Accept SSLContext into SSLConfig(verify=...) (#215)

This commit is contained in:
Seth Michael Larson 2019-08-14 22:30:02 -05:00 committed by GitHub
parent af907e85e1
commit df8874b733
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 14 deletions

View File

@ -7,7 +7,7 @@ import certifi
from .__version__ import __version__
CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
VerifyTypes = typing.Union[str, bool]
VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
@ -40,9 +40,17 @@ class SSLConfig:
def __init__(self, *, cert: CertTypes = None, verify: VerifyTypes = True):
self.cert = cert
self.verify = verify
self.ssl_context: typing.Optional[ssl.SSLContext] = None
# Allow passing in our own SSLContext object that's pre-configured.
# If you do this we assume that you want verify=True as well.
ssl_context = None
if isinstance(verify, ssl.SSLContext):
ssl_context = verify
verify = True
self._load_client_certs(ssl_context)
self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
self.verify: typing.Union[str, bool] = verify
def __eq__(self, other: typing.Any) -> bool:
return (
@ -121,17 +129,7 @@ class SSLConfig:
elif ca_bundle_path.is_dir():
context.load_verify_locations(capath=str(ca_bundle_path))
if self.cert is not None:
if isinstance(self.cert, str):
context.load_cert_chain(certfile=self.cert)
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
elif isinstance(self.cert, tuple) and len(self.cert) == 3:
context.load_cert_chain(
certfile=self.cert[0],
keyfile=self.cert[1],
password=self.cert[2], # type: ignore
)
self._load_client_certs(context)
return context
@ -155,6 +153,22 @@ class SSLConfig:
return context
def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None:
"""
Loads client certificates into our SSLContext object
"""
if self.cert is not None:
if isinstance(self.cert, str):
ssl_context.load_cert_chain(certfile=self.cert)
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
elif isinstance(self.cert, tuple) and len(self.cert) == 3:
ssl_context.load_cert_chain(
certfile=self.cert[0],
keyfile=self.cert[1],
password=self.cert[2], # type: ignore
)
class TimeoutConfig:
"""

View File

@ -76,6 +76,15 @@ def test_load_ssl_config_no_verify():
assert context.check_hostname is False
def test_load_ssl_context():
ssl_context = ssl.create_default_context()
ssl_config = httpx.SSLConfig(verify=ssl_context)
assert ssl_config.verify is True
assert ssl_config.ssl_context is ssl_context
assert repr(ssl_config) == "SSLConfig(cert=None, verify=True)"
def test_ssl_repr():
ssl = httpx.SSLConfig(verify=False)
assert repr(ssl) == "SSLConfig(cert=None, verify=False)"