Merge branch 'protocol-support' of https://github.com/encode/httpx into protocol-support

This commit is contained in:
Tom Christie 2019-08-20 12:03:55 +01:00
commit cf098f8f4c
6 changed files with 73 additions and 20 deletions

View File

@ -6,6 +6,8 @@ from .config import (
USER_AGENT,
CertTypes,
PoolLimits,
ProtocolConfig,
ProtocolTypes,
SSLConfig,
TimeoutConfig,
TimeoutTypes,

View File

@ -14,7 +14,7 @@ import ssl
import typing
from types import TracebackType
from .config import PoolLimits, TimeoutConfig
from .config import PoolLimits, ProtocolConfig, TimeoutConfig
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .interfaces import (
BaseBackgroundManager,
@ -202,6 +202,7 @@ class AsyncioBackend(ConcurrencyBackend):
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
protocols: ProtocolConfig
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore

View File

@ -9,6 +9,7 @@ from .__version__ import __version__
CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
ProtocolTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "ProtocolConfig"]
USER_AGENT = f"python-httpx/{__version__}"
@ -72,27 +73,29 @@ class SSLConfig:
return self
return SSLConfig(cert=cert, verify=verify)
def load_ssl_context(self) -> ssl.SSLContext:
def load_ssl_context(self, protocols: 'ProtocolConfig'=None) -> ssl.SSLContext:
protocols = ProtocolConfig() if protocols is None else protocols
if self.ssl_context is None:
self.ssl_context = (
self.load_ssl_context_verify()
self.load_ssl_context_verify(protocols=protocols)
if self.verify
else self.load_ssl_context_no_verify()
else self.load_ssl_context_no_verify(protocols=protocols)
)
assert self.ssl_context is not None
return self.ssl_context
def load_ssl_context_no_verify(self) -> ssl.SSLContext:
def load_ssl_context_no_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
"""
Return an SSL context for unverified connections.
"""
context = self._create_default_ssl_context()
context = self._create_default_ssl_context(protocols=protocols)
context.verify_mode = ssl.CERT_NONE
context.check_hostname = False
return context
def load_ssl_context_verify(self) -> ssl.SSLContext:
def load_ssl_context_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
"""
Return an SSL context for verified connections.
"""
@ -106,7 +109,7 @@ class SSLConfig:
"invalid path: {}".format(self.verify)
)
context = self._create_default_ssl_context()
context = self._create_default_ssl_context(protocols=protocols)
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
@ -133,7 +136,7 @@ class SSLConfig:
return context
def _create_default_ssl_context(self) -> ssl.SSLContext:
def _create_default_ssl_context(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
"""
Creates the default SSLContext object that's used for both verified
and unverified connections.
@ -147,9 +150,9 @@ class SSLConfig:
context.set_ciphers(DEFAULT_CIPHERS)
if ssl.HAS_ALPN:
context.set_alpn_protocols(["h2", "http/1.1"])
context.set_alpn_protocols(protocols.protocol_ident_strings)
if ssl.HAS_NPN: # pragma: no cover
context.set_npn_protocols(["h2", "http/1.1"])
context.set_npn_protocols(protocols.protocol_ident_strings)
return context
@ -223,6 +226,40 @@ class TimeoutConfig:
)
class ProtocolConfig:
"""
Configure which HTTP protocol versions are supported.
"""
def __init__(self, protocols: ProtocolTypes = None):
if protocols is None:
protocols = ['HTTP/1.1', 'HTTP/2']
if isinstance(protocols, str):
self.protocols = set([protocol])
elif isinstance(protocols, ProtocolConfig):
self.protocols = protocols.protocols
else:
self.protocols = set(sorted(protocols))
for protocol in self.protocols:
if protocol not in ('HTTP/1.1', 'HTTP/2'):
raise ValueError(f"Unsupported protocol value {protocol!r}")
@property
def protocol_ident_strings(self) -> typing.List[str]:
mapping = {"HTTP/1.1": "http/1.1", "HTTP/2": "h2"}
return [mapping[protocol] for protocol in self.protocols]
def __repr__(self) -> str:
class_name = self.__class__.__name__
if len(self.protocols) == 1:
value = self.protocols[0]
return f"{class_name}(protocols={value!r})"
value = list(self.protocols)
return f"{class_name}(protocols={value!r})"
class PoolLimits:
"""
Limits on the number of connections in a connection pool.

View File

@ -1,10 +1,13 @@
import functools
import typing
import ssl
from ..concurrency import AsyncioBackend
from ..config import (
DEFAULT_TIMEOUT_CONFIG,
CertTypes,
ProtocolTypes,
ProtocolConfig,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
@ -28,10 +31,12 @@ class HTTPConnection(AsyncDispatcher):
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
backend: ConcurrencyBackend = None,
release_func: typing.Optional[ReleaseCallback] = None,
protocols: ProtocolTypes = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = SSLConfig(cert=cert, verify=verify)
self.timeout = TimeoutConfig(timeout)
self.protocols = ProtocolConfig(protocols)
self.backend = AsyncioBackend() if backend is None else backend
self.release_func = release_func
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
@ -43,9 +48,10 @@ class HTTPConnection(AsyncDispatcher):
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
protocols: ProtocolTypes = None,
) -> AsyncResponse:
if self.h11_connection is None and self.h2_connection is None:
await self.connect(verify=verify, cert=cert, timeout=timeout)
await self.connect(verify=verify, cert=cert, timeout=timeout, protocols=protocols)
if self.h2_connection is not None:
response = await self.h2_connection.send(request, timeout=timeout)
@ -60,18 +66,15 @@ class HTTPConnection(AsyncDispatcher):
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
protocols: ProtocolTypes = None,
) -> None:
ssl = self.ssl.with_overrides(verify=verify, cert=cert)
timeout = self.timeout if timeout is None else TimeoutConfig(timeout)
protocols = self.protocols if protocols is None else ProtocolConfig(protocols)
host = self.origin.host
port = self.origin.port
# Run the SSL loading in a threadpool, since it makes disk accesses.
ssl_context = (
await self.backend.run_in_threadpool(ssl.load_ssl_context)
if self.origin.is_ssl
else None
)
ssl_context = await self.get_ssl_context(ssl, protocols)
if self.release_func is None:
on_release = None
@ -79,7 +82,7 @@ class HTTPConnection(AsyncDispatcher):
on_release = functools.partial(self.release_func, self)
reader, writer, protocol = await self.backend.connect(
host, port, ssl_context, timeout
host, port, ssl_context, timeout, protocols
)
if protocol == Protocol.HTTP_2:
self.h2_connection = HTTP2Connection(
@ -90,6 +93,13 @@ class HTTPConnection(AsyncDispatcher):
reader, writer, self.backend, on_release=on_release
)
async def get_ssl_context(self, ssl: SSLConfig, protocols: ProtocolConfig) -> typing.Optional[ssl.SSLContext]:
if not self.origin.is_ssl:
return None
# Run the SSL loading in a threadpool, since it may makes disk accesses.
return await self.backend.run_in_threadpool(ssl.load_ssl_context, protocols)
async def close(self) -> None:
if self.h2_connection is not None:
await self.h2_connection.close()

View File

@ -3,7 +3,7 @@ import ssl
import typing
from types import TracebackType
from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes
from .config import CertTypes, PoolLimits, ProtocolConfig, TimeoutConfig, TimeoutTypes, VerifyTypes
from .models import (
AsyncRequest,
AsyncRequestData,
@ -172,6 +172,7 @@ class ConcurrencyBackend:
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
protocols: ProtocolConfig
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
raise NotImplementedError() # pragma: no cover

View File

@ -13,6 +13,7 @@ from httpx import (
Protocol,
Request,
TimeoutConfig,
ProtocolConfig
)
@ -27,6 +28,7 @@ class MockHTTP2Backend(AsyncioBackend):
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
protocols: ProtocolConfig
) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
self.server = MockHTTP2Server(self.app)
return self.server, self.server, Protocol.HTTP_2