Concurrency autodetection (#585)

* Simplify HTTP version config, and switch HTTP/2 off by default

* HTTP/2 docs

* HTTP/2 interlinking in docs

* Add concurrency auto-detection

* Add sniffio
This commit is contained in:
Tom Christie 2019-12-02 19:26:16 +00:00 committed by GitHub
parent 30229f1652
commit 3cbe7315e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 113 additions and 38 deletions

View File

@ -113,6 +113,7 @@ The httpx project relies on these excellent libraries:
* `hstspreload` - determines whether IDNA-encoded host should be only accessed via HTTPS.
* `idna` - Internationalized domain name support.
* `rfc3986` - URL parsing & normalization.
* `sniffio` - Async library autodetection.
* `brotlipy` - Decoding for "brotli" compressed responses. *(Optional)*
A huge amount of credit is due to `requests` for the API layout that

View File

@ -5,7 +5,6 @@ from types import TracebackType
import hstspreload
from .auth import BasicAuth
from .concurrency.asyncio import AsyncioBackend
from .concurrency.base import ConcurrencyBackend
from .config import (
DEFAULT_MAX_REDIRECTS,
@ -95,7 +94,8 @@ class Client:
* **app** - *(optional)* An ASGI application to send requests to,
rather than sending actual network requests.
* **backend** - *(optional)* A concurrency backend to use when issuing
async requests.
async requests. Either 'auto', 'asyncio', 'trio', or a `ConcurrencyBackend`
instance. Defaults to 'auto', for autodetection.
* **trust_env** - *(optional)* Enables or disables usage of environment
variables for configuration.
* **uds** - *(optional)* A path to a Unix domain socket to connect through.
@ -118,15 +118,12 @@ class Client:
base_url: URLTypes = None,
dispatch: Dispatcher = None,
app: typing.Callable = None,
backend: ConcurrencyBackend = None,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
trust_env: bool = True,
uds: str = None,
):
if backend is None:
backend = AsyncioBackend()
if app is not None:
dispatch = ASGIDispatch(app=app, backend=backend)
dispatch = ASGIDispatch(app=app)
if dispatch is None:
dispatch = ConnectionPool(
@ -155,7 +152,6 @@ class Client:
self.max_redirects = max_redirects
self.trust_env = trust_env
self.dispatch = dispatch
self.concurrency_backend = backend
self.netrc = NetRCInfo()
if proxies is None and trust_env:
@ -834,7 +830,7 @@ def _proxies_to_dispatchers(
timeout: TimeoutTypes,
http_2: bool,
pool_limits: PoolLimits,
backend: ConcurrencyBackend,
backend: typing.Union[str, ConcurrencyBackend],
trust_env: bool,
) -> typing.Dict[str, Dispatcher]:
def _proxy_from_url(url: URLTypes) -> Dispatcher:

59
httpx/concurrency/auto.py Normal file
View File

@ -0,0 +1,59 @@
import ssl
import typing
import sniffio
from ..config import PoolLimits, TimeoutConfig
from .base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseSocketStream,
ConcurrencyBackend,
lookup_backend,
)
class AutoBackend(ConcurrencyBackend):
@property
def backend(self) -> ConcurrencyBackend:
if not hasattr(self, "_backend_implementation"):
backend = sniffio.current_async_library()
if backend not in ("asyncio", "trio"):
raise RuntimeError(f"Unsupported concurrency backend {backend!r}")
self._backend_implementation = lookup_backend(backend)
return self._backend_implementation
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseSocketStream:
return await self.backend.open_tcp_stream(hostname, port, ssl_context, timeout)
async def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> BaseSocketStream:
return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return self.backend.get_semaphore(limits)
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return await self.backend.run_in_threadpool(func, *args, **kwargs)
def create_event(self) -> BaseEvent:
return self.backend.create_event()
def background_manager(
self, coroutine: typing.Callable, *args: typing.Any
) -> BaseBackgroundManager:
return self.backend.background_manager(coroutine, *args)

View File

@ -5,6 +5,28 @@ from types import TracebackType
from ..config import PoolLimits, TimeoutConfig
def lookup_backend(
backend: typing.Union[str, "ConcurrencyBackend"] = "auto"
) -> "ConcurrencyBackend":
if not isinstance(backend, str):
return backend
if backend == "auto":
from .auto import AutoBackend
return AutoBackend()
elif backend == "asyncio":
from .asyncio import AsyncioBackend
return AsyncioBackend()
elif backend == "trio":
from .trio import TrioBackend
return TrioBackend()
raise RuntimeError(f"Unknown or unsupported concurrency backend {backend!r}")
class TimeoutFlag:
"""
A timeout flag holds a state of either read-timeout or write-timeout mode.

View File

@ -1,7 +1,5 @@
import typing
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import Request, Response
from .base import Dispatcher
@ -49,13 +47,11 @@ class ASGIDispatch(Dispatcher):
raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
backend: ConcurrencyBackend = None,
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
self.backend = AsyncioBackend() if backend is None else backend
async def send(
self,

View File

@ -2,8 +2,7 @@ import functools
import ssl
import typing
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..concurrency.base import ConcurrencyBackend, lookup_backend
from ..config import (
DEFAULT_TIMEOUT_CONFIG,
CertTypes,
@ -34,7 +33,7 @@ class HTTPConnection(Dispatcher):
trust_env: bool = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
http_2: bool = False,
backend: ConcurrencyBackend = None,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
release_func: typing.Optional[ReleaseCallback] = None,
uds: typing.Optional[str] = None,
):
@ -42,7 +41,7 @@ class HTTPConnection(Dispatcher):
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
self.timeout = TimeoutConfig(timeout)
self.http_2 = http_2
self.backend = AsyncioBackend() if backend is None else backend
self.backend = lookup_backend(backend)
self.release_func = release_func
self.uds = uds
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
@ -104,13 +103,11 @@ class HTTPConnection(Dispatcher):
if http_version == "HTTP/2":
self.h2_connection = HTTP2Connection(
stream, self.backend, on_release=on_release
stream, backend=self.backend, on_release=on_release
)
else:
assert http_version == "HTTP/1.1"
self.h11_connection = HTTP11Connection(
stream, self.backend, on_release=on_release
)
self.h11_connection = HTTP11Connection(stream, on_release=on_release)
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
if not self.origin.is_ssl:

View File

@ -1,7 +1,6 @@
import typing
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend
from ..config import (
DEFAULT_POOL_LIMITS,
DEFAULT_TIMEOUT_CONFIG,
@ -88,7 +87,7 @@ class ConnectionPool(Dispatcher):
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http_2: bool = False,
backend: ConcurrencyBackend = None,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
uds: typing.Optional[str] = None,
):
self.verify = verify
@ -103,8 +102,15 @@ class ConnectionPool(Dispatcher):
self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
self.backend = AsyncioBackend() if backend is None else backend
self.max_connections = self.backend.get_semaphore(pool_limits)
self.backend = lookup_backend(backend)
@property
def max_connections(self) -> BasePoolSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_connections"):
self._max_connections = self.backend.get_semaphore(self.pool_limits)
return self._max_connections
@property
def num_connections(self) -> int:

View File

@ -2,7 +2,7 @@ import typing
import h11
from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseSocketStream, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectionClosed, ProtocolError
from ..models import Request, Response
@ -33,11 +33,9 @@ class HTTP11Connection:
def __init__(
self,
stream: BaseSocketStream,
backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
self.stream = stream
self.backend = backend
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
self.timeout_flag = TimeoutFlag()

View File

@ -10,6 +10,7 @@ from ..concurrency.base import (
BaseSocketStream,
ConcurrencyBackend,
TimeoutFlag,
lookup_backend,
)
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import ProtocolError
@ -25,11 +26,11 @@ class HTTP2Connection:
def __init__(
self,
stream: BaseSocketStream,
backend: ConcurrencyBackend,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
on_release: typing.Callable = None,
):
self.stream = stream
self.backend = backend
self.backend = lookup_backend(backend)
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]

View File

@ -1,4 +1,5 @@
import enum
import typing
from base64 import b64encode
import h11
@ -47,7 +48,7 @@ class HTTPProxy(ConnectionPool):
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
http_2: bool = False,
backend: ConcurrencyBackend = None,
backend: typing.Union[str, ConcurrencyBackend] = "auto",
):
super(HTTPProxy, self).__init__(
@ -207,9 +208,7 @@ class HTTPProxy(ConnectionPool):
)
else:
assert http_version == "HTTP/1.1"
connection.h11_connection = HTTP11Connection(
stream, self.backend, on_release=on_release
)
connection.h11_connection = HTTP11Connection(stream, on_release=on_release)
def should_forward_origin(self, origin: Origin) -> bool:
"""Determines if the given origin should

View File

@ -14,7 +14,7 @@ combine_as_imports = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = httpx,httpxprof,tests
known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,tqdm,trio,trustme,uvicorn
known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,sniffio,tqdm,trio,trustme,uvicorn
line_length = 88
multi_line_output = 3

View File

@ -52,12 +52,13 @@ setup(
zip_safe=False,
install_requires=[
"certifi",
"hstspreload",
"chardet==3.*",
"h11==0.8.*",
"h2==3.*",
"hstspreload>=2019.8.27",
"idna==2.*",
"rfc3986==1.*",
"sniffio==1.*",
],
classifiers=[
"Development Status :: 3 - Alpha",

View File

@ -48,7 +48,6 @@ def test_proxies_has_same_properties_as_dispatch():
"cert",
"timeout",
"pool_limits",
"backend",
]:
assert getattr(pool, prop) == getattr(proxy, prop)