Support client cert passwords, new TLS options (#118)

* Support client cert passwords, new TLS options

* Update test_config.py

* Switch to try-except for post_handshake_auth=True

SSLContext.post_handshake_auth raises AttributeError if the property is available but cannot be written to (needs OpenSSL 1.1.1+)

* Also try-except for hostname_checks_common_name=False

* Custom implementation of trustme.CA() that emits encrypted PKs

* lint

* Split name of test

* Updates from review comments

* Don't load default CAs yet
This commit is contained in:
Seth Michael Larson 2019-07-16 04:08:57 -05:00 committed by Tom Christie
parent 5442006a41
commit ba83e29aee
4 changed files with 130 additions and 35 deletions

View File

@ -7,7 +7,7 @@ import certifi
from .__version__ import __version__
CertTypes = typing.Union[str, typing.Tuple[str, str]]
CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
VerifyTypes = typing.Union[str, bool]
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
@ -43,6 +43,8 @@ class SSLConfig:
self.cert = cert
self.verify = verify
self.ssl_context: typing.Optional[ssl.SSLContext] = None
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
@ -64,7 +66,7 @@ class SSLConfig:
return SSLConfig(cert=cert, verify=verify)
async def load_ssl_context(self) -> ssl.SSLContext:
if not hasattr(self, "ssl_context"):
if self.ssl_context is None:
if not self.verify:
self.ssl_context = self.load_ssl_context_no_verify()
else:
@ -80,11 +82,9 @@ class SSLConfig:
"""
Return an SSL context for unverified connections.
"""
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_COMPRESSION
context.set_default_verify_paths()
context = self._create_default_ssl_context()
context.verify_mode = ssl.CERT_NONE
context.check_hostname = False
return context
def load_ssl_context_verify(self) -> ssl.SSLContext:
@ -101,20 +101,23 @@ class SSLConfig:
"invalid path: {}".format(self.verify)
)
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
context = self._create_default_ssl_context()
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_COMPRESSION
# Signal to server support for PHA in TLS 1.3. Raises an
# AttributeError if only read-only access is implemented.
try:
context.post_handshake_auth = True
except AttributeError:
pass
context.set_ciphers(DEFAULT_CIPHERS)
if ssl.HAS_ALPN:
context.set_alpn_protocols(["h2", "http/1.1"])
if ssl.HAS_NPN:
context.set_npn_protocols(["h2", "http/1.1"])
# Disable using 'commonName' for SSLContext.check_hostname
# when the 'subjectAltName' extension isn't available.
try:
context.hostname_checks_common_name = False
except AttributeError:
pass
if os.path.isfile(ca_bundle_path):
context.load_verify_locations(cafile=ca_bundle_path)
@ -124,8 +127,30 @@ class SSLConfig:
if self.cert is not None:
if isinstance(self.cert, str):
context.load_cert_chain(certfile=self.cert)
else:
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
else:
context.load_cert_chain(
certfile=self.cert[0], keyfile=self.cert[1], password=self.cert[2]
)
return context
def _create_default_ssl_context(self) -> ssl.SSLContext:
"""
Creates the default SSLContext object that's used for both verified
and unverified connections.
"""
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_COMPRESSION
context.set_ciphers(DEFAULT_CIPHERS)
if ssl.HAS_ALPN:
context.set_alpn_protocols(["h2", "http/1.1"])
if ssl.HAS_NPN:
context.set_npn_protocols(["h2", "http/1.1"])
return context
@ -175,7 +200,7 @@ class TimeoutConfig:
def __repr__(self) -> str:
class_name = self.__class__.__name__
if len(set([self.connect_timeout, self.read_timeout, self.write_timeout])) == 1:
if len({self.connect_timeout, self.read_timeout, self.write_timeout}) == 1:
return f"{class_name}(timeout={self.connect_timeout})"
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"

View File

@ -15,6 +15,7 @@ mkdocs-material
# Testing
autoflake
black
cryptography
isort
mypy
pytest

View File

@ -2,6 +2,11 @@ import asyncio
import pytest
import trustme
from cryptography.hazmat.primitives.serialization import (
BestAvailableEncryption,
Encoding,
PrivateFormat,
)
from uvicorn.config import Config
from uvicorn.main import Server
@ -72,12 +77,45 @@ async def echo_body(scope, receive, send):
await send({"type": "http.response.body", "body": body})
class CAWithPKEncryption(trustme.CA):
"""Implementation of trustme.CA() that can emit
private keys that are encrypted with a password.
"""
@property
def encrypted_private_key_pem(self):
return trustme.Blob(
self._private_key.private_bytes(
Encoding.PEM,
PrivateFormat.TraditionalOpenSSL,
BestAvailableEncryption(password=b"password"),
)
)
@pytest.fixture
def cert_and_key_paths():
ca = trustme.CA()
def example_cert():
ca = CAWithPKEncryption()
ca.issue_cert("example.org")
with ca.cert_pem.tempfile() as cert_temp_path, ca.private_key_pem.tempfile() as key_temp_path:
yield cert_temp_path, key_temp_path
return ca
@pytest.fixture
def cert_pem_file(example_cert):
with example_cert.cert_pem.tempfile() as tmp:
yield tmp
@pytest.fixture
def cert_private_key_file(example_cert):
with example_cert.private_key_pem.tempfile() as tmp:
yield tmp
@pytest.fixture
def cert_encrypted_private_key_file(example_cert):
with example_cert.encrypted_private_key_pem.tempfile() as tmp:
yield tmp
@pytest.fixture
@ -95,10 +133,13 @@ async def server():
@pytest.fixture
async def https_server(cert_and_key_paths):
cert_path, key_path = cert_and_key_paths
async def https_server(cert_pem_file, cert_private_key_file):
config = Config(
app=app, lifespan="off", ssl_certfile=cert_path, ssl_keyfile=key_path, port=8001
app=app,
lifespan="off",
ssl_certfile=cert_pem_file,
ssl_keyfile=cert_private_key_file,
port=8001,
)
server = Server(config=config)
task = asyncio.ensure_future(server.serve())

View File

@ -11,6 +11,7 @@ async def test_load_ssl_config():
ssl_config = http3.SSLConfig()
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
@pytest.mark.asyncio
@ -25,6 +26,7 @@ async def test_load_ssl_config_verify_existing_file():
ssl_config = http3.SSLConfig(verify=http3.config.DEFAULT_CA_BUNDLE_PATH)
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
@pytest.mark.asyncio
@ -33,29 +35,55 @@ async def test_load_ssl_config_verify_directory():
ssl_config = http3.SSLConfig(verify=path)
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
@pytest.mark.asyncio
async def test_load_ssl_config_cert_and_key(cert_and_key_paths):
cert_path, key_path = cert_and_key_paths
ssl_config = http3.SSLConfig(cert=(cert_path, key_path))
async def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
ssl_config = http3.SSLConfig(cert=(cert_pem_file, cert_private_key_file))
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
@pytest.mark.asyncio
async def test_load_ssl_config_cert_without_key_raises(cert_and_key_paths):
cert_path, _ = cert_and_key_paths
ssl_config = http3.SSLConfig(cert=cert_path)
@pytest.mark.parametrize("password", [b"password", "password"])
async def test_load_ssl_config_cert_and_encrypted_key(
cert_pem_file, cert_encrypted_private_key_file, password
):
ssl_config = http3.SSLConfig(
cert=(cert_pem_file, cert_encrypted_private_key_file, password)
)
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
@pytest.mark.asyncio
async def test_load_ssl_config_cert_and_key_invalid_password(
cert_pem_file, cert_encrypted_private_key_file
):
ssl_config = http3.SSLConfig(
cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
)
with pytest.raises(ssl.SSLError):
await ssl_config.load_ssl_context()
@pytest.mark.asyncio
async def test_load_ssl_config_no_verify(verify=False):
async def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
ssl_config = http3.SSLConfig(cert=cert_pem_file)
with pytest.raises(ssl.SSLError):
await ssl_config.load_ssl_context()
@pytest.mark.asyncio
async def test_load_ssl_config_no_verify():
ssl_config = http3.SSLConfig(verify=False)
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_NONE
assert context.check_hostname is False
def test_ssl_repr():
@ -102,5 +130,5 @@ def test_timeout_from_tuple():
def test_timeout_from_config_instance():
timeout = http3.TimeoutConfig(timeout=(5.0))
timeout = http3.TimeoutConfig(timeout=5.0)
assert http3.TimeoutConfig(timeout) == http3.TimeoutConfig(timeout=5.0)