Support client cert passwords, new TLS options (#118)
* Support client cert passwords, new TLS options * Update test_config.py * Switch to try-except for post_handshake_auth=True SSLContext.post_handshake_auth raises AttributeError if the property is available but cannot be written to (needs OpenSSL 1.1.1+) * Also try-except for hostname_checks_common_name=False * Custom implementation of trustme.CA() that emits encrypted PKs * lint * Split name of test * Updates from review comments * Don't load default CAs yet
This commit is contained in:
parent
5442006a41
commit
ba83e29aee
@ -7,7 +7,7 @@ import certifi
|
||||
|
||||
from .__version__ import __version__
|
||||
|
||||
CertTypes = typing.Union[str, typing.Tuple[str, str]]
|
||||
CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
|
||||
VerifyTypes = typing.Union[str, bool]
|
||||
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
|
||||
|
||||
@ -43,6 +43,8 @@ class SSLConfig:
|
||||
self.cert = cert
|
||||
self.verify = verify
|
||||
|
||||
self.ssl_context: typing.Optional[ssl.SSLContext] = None
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
@ -64,7 +66,7 @@ class SSLConfig:
|
||||
return SSLConfig(cert=cert, verify=verify)
|
||||
|
||||
async def load_ssl_context(self) -> ssl.SSLContext:
|
||||
if not hasattr(self, "ssl_context"):
|
||||
if self.ssl_context is None:
|
||||
if not self.verify:
|
||||
self.ssl_context = self.load_ssl_context_no_verify()
|
||||
else:
|
||||
@ -80,11 +82,9 @@ class SSLConfig:
|
||||
"""
|
||||
Return an SSL context for unverified connections.
|
||||
"""
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
context.set_default_verify_paths()
|
||||
context = self._create_default_ssl_context()
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
context.check_hostname = False
|
||||
return context
|
||||
|
||||
def load_ssl_context_verify(self) -> ssl.SSLContext:
|
||||
@ -101,20 +101,23 @@ class SSLConfig:
|
||||
"invalid path: {}".format(self.verify)
|
||||
)
|
||||
|
||||
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
||||
|
||||
context = self._create_default_ssl_context()
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
context.check_hostname = True
|
||||
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
# Signal to server support for PHA in TLS 1.3. Raises an
|
||||
# AttributeError if only read-only access is implemented.
|
||||
try:
|
||||
context.post_handshake_auth = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
context.set_ciphers(DEFAULT_CIPHERS)
|
||||
|
||||
if ssl.HAS_ALPN:
|
||||
context.set_alpn_protocols(["h2", "http/1.1"])
|
||||
if ssl.HAS_NPN:
|
||||
context.set_npn_protocols(["h2", "http/1.1"])
|
||||
# Disable using 'commonName' for SSLContext.check_hostname
|
||||
# when the 'subjectAltName' extension isn't available.
|
||||
try:
|
||||
context.hostname_checks_common_name = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if os.path.isfile(ca_bundle_path):
|
||||
context.load_verify_locations(cafile=ca_bundle_path)
|
||||
@ -124,8 +127,30 @@ class SSLConfig:
|
||||
if self.cert is not None:
|
||||
if isinstance(self.cert, str):
|
||||
context.load_cert_chain(certfile=self.cert)
|
||||
else:
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
|
||||
context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
|
||||
else:
|
||||
context.load_cert_chain(
|
||||
certfile=self.cert[0], keyfile=self.cert[1], password=self.cert[2]
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def _create_default_ssl_context(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Creates the default SSLContext object that's used for both verified
|
||||
and unverified connections.
|
||||
"""
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
context.set_ciphers(DEFAULT_CIPHERS)
|
||||
|
||||
if ssl.HAS_ALPN:
|
||||
context.set_alpn_protocols(["h2", "http/1.1"])
|
||||
if ssl.HAS_NPN:
|
||||
context.set_npn_protocols(["h2", "http/1.1"])
|
||||
|
||||
return context
|
||||
|
||||
@ -175,7 +200,7 @@ class TimeoutConfig:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
if len(set([self.connect_timeout, self.read_timeout, self.write_timeout])) == 1:
|
||||
if len({self.connect_timeout, self.read_timeout, self.write_timeout}) == 1:
|
||||
return f"{class_name}(timeout={self.connect_timeout})"
|
||||
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ mkdocs-material
|
||||
# Testing
|
||||
autoflake
|
||||
black
|
||||
cryptography
|
||||
isort
|
||||
mypy
|
||||
pytest
|
||||
|
||||
@ -2,6 +2,11 @@ import asyncio
|
||||
|
||||
import pytest
|
||||
import trustme
|
||||
from cryptography.hazmat.primitives.serialization import (
|
||||
BestAvailableEncryption,
|
||||
Encoding,
|
||||
PrivateFormat,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.main import Server
|
||||
|
||||
@ -72,12 +77,45 @@ async def echo_body(scope, receive, send):
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
|
||||
|
||||
class CAWithPKEncryption(trustme.CA):
|
||||
"""Implementation of trustme.CA() that can emit
|
||||
private keys that are encrypted with a password.
|
||||
"""
|
||||
|
||||
@property
|
||||
def encrypted_private_key_pem(self):
|
||||
return trustme.Blob(
|
||||
self._private_key.private_bytes(
|
||||
Encoding.PEM,
|
||||
PrivateFormat.TraditionalOpenSSL,
|
||||
BestAvailableEncryption(password=b"password"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cert_and_key_paths():
|
||||
ca = trustme.CA()
|
||||
def example_cert():
|
||||
ca = CAWithPKEncryption()
|
||||
ca.issue_cert("example.org")
|
||||
with ca.cert_pem.tempfile() as cert_temp_path, ca.private_key_pem.tempfile() as key_temp_path:
|
||||
yield cert_temp_path, key_temp_path
|
||||
return ca
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cert_pem_file(example_cert):
|
||||
with example_cert.cert_pem.tempfile() as tmp:
|
||||
yield tmp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cert_private_key_file(example_cert):
|
||||
with example_cert.private_key_pem.tempfile() as tmp:
|
||||
yield tmp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cert_encrypted_private_key_file(example_cert):
|
||||
with example_cert.encrypted_private_key_pem.tempfile() as tmp:
|
||||
yield tmp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -95,10 +133,13 @@ async def server():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def https_server(cert_and_key_paths):
|
||||
cert_path, key_path = cert_and_key_paths
|
||||
async def https_server(cert_pem_file, cert_private_key_file):
|
||||
config = Config(
|
||||
app=app, lifespan="off", ssl_certfile=cert_path, ssl_keyfile=key_path, port=8001
|
||||
app=app,
|
||||
lifespan="off",
|
||||
ssl_certfile=cert_pem_file,
|
||||
ssl_keyfile=cert_private_key_file,
|
||||
port=8001,
|
||||
)
|
||||
server = Server(config=config)
|
||||
task = asyncio.ensure_future(server.serve())
|
||||
|
||||
@ -11,6 +11,7 @@ async def test_load_ssl_config():
|
||||
ssl_config = http3.SSLConfig()
|
||||
context = await ssl_config.load_ssl_context()
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -25,6 +26,7 @@ async def test_load_ssl_config_verify_existing_file():
|
||||
ssl_config = http3.SSLConfig(verify=http3.config.DEFAULT_CA_BUNDLE_PATH)
|
||||
context = await ssl_config.load_ssl_context()
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -33,29 +35,55 @@ async def test_load_ssl_config_verify_directory():
|
||||
ssl_config = http3.SSLConfig(verify=path)
|
||||
context = await ssl_config.load_ssl_context()
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_ssl_config_cert_and_key(cert_and_key_paths):
|
||||
cert_path, key_path = cert_and_key_paths
|
||||
ssl_config = http3.SSLConfig(cert=(cert_path, key_path))
|
||||
async def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
|
||||
ssl_config = http3.SSLConfig(cert=(cert_pem_file, cert_private_key_file))
|
||||
context = await ssl_config.load_ssl_context()
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_ssl_config_cert_without_key_raises(cert_and_key_paths):
|
||||
cert_path, _ = cert_and_key_paths
|
||||
ssl_config = http3.SSLConfig(cert=cert_path)
|
||||
@pytest.mark.parametrize("password", [b"password", "password"])
|
||||
async def test_load_ssl_config_cert_and_encrypted_key(
|
||||
cert_pem_file, cert_encrypted_private_key_file, password
|
||||
):
|
||||
ssl_config = http3.SSLConfig(
|
||||
cert=(cert_pem_file, cert_encrypted_private_key_file, password)
|
||||
)
|
||||
context = await ssl_config.load_ssl_context()
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_ssl_config_cert_and_key_invalid_password(
|
||||
cert_pem_file, cert_encrypted_private_key_file
|
||||
):
|
||||
ssl_config = http3.SSLConfig(
|
||||
cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
|
||||
)
|
||||
|
||||
with pytest.raises(ssl.SSLError):
|
||||
await ssl_config.load_ssl_context()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_ssl_config_no_verify(verify=False):
|
||||
async def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
|
||||
ssl_config = http3.SSLConfig(cert=cert_pem_file)
|
||||
with pytest.raises(ssl.SSLError):
|
||||
await ssl_config.load_ssl_context()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_ssl_config_no_verify():
|
||||
ssl_config = http3.SSLConfig(verify=False)
|
||||
context = await ssl_config.load_ssl_context()
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_NONE
|
||||
assert context.check_hostname is False
|
||||
|
||||
|
||||
def test_ssl_repr():
|
||||
@ -102,5 +130,5 @@ def test_timeout_from_tuple():
|
||||
|
||||
|
||||
def test_timeout_from_config_instance():
|
||||
timeout = http3.TimeoutConfig(timeout=(5.0))
|
||||
timeout = http3.TimeoutConfig(timeout=5.0)
|
||||
assert http3.TimeoutConfig(timeout) == http3.TimeoutConfig(timeout=5.0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user