Accept SSLContext into SSLConfig(verify=...) (#215)
This commit is contained in:
parent
af907e85e1
commit
df8874b733
@ -7,7 +7,7 @@ import certifi
|
||||
from .__version__ import __version__
|
||||
|
||||
CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
|
||||
VerifyTypes = typing.Union[str, bool]
|
||||
VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
|
||||
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
|
||||
|
||||
|
||||
@ -40,9 +40,17 @@ class SSLConfig:
|
||||
|
||||
def __init__(self, *, cert: CertTypes = None, verify: VerifyTypes = True):
|
||||
self.cert = cert
|
||||
self.verify = verify
|
||||
|
||||
self.ssl_context: typing.Optional[ssl.SSLContext] = None
|
||||
# Allow passing in our own SSLContext object that's pre-configured.
|
||||
# If you do this we assume that you want verify=True as well.
|
||||
ssl_context = None
|
||||
if isinstance(verify, ssl.SSLContext):
|
||||
ssl_context = verify
|
||||
verify = True
|
||||
self._load_client_certs(ssl_context)
|
||||
|
||||
self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
|
||||
self.verify: typing.Union[str, bool] = verify
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
@ -121,17 +129,7 @@ class SSLConfig:
|
||||
elif ca_bundle_path.is_dir():
|
||||
context.load_verify_locations(capath=str(ca_bundle_path))
|
||||
|
||||
if self.cert is not None:
|
||||
if isinstance(self.cert, str):
|
||||
context.load_cert_chain(certfile=self.cert)
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
|
||||
context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 3:
|
||||
context.load_cert_chain(
|
||||
certfile=self.cert[0],
|
||||
keyfile=self.cert[1],
|
||||
password=self.cert[2], # type: ignore
|
||||
)
|
||||
self._load_client_certs(context)
|
||||
|
||||
return context
|
||||
|
||||
@ -155,6 +153,22 @@ class SSLConfig:
|
||||
|
||||
return context
|
||||
|
||||
def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None:
|
||||
"""
|
||||
Loads client certificates into our SSLContext object
|
||||
"""
|
||||
if self.cert is not None:
|
||||
if isinstance(self.cert, str):
|
||||
ssl_context.load_cert_chain(certfile=self.cert)
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
|
||||
ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 3:
|
||||
ssl_context.load_cert_chain(
|
||||
certfile=self.cert[0],
|
||||
keyfile=self.cert[1],
|
||||
password=self.cert[2], # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class TimeoutConfig:
|
||||
"""
|
||||
|
||||
@ -76,6 +76,15 @@ def test_load_ssl_config_no_verify():
|
||||
assert context.check_hostname is False
|
||||
|
||||
|
||||
def test_load_ssl_context():
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_config = httpx.SSLConfig(verify=ssl_context)
|
||||
|
||||
assert ssl_config.verify is True
|
||||
assert ssl_config.ssl_context is ssl_context
|
||||
assert repr(ssl_config) == "SSLConfig(cert=None, verify=True)"
|
||||
|
||||
|
||||
def test_ssl_repr():
|
||||
ssl = httpx.SSLConfig(verify=False)
|
||||
assert repr(ssl) == "SSLConfig(cert=None, verify=False)"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user