SSLConfig refactor (#706)
* SSLConfig includes 'http2' argument on init. * Pass SSL config to HTTPConnection as a single argument * Don't run SSL context loading in threadpool
This commit is contained in:
parent
a9f4d018e1
commit
b0bf2a7513
@ -60,6 +60,7 @@ class SSLConfig:
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
trust_env: bool = None,
|
||||
http2: bool = False,
|
||||
):
|
||||
self.cert = cert
|
||||
|
||||
@ -74,6 +75,7 @@ class SSLConfig:
|
||||
self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
|
||||
self.verify: typing.Union[str, bool] = verify
|
||||
self.trust_env = trust_env
|
||||
self.http2 = http2
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
@ -86,35 +88,35 @@ class SSLConfig:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(cert={self.cert}, verify={self.verify})"
|
||||
|
||||
def load_ssl_context(self, http2: bool = False) -> ssl.SSLContext:
|
||||
def load_ssl_context(self) -> ssl.SSLContext:
|
||||
logger.trace(
|
||||
f"load_ssl_context "
|
||||
f"verify={self.verify!r} "
|
||||
f"cert={self.cert!r} "
|
||||
f"trust_env={self.trust_env!r} "
|
||||
f"http2={http2!r}"
|
||||
f"http2={self.http2!r}"
|
||||
)
|
||||
|
||||
if self.ssl_context is None:
|
||||
self.ssl_context = (
|
||||
self.load_ssl_context_verify(http2=http2)
|
||||
self.load_ssl_context_verify()
|
||||
if self.verify
|
||||
else self.load_ssl_context_no_verify(http2=http2)
|
||||
else self.load_ssl_context_no_verify()
|
||||
)
|
||||
|
||||
assert self.ssl_context is not None
|
||||
return self.ssl_context
|
||||
|
||||
def load_ssl_context_no_verify(self, http2: bool = False) -> ssl.SSLContext:
|
||||
def load_ssl_context_no_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for unverified connections.
|
||||
"""
|
||||
context = self._create_default_ssl_context(http2=http2)
|
||||
context = self._create_default_ssl_context()
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
context.check_hostname = False
|
||||
return context
|
||||
|
||||
def load_ssl_context_verify(self, http2: bool = False) -> ssl.SSLContext:
|
||||
def load_ssl_context_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for verified connections.
|
||||
"""
|
||||
@ -133,7 +135,7 @@ class SSLConfig:
|
||||
"invalid path: {}".format(self.verify)
|
||||
)
|
||||
|
||||
context = self._create_default_ssl_context(http2=http2)
|
||||
context = self._create_default_ssl_context()
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
context.check_hostname = True
|
||||
|
||||
@ -162,7 +164,7 @@ class SSLConfig:
|
||||
|
||||
return context
|
||||
|
||||
def _create_default_ssl_context(self, http2: bool) -> ssl.SSLContext:
|
||||
def _create_default_ssl_context(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Creates the default SSLContext object that's used for both verified
|
||||
and unverified connections.
|
||||
@ -176,7 +178,7 @@ class SSLConfig:
|
||||
context.set_ciphers(DEFAULT_CIPHERS)
|
||||
|
||||
if ssl.HAS_ALPN:
|
||||
alpn_idents = ["http/1.1", "h2"] if http2 else ["http/1.1"]
|
||||
alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"]
|
||||
context.set_alpn_protocols(alpn_idents)
|
||||
|
||||
if hasattr(context, "keylog_filename"): # pragma: nocover (Available in 3.8+)
|
||||
|
||||
@ -5,7 +5,7 @@ import typing
|
||||
import h11
|
||||
|
||||
from ..backends.base import ConcurrencyBackend, lookup_backend
|
||||
from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
|
||||
from ..config import SSLConfig, Timeout
|
||||
from ..models import URL, Origin, Request, Response
|
||||
from ..utils import get_logger
|
||||
from .base import Dispatcher
|
||||
@ -23,17 +23,13 @@ class HTTPConnection(Dispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
origin: typing.Union[str, Origin],
|
||||
verify: VerifyTypes = True,
|
||||
cert: CertTypes = None,
|
||||
trust_env: bool = None,
|
||||
http2: bool = False,
|
||||
ssl: SSLConfig = None,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
release_func: typing.Optional[ReleaseCallback] = None,
|
||||
uds: typing.Optional[str] = None,
|
||||
):
|
||||
self.origin = Origin(origin) if isinstance(origin, str) else origin
|
||||
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
|
||||
self.http2 = http2
|
||||
self.ssl = SSLConfig() if ssl is None else ssl
|
||||
self.backend = lookup_backend(backend)
|
||||
self.release_func = release_func
|
||||
self.uds = uds
|
||||
@ -53,7 +49,7 @@ class HTTPConnection(Dispatcher):
|
||||
) -> typing.Union[HTTP11Connection, HTTP2Connection]:
|
||||
host = self.origin.host
|
||||
port = self.origin.port
|
||||
ssl_context = await self.get_ssl_context(self.ssl)
|
||||
ssl_context = self.get_ssl_context()
|
||||
|
||||
if self.release_func is None:
|
||||
on_release = None
|
||||
@ -108,7 +104,7 @@ class HTTPConnection(Dispatcher):
|
||||
if origin.is_ssl:
|
||||
# Pull the socket stream off the internal HTTP connection object,
|
||||
# and run start_tls().
|
||||
ssl_context = await self.get_ssl_context(self.ssl)
|
||||
ssl_context = self.get_ssl_context()
|
||||
assert ssl_context is not None
|
||||
|
||||
logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
|
||||
@ -134,12 +130,10 @@ class HTTPConnection(Dispatcher):
|
||||
else:
|
||||
self.connection = HTTP11Connection(socket, on_release=on_release)
|
||||
|
||||
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
|
||||
def get_ssl_context(self) -> typing.Optional[ssl.SSLContext]:
|
||||
if not self.origin.is_ssl:
|
||||
return None
|
||||
|
||||
# Run the SSL loading in a threadpool, since it may make disk accesses.
|
||||
return await self.backend.run_in_threadpool(ssl.load_ssl_context, self.http2)
|
||||
return self.ssl.load_ssl_context()
|
||||
|
||||
async def close(self) -> None:
|
||||
logger.trace("close_connection")
|
||||
|
||||
@ -1,7 +1,14 @@
|
||||
import typing
|
||||
|
||||
from ..backends.base import BaseSemaphore, ConcurrencyBackend, lookup_backend
|
||||
from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
|
||||
from ..config import (
|
||||
DEFAULT_POOL_LIMITS,
|
||||
CertTypes,
|
||||
PoolLimits,
|
||||
SSLConfig,
|
||||
Timeout,
|
||||
VerifyTypes,
|
||||
)
|
||||
from ..exceptions import PoolTimeout
|
||||
from ..models import Origin, Request, Response
|
||||
from ..utils import get_logger
|
||||
@ -92,12 +99,9 @@ class ConnectionPool(Dispatcher):
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
uds: typing.Optional[str] = None,
|
||||
):
|
||||
self.verify = verify
|
||||
self.cert = cert
|
||||
self.ssl = SSLConfig(verify=verify, cert=cert, trust_env=trust_env, http2=http2)
|
||||
self.pool_limits = pool_limits
|
||||
self.http2 = http2
|
||||
self.is_closed = False
|
||||
self.trust_env = trust_env
|
||||
self.uds = uds
|
||||
|
||||
self.keepalive_connections = ConnectionStore()
|
||||
@ -166,12 +170,9 @@ class ConnectionPool(Dispatcher):
|
||||
await self.max_connections.acquire(timeout=pool_timeout)
|
||||
connection = HTTPConnection(
|
||||
origin,
|
||||
verify=self.verify,
|
||||
cert=self.cert,
|
||||
http2=self.http2,
|
||||
ssl=self.ssl,
|
||||
backend=self.backend,
|
||||
release_func=self.release_connection,
|
||||
trust_env=self.trust_env,
|
||||
uds=self.uds,
|
||||
)
|
||||
logger.trace(f"new_connection connection={connection!r}")
|
||||
|
||||
@ -4,7 +4,14 @@ import warnings
|
||||
from base64 import b64encode
|
||||
|
||||
from ..backends.base import ConcurrencyBackend
|
||||
from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
|
||||
from ..config import (
|
||||
DEFAULT_POOL_LIMITS,
|
||||
CertTypes,
|
||||
PoolLimits,
|
||||
SSLConfig,
|
||||
Timeout,
|
||||
VerifyTypes,
|
||||
)
|
||||
from ..exceptions import ProxyError
|
||||
from ..models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
|
||||
from ..utils import get_logger
|
||||
@ -55,6 +62,10 @@ class HTTPProxy(ConnectionPool):
|
||||
proxy_mode = proxy_mode.value
|
||||
assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY")
|
||||
|
||||
self.tunnel_ssl = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env, http2=False
|
||||
)
|
||||
|
||||
super(HTTPProxy, self).__init__(
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
@ -137,11 +148,8 @@ class HTTPProxy(ConnectionPool):
|
||||
|
||||
connection = HTTPConnection(
|
||||
self.proxy_url.origin,
|
||||
verify=self.verify,
|
||||
cert=self.cert,
|
||||
ssl=self.tunnel_ssl,
|
||||
backend=self.backend,
|
||||
http2=False, # Short-lived 'connection'
|
||||
trust_env=self.trust_env,
|
||||
release_func=self.release_connection,
|
||||
)
|
||||
self.active_connections.add(connection)
|
||||
|
||||
@ -29,27 +29,6 @@ def test_proxies_parameter(proxies, expected_proxies):
|
||||
assert len(expected_proxies) == len(client.proxies)
|
||||
|
||||
|
||||
def test_proxies_has_same_properties_as_dispatch():
|
||||
client = httpx.AsyncClient(
|
||||
proxies="http://127.0.0.1",
|
||||
verify="/path/to/verify",
|
||||
cert="/path/to/cert",
|
||||
trust_env=False,
|
||||
timeout=30,
|
||||
)
|
||||
pool = client.dispatch
|
||||
proxy = client.proxies["all"]
|
||||
|
||||
assert isinstance(proxy, httpx.HTTPProxy)
|
||||
|
||||
for prop in [
|
||||
"verify",
|
||||
"cert",
|
||||
"pool_limits",
|
||||
]:
|
||||
assert getattr(pool, prop) == getattr(proxy, prop)
|
||||
|
||||
|
||||
PROXY_URL = "http://[::1]"
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx.config import SSLConfig
|
||||
from httpx.dispatch.connection import HTTPConnection
|
||||
|
||||
|
||||
@ -35,7 +36,8 @@ async def test_https_get_with_ssl(https_server, ca_cert_pem_file):
|
||||
"""
|
||||
An HTTPS request, with SSL configuration set on the client.
|
||||
"""
|
||||
async with HTTPConnection(origin=https_server.url, verify=ca_cert_pem_file) as conn:
|
||||
ssl = SSLConfig(verify=ca_cert_pem_file)
|
||||
async with HTTPConnection(origin=https_server.url, ssl=ssl) as conn:
|
||||
response = await conn.request("GET", https_server.url)
|
||||
await response.aread()
|
||||
assert response.status_code == 200
|
||||
|
||||
Loading…
Reference in New Issue
Block a user