'Protocols' -> 'HTTPVersions'
This commit is contained in:
parent
cf098f8f4c
commit
f72fa68802
@ -6,8 +6,8 @@ from .config import (
|
||||
USER_AGENT,
|
||||
CertTypes,
|
||||
PoolLimits,
|
||||
ProtocolConfig,
|
||||
ProtocolTypes,
|
||||
HTTPVersionConfig,
|
||||
HTTPVersionTypes,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
TimeoutTypes,
|
||||
|
||||
@ -14,7 +14,7 @@ import ssl
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .config import PoolLimits, ProtocolConfig, TimeoutConfig
|
||||
from .config import PoolLimits, HTTPVersionConfig, TimeoutConfig
|
||||
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
from .interfaces import (
|
||||
BaseBackgroundManager,
|
||||
@ -202,7 +202,6 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
protocols: ProtocolConfig
|
||||
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
|
||||
try:
|
||||
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
|
||||
|
||||
@ -9,7 +9,7 @@ from .__version__ import __version__
|
||||
CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
|
||||
VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
|
||||
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
|
||||
ProtocolTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "ProtocolConfig"]
|
||||
HTTPVersionTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig"]
|
||||
|
||||
|
||||
USER_AGENT = f"python-httpx/{__version__}"
|
||||
@ -73,29 +73,29 @@ class SSLConfig:
|
||||
return self
|
||||
return SSLConfig(cert=cert, verify=verify)
|
||||
|
||||
def load_ssl_context(self, protocols: 'ProtocolConfig'=None) -> ssl.SSLContext:
|
||||
protocols = ProtocolConfig() if protocols is None else protocols
|
||||
def load_ssl_context(self, http_versions: 'HTTPVersionConfig'=None) -> ssl.SSLContext:
|
||||
http_versions = HTTPVersionConfig() if http_versions is None else http_versions
|
||||
|
||||
if self.ssl_context is None:
|
||||
self.ssl_context = (
|
||||
self.load_ssl_context_verify(protocols=protocols)
|
||||
self.load_ssl_context_verify(http_versions=http_versions)
|
||||
if self.verify
|
||||
else self.load_ssl_context_no_verify(protocols=protocols)
|
||||
else self.load_ssl_context_no_verify(http_versions=http_versions)
|
||||
)
|
||||
|
||||
assert self.ssl_context is not None
|
||||
return self.ssl_context
|
||||
|
||||
def load_ssl_context_no_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
|
||||
def load_ssl_context_no_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for unverified connections.
|
||||
"""
|
||||
context = self._create_default_ssl_context(protocols=protocols)
|
||||
context = self._create_default_ssl_context(http_versions=http_versions)
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
context.check_hostname = False
|
||||
return context
|
||||
|
||||
def load_ssl_context_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
|
||||
def load_ssl_context_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for verified connections.
|
||||
"""
|
||||
@ -109,7 +109,7 @@ class SSLConfig:
|
||||
"invalid path: {}".format(self.verify)
|
||||
)
|
||||
|
||||
context = self._create_default_ssl_context(protocols=protocols)
|
||||
context = self._create_default_ssl_context(http_versions=http_versions)
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
context.check_hostname = True
|
||||
|
||||
@ -136,7 +136,7 @@ class SSLConfig:
|
||||
|
||||
return context
|
||||
|
||||
def _create_default_ssl_context(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
|
||||
def _create_default_ssl_context(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
|
||||
"""
|
||||
Creates the default SSLContext object that's used for both verified
|
||||
and unverified connections.
|
||||
@ -150,9 +150,9 @@ class SSLConfig:
|
||||
context.set_ciphers(DEFAULT_CIPHERS)
|
||||
|
||||
if ssl.HAS_ALPN:
|
||||
context.set_alpn_protocols(protocols.protocol_ident_strings)
|
||||
context.set_alpn_protocols(http_versions.alpn_strings)
|
||||
if ssl.HAS_NPN: # pragma: no cover
|
||||
context.set_npn_protocols(protocols.protocol_ident_strings)
|
||||
context.set_npn_protocols(http_versions.alpn_strings)
|
||||
|
||||
return context
|
||||
|
||||
@ -226,38 +226,38 @@ class TimeoutConfig:
|
||||
)
|
||||
|
||||
|
||||
class ProtocolConfig:
|
||||
class HTTPVersionConfig:
|
||||
"""
|
||||
Configure which HTTP protocol versions are supported.
|
||||
"""
|
||||
|
||||
def __init__(self, protocols: ProtocolTypes = None):
|
||||
if protocols is None:
|
||||
protocols = ['HTTP/1.1', 'HTTP/2']
|
||||
def __init__(self, http_versions: HTTPVersionTypes = None):
|
||||
if http_versions is None:
|
||||
http_versions = ['HTTP/1.1', 'HTTP/2']
|
||||
|
||||
if isinstance(protocols, str):
|
||||
self.protocols = set([protocol])
|
||||
elif isinstance(protocols, ProtocolConfig):
|
||||
self.protocols = protocols.protocols
|
||||
if isinstance(http_versions, str):
|
||||
self.http_versions = set([http_versions])
|
||||
elif isinstance(http_versions, HTTPVersionConfig):
|
||||
self.http_versions = http_versions.http_versions
|
||||
else:
|
||||
self.protocols = set(sorted(protocols))
|
||||
self.http_versions = set(sorted(http_versions))
|
||||
|
||||
for protocol in self.protocols:
|
||||
if protocol not in ('HTTP/1.1', 'HTTP/2'):
|
||||
for version in self.http_versions:
|
||||
if version not in ('HTTP/1.1', 'HTTP/2'):
|
||||
raise ValueError(f"Unsupported protocol value {protocol!r}")
|
||||
|
||||
@property
|
||||
def protocol_ident_strings(self) -> typing.List[str]:
|
||||
def alpn_strings(self) -> typing.List[str]:
|
||||
"""
|
||||
Returns a list of supported ALPN identifiers. (One or more of "http/1.1", "h2").
|
||||
"""
|
||||
mapping = {"HTTP/1.1": "http/1.1", "HTTP/2": "h2"}
|
||||
return [mapping[protocol] for protocol in self.protocols]
|
||||
return [mapping[version] for version in self.http_versions]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
if len(self.protocols) == 1:
|
||||
value = self.protocols[0]
|
||||
return f"{class_name}(protocols={value!r})"
|
||||
value = list(self.protocols)
|
||||
return f"{class_name}(protocols={value!r})"
|
||||
value = list(self.http_versions)
|
||||
return f"{class_name}({value!r})"
|
||||
|
||||
|
||||
class PoolLimits:
|
||||
|
||||
@ -6,8 +6,8 @@ from ..concurrency import AsyncioBackend
|
||||
from ..config import (
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
CertTypes,
|
||||
ProtocolTypes,
|
||||
ProtocolConfig,
|
||||
HTTPVersionTypes,
|
||||
HTTPVersionConfig,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
TimeoutTypes,
|
||||
@ -31,12 +31,12 @@ class HTTPConnection(AsyncDispatcher):
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
backend: ConcurrencyBackend = None,
|
||||
release_func: typing.Optional[ReleaseCallback] = None,
|
||||
protocols: ProtocolTypes = None,
|
||||
http_versions: HTTPVersionTypes = None,
|
||||
):
|
||||
self.origin = Origin(origin) if isinstance(origin, str) else origin
|
||||
self.ssl = SSLConfig(cert=cert, verify=verify)
|
||||
self.timeout = TimeoutConfig(timeout)
|
||||
self.protocols = ProtocolConfig(protocols)
|
||||
self.http_versions = HTTPVersionConfig(http_versions)
|
||||
self.backend = AsyncioBackend() if backend is None else backend
|
||||
self.release_func = release_func
|
||||
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
|
||||
@ -48,10 +48,10 @@ class HTTPConnection(AsyncDispatcher):
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
protocols: ProtocolTypes = None,
|
||||
http_versions: HTTPVersionTypes = None,
|
||||
) -> AsyncResponse:
|
||||
if self.h11_connection is None and self.h2_connection is None:
|
||||
await self.connect(verify=verify, cert=cert, timeout=timeout, protocols=protocols)
|
||||
await self.connect(verify=verify, cert=cert, timeout=timeout, http_versions=http_versions)
|
||||
|
||||
if self.h2_connection is not None:
|
||||
response = await self.h2_connection.send(request, timeout=timeout)
|
||||
@ -66,15 +66,15 @@ class HTTPConnection(AsyncDispatcher):
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
protocols: ProtocolTypes = None,
|
||||
http_versions: HTTPVersionTypes = None,
|
||||
) -> None:
|
||||
ssl = self.ssl.with_overrides(verify=verify, cert=cert)
|
||||
timeout = self.timeout if timeout is None else TimeoutConfig(timeout)
|
||||
protocols = self.protocols if protocols is None else ProtocolConfig(protocols)
|
||||
http_versions = self.http_versions if http_versions is None else HTTPVersionConfig(http_versions)
|
||||
|
||||
host = self.origin.host
|
||||
port = self.origin.port
|
||||
ssl_context = await self.get_ssl_context(ssl, protocols)
|
||||
ssl_context = await self.get_ssl_context(ssl, http_versions)
|
||||
|
||||
if self.release_func is None:
|
||||
on_release = None
|
||||
@ -82,7 +82,7 @@ class HTTPConnection(AsyncDispatcher):
|
||||
on_release = functools.partial(self.release_func, self)
|
||||
|
||||
reader, writer, protocol = await self.backend.connect(
|
||||
host, port, ssl_context, timeout, protocols
|
||||
host, port, ssl_context, timeout
|
||||
)
|
||||
if protocol == Protocol.HTTP_2:
|
||||
self.h2_connection = HTTP2Connection(
|
||||
@ -93,12 +93,12 @@ class HTTPConnection(AsyncDispatcher):
|
||||
reader, writer, self.backend, on_release=on_release
|
||||
)
|
||||
|
||||
async def get_ssl_context(self, ssl: SSLConfig, protocols: ProtocolConfig) -> typing.Optional[ssl.SSLContext]:
|
||||
async def get_ssl_context(self, ssl: SSLConfig, http_versions: HTTPVersionConfig) -> typing.Optional[ssl.SSLContext]:
|
||||
if not self.origin.is_ssl:
|
||||
return None
|
||||
|
||||
# Run the SSL loading in a threadpool, since it may makes disk accesses.
|
||||
return await self.backend.run_in_threadpool(ssl.load_ssl_context, protocols)
|
||||
return await self.backend.run_in_threadpool(ssl.load_ssl_context, http_versions)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.h2_connection is not None:
|
||||
|
||||
@ -3,7 +3,7 @@ import ssl
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .config import CertTypes, PoolLimits, ProtocolConfig, TimeoutConfig, TimeoutTypes, VerifyTypes
|
||||
from .config import CertTypes, PoolLimits, HTTPVersionConfig, TimeoutConfig, TimeoutTypes, VerifyTypes
|
||||
from .models import (
|
||||
AsyncRequest,
|
||||
AsyncRequestData,
|
||||
@ -171,8 +171,7 @@ class ConcurrencyBackend:
|
||||
hostname: str,
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
protocols: ProtocolConfig
|
||||
timeout: TimeoutConfig
|
||||
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from httpx import (
|
||||
Protocol,
|
||||
Request,
|
||||
TimeoutConfig,
|
||||
ProtocolConfig
|
||||
HTTPVersionConfig
|
||||
)
|
||||
|
||||
|
||||
@ -27,8 +27,7 @@ class MockHTTP2Backend(AsyncioBackend):
|
||||
hostname: str,
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
protocols: ProtocolConfig
|
||||
timeout: TimeoutConfig
|
||||
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
|
||||
self.server = MockHTTP2Server(self.app)
|
||||
return self.server, self.server, Protocol.HTTP_2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user