Load SSL Context on init (#709)
This commit is contained in:
parent
f5eaec7ab3
commit
79a9748ae6
@ -63,19 +63,10 @@ class SSLConfig:
|
||||
http2: bool = False,
|
||||
):
|
||||
self.cert = cert
|
||||
|
||||
# Allow passing in our own SSLContext object that's pre-configured.
|
||||
# If you do this we assume that you want verify=True as well.
|
||||
ssl_context = None
|
||||
if isinstance(verify, ssl.SSLContext):
|
||||
ssl_context = verify
|
||||
verify = True
|
||||
self._load_client_certs(ssl_context)
|
||||
|
||||
self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
|
||||
self.verify: typing.Union[str, bool] = verify
|
||||
self.verify = verify
|
||||
self.trust_env = trust_env
|
||||
self.http2 = http2
|
||||
self.ssl_context = self.load_ssl_context()
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
@ -97,15 +88,9 @@ class SSLConfig:
|
||||
f"http2={self.http2!r}"
|
||||
)
|
||||
|
||||
if self.ssl_context is None:
|
||||
self.ssl_context = (
|
||||
self.load_ssl_context_verify()
|
||||
if self.verify
|
||||
else self.load_ssl_context_no_verify()
|
||||
)
|
||||
|
||||
assert self.ssl_context is not None
|
||||
return self.ssl_context
|
||||
if self.verify:
|
||||
return self.load_ssl_context_verify()
|
||||
return self.load_ssl_context_no_verify()
|
||||
|
||||
def load_ssl_context_no_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
@ -125,7 +110,12 @@ class SSLConfig:
|
||||
if ca_bundle is not None:
|
||||
self.verify = ca_bundle # type: ignore
|
||||
|
||||
if isinstance(self.verify, bool):
|
||||
if isinstance(self.verify, ssl.SSLContext):
|
||||
# Allow passing in our own SSLContext object that's pre-configured.
|
||||
context = self.verify
|
||||
self._load_client_certs(context)
|
||||
return context
|
||||
elif isinstance(self.verify, bool):
|
||||
ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH
|
||||
elif Path(self.verify).exists():
|
||||
ca_bundle_path = Path(self.verify)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import functools
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
import h11
|
||||
@ -49,7 +48,7 @@ class HTTPConnection(Dispatcher):
|
||||
) -> typing.Union[HTTP11Connection, HTTP2Connection]:
|
||||
host = self.origin.host
|
||||
port = self.origin.port
|
||||
ssl_context = self.get_ssl_context()
|
||||
ssl_context = None if not self.origin.is_ssl else self.ssl.ssl_context
|
||||
|
||||
if self.release_func is None:
|
||||
on_release = None
|
||||
@ -104,8 +103,7 @@ class HTTPConnection(Dispatcher):
|
||||
if origin.is_ssl:
|
||||
# Pull the socket stream off the internal HTTP connection object,
|
||||
# and run start_tls().
|
||||
ssl_context = self.get_ssl_context()
|
||||
assert ssl_context is not None
|
||||
ssl_context = self.ssl.ssl_context
|
||||
|
||||
logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
|
||||
socket = await socket.start_tls(
|
||||
@ -130,11 +128,6 @@ class HTTPConnection(Dispatcher):
|
||||
else:
|
||||
self.connection = HTTP11Connection(socket, on_release=on_release)
|
||||
|
||||
def get_ssl_context(self) -> typing.Optional[ssl.SSLContext]:
|
||||
if not self.origin.is_ssl:
|
||||
return None
|
||||
return self.ssl.load_ssl_context()
|
||||
|
||||
async def close(self) -> None:
|
||||
logger.trace("close_connection")
|
||||
if self.connection is not None:
|
||||
|
||||
@ -13,20 +13,19 @@ from httpx.config import SSLConfig
|
||||
|
||||
def test_load_ssl_config():
|
||||
ssl_config = SSLConfig()
|
||||
context = ssl_config.load_ssl_context()
|
||||
context = ssl_config.ssl_context
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
def test_load_ssl_config_verify_non_existing_path():
|
||||
ssl_config = SSLConfig(verify="/path/to/nowhere")
|
||||
with pytest.raises(IOError):
|
||||
ssl_config.load_ssl_context()
|
||||
SSLConfig(verify="/path/to/nowhere")
|
||||
|
||||
|
||||
def test_load_ssl_config_verify_existing_file():
|
||||
ssl_config = SSLConfig(verify=certifi.where())
|
||||
context = ssl_config.load_ssl_context()
|
||||
context = ssl_config.ssl_context
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
@ -39,7 +38,7 @@ def test_load_ssl_config_verify_env_file(https_server, ca_cert_pem_file, config)
|
||||
else str(Path(ca_cert_pem_file).parent)
|
||||
)
|
||||
ssl_config = SSLConfig(trust_env=True)
|
||||
context = ssl_config.load_ssl_context()
|
||||
context = ssl_config.ssl_context
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
assert ssl_config.verify == os.environ[config]
|
||||
@ -58,14 +57,14 @@ def test_load_ssl_config_verify_env_file(https_server, ca_cert_pem_file, config)
|
||||
def test_load_ssl_config_verify_directory():
|
||||
path = Path(certifi.where()).parent
|
||||
ssl_config = SSLConfig(verify=path)
|
||||
context = ssl_config.load_ssl_context()
|
||||
context = ssl_config.ssl_context
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
|
||||
ssl_config = SSLConfig(cert=(cert_pem_file, cert_private_key_file))
|
||||
context = ssl_config.load_ssl_context()
|
||||
context = ssl_config.ssl_context
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
@ -77,7 +76,7 @@ def test_load_ssl_config_cert_and_encrypted_key(
|
||||
ssl_config = SSLConfig(
|
||||
cert=(cert_pem_file, cert_encrypted_private_key_file, password)
|
||||
)
|
||||
context = ssl_config.load_ssl_context()
|
||||
context = ssl_config.ssl_context
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
@ -85,23 +84,18 @@ def test_load_ssl_config_cert_and_encrypted_key(
|
||||
def test_load_ssl_config_cert_and_key_invalid_password(
|
||||
cert_pem_file, cert_encrypted_private_key_file
|
||||
):
|
||||
ssl_config = SSLConfig(
|
||||
cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
|
||||
)
|
||||
|
||||
with pytest.raises(ssl.SSLError):
|
||||
ssl_config.load_ssl_context()
|
||||
SSLConfig(cert=(cert_pem_file, cert_encrypted_private_key_file, "password1"))
|
||||
|
||||
|
||||
def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
|
||||
ssl_config = SSLConfig(cert=cert_pem_file)
|
||||
with pytest.raises(ssl.SSLError):
|
||||
ssl_config.load_ssl_context()
|
||||
SSLConfig(cert=cert_pem_file)
|
||||
|
||||
|
||||
def test_load_ssl_config_no_verify():
|
||||
ssl_config = SSLConfig(verify=False)
|
||||
context = ssl_config.load_ssl_context()
|
||||
context = ssl_config.ssl_context
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_NONE
|
||||
assert context.check_hostname is False
|
||||
|
||||
@ -110,9 +104,7 @@ def test_load_ssl_context():
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_config = SSLConfig(verify=ssl_context)
|
||||
|
||||
assert ssl_config.verify is True
|
||||
assert ssl_config.ssl_context is ssl_context
|
||||
assert repr(ssl_config) == "SSLConfig(cert=None, verify=True)"
|
||||
|
||||
|
||||
def test_ssl_repr():
|
||||
@ -199,7 +191,6 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch): # pragma: noc
|
||||
m.delenv("SSLKEYLOGFILE", raising=False)
|
||||
|
||||
ssl_config = SSLConfig(trust_env=True)
|
||||
ssl_config.load_ssl_context()
|
||||
|
||||
assert ssl_config.ssl_context.keylog_filename is None
|
||||
|
||||
@ -209,11 +200,9 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch): # pragma: noc
|
||||
m.setenv("SSLKEYLOGFILE", filename)
|
||||
|
||||
ssl_config = SSLConfig(trust_env=True)
|
||||
ssl_config.load_ssl_context()
|
||||
|
||||
assert ssl_config.ssl_context.keylog_filename == filename
|
||||
|
||||
ssl_config = SSLConfig(trust_env=False)
|
||||
ssl_config.load_ssl_context()
|
||||
|
||||
assert ssl_config.ssl_context.keylog_filename is None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user