Concurrency autodetection (#585)
* Simplify HTTP version config, and switch HTTP/2 off by default * HTTP/2 docs * HTTP/2 interlinking in docs * Add concurrency auto-detection * Add sniffio
This commit is contained in:
parent
30229f1652
commit
3cbe7315e8
@ -113,6 +113,7 @@ The httpx project relies on these excellent libraries:
|
||||
* `hstspreload` - determines whether IDNA-encoded host should be only accessed via HTTPS.
|
||||
* `idna` - Internationalized domain name support.
|
||||
* `rfc3986` - URL parsing & normalization.
|
||||
* `sniffio` - Async library autodetection.
|
||||
* `brotlipy` - Decoding for "brotli" compressed responses. *(Optional)*
|
||||
|
||||
A huge amount of credit is due to `requests` for the API layout that
|
||||
|
||||
@ -5,7 +5,6 @@ from types import TracebackType
|
||||
import hstspreload
|
||||
|
||||
from .auth import BasicAuth
|
||||
from .concurrency.asyncio import AsyncioBackend
|
||||
from .concurrency.base import ConcurrencyBackend
|
||||
from .config import (
|
||||
DEFAULT_MAX_REDIRECTS,
|
||||
@ -95,7 +94,8 @@ class Client:
|
||||
* **app** - *(optional)* An ASGI application to send requests to,
|
||||
rather than sending actual network requests.
|
||||
* **backend** - *(optional)* A concurrency backend to use when issuing
|
||||
async requests.
|
||||
async requests. Either 'auto', 'asyncio', 'trio', or a `ConcurrencyBackend`
|
||||
instance. Defaults to 'auto', for autodetection.
|
||||
* **trust_env** - *(optional)* Enables or disables usage of environment
|
||||
variables for configuration.
|
||||
* **uds** - *(optional)* A path to a Unix domain socket to connect through.
|
||||
@ -118,15 +118,12 @@ class Client:
|
||||
base_url: URLTypes = None,
|
||||
dispatch: Dispatcher = None,
|
||||
app: typing.Callable = None,
|
||||
backend: ConcurrencyBackend = None,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
trust_env: bool = True,
|
||||
uds: str = None,
|
||||
):
|
||||
if backend is None:
|
||||
backend = AsyncioBackend()
|
||||
|
||||
if app is not None:
|
||||
dispatch = ASGIDispatch(app=app, backend=backend)
|
||||
dispatch = ASGIDispatch(app=app)
|
||||
|
||||
if dispatch is None:
|
||||
dispatch = ConnectionPool(
|
||||
@ -155,7 +152,6 @@ class Client:
|
||||
self.max_redirects = max_redirects
|
||||
self.trust_env = trust_env
|
||||
self.dispatch = dispatch
|
||||
self.concurrency_backend = backend
|
||||
self.netrc = NetRCInfo()
|
||||
|
||||
if proxies is None and trust_env:
|
||||
@ -834,7 +830,7 @@ def _proxies_to_dispatchers(
|
||||
timeout: TimeoutTypes,
|
||||
http_2: bool,
|
||||
pool_limits: PoolLimits,
|
||||
backend: ConcurrencyBackend,
|
||||
backend: typing.Union[str, ConcurrencyBackend],
|
||||
trust_env: bool,
|
||||
) -> typing.Dict[str, Dispatcher]:
|
||||
def _proxy_from_url(url: URLTypes) -> Dispatcher:
|
||||
|
||||
59
httpx/concurrency/auto.py
Normal file
59
httpx/concurrency/auto.py
Normal file
@ -0,0 +1,59 @@
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
import sniffio
|
||||
|
||||
from ..config import PoolLimits, TimeoutConfig
|
||||
from .base import (
|
||||
BaseBackgroundManager,
|
||||
BaseEvent,
|
||||
BasePoolSemaphore,
|
||||
BaseSocketStream,
|
||||
ConcurrencyBackend,
|
||||
lookup_backend,
|
||||
)
|
||||
|
||||
|
||||
class AutoBackend(ConcurrencyBackend):
|
||||
@property
|
||||
def backend(self) -> ConcurrencyBackend:
|
||||
if not hasattr(self, "_backend_implementation"):
|
||||
backend = sniffio.current_async_library()
|
||||
if backend not in ("asyncio", "trio"):
|
||||
raise RuntimeError(f"Unsupported concurrency backend {backend!r}")
|
||||
self._backend_implementation = lookup_backend(backend)
|
||||
return self._backend_implementation
|
||||
|
||||
async def open_tcp_stream(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseSocketStream:
|
||||
return await self.backend.open_tcp_stream(hostname, port, ssl_context, timeout)
|
||||
|
||||
async def open_uds_stream(
|
||||
self,
|
||||
path: str,
|
||||
hostname: typing.Optional[str],
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
) -> BaseSocketStream:
|
||||
return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)
|
||||
|
||||
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
|
||||
return self.backend.get_semaphore(limits)
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
return await self.backend.run_in_threadpool(func, *args, **kwargs)
|
||||
|
||||
def create_event(self) -> BaseEvent:
|
||||
return self.backend.create_event()
|
||||
|
||||
def background_manager(
|
||||
self, coroutine: typing.Callable, *args: typing.Any
|
||||
) -> BaseBackgroundManager:
|
||||
return self.backend.background_manager(coroutine, *args)
|
||||
@ -5,6 +5,28 @@ from types import TracebackType
|
||||
from ..config import PoolLimits, TimeoutConfig
|
||||
|
||||
|
||||
def lookup_backend(
|
||||
backend: typing.Union[str, "ConcurrencyBackend"] = "auto"
|
||||
) -> "ConcurrencyBackend":
|
||||
if not isinstance(backend, str):
|
||||
return backend
|
||||
|
||||
if backend == "auto":
|
||||
from .auto import AutoBackend
|
||||
|
||||
return AutoBackend()
|
||||
elif backend == "asyncio":
|
||||
from .asyncio import AsyncioBackend
|
||||
|
||||
return AsyncioBackend()
|
||||
elif backend == "trio":
|
||||
from .trio import TrioBackend
|
||||
|
||||
return TrioBackend()
|
||||
|
||||
raise RuntimeError(f"Unknown or unsupported concurrency backend {backend!r}")
|
||||
|
||||
|
||||
class TimeoutFlag:
|
||||
"""
|
||||
A timeout flag holds a state of either read-timeout or write-timeout mode.
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import typing
|
||||
|
||||
from ..concurrency.asyncio import AsyncioBackend
|
||||
from ..concurrency.base import ConcurrencyBackend
|
||||
from ..config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from ..models import Request, Response
|
||||
from .base import Dispatcher
|
||||
@ -49,13 +47,11 @@ class ASGIDispatch(Dispatcher):
|
||||
raise_app_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
|
||||
backend: ConcurrencyBackend = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
self.root_path = root_path
|
||||
self.client = client
|
||||
self.backend = AsyncioBackend() if backend is None else backend
|
||||
|
||||
async def send(
|
||||
self,
|
||||
|
||||
@ -2,8 +2,7 @@ import functools
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from ..concurrency.asyncio import AsyncioBackend
|
||||
from ..concurrency.base import ConcurrencyBackend
|
||||
from ..concurrency.base import ConcurrencyBackend, lookup_backend
|
||||
from ..config import (
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
CertTypes,
|
||||
@ -34,7 +33,7 @@ class HTTPConnection(Dispatcher):
|
||||
trust_env: bool = None,
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
http_2: bool = False,
|
||||
backend: ConcurrencyBackend = None,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
release_func: typing.Optional[ReleaseCallback] = None,
|
||||
uds: typing.Optional[str] = None,
|
||||
):
|
||||
@ -42,7 +41,7 @@ class HTTPConnection(Dispatcher):
|
||||
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
|
||||
self.timeout = TimeoutConfig(timeout)
|
||||
self.http_2 = http_2
|
||||
self.backend = AsyncioBackend() if backend is None else backend
|
||||
self.backend = lookup_backend(backend)
|
||||
self.release_func = release_func
|
||||
self.uds = uds
|
||||
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
|
||||
@ -104,13 +103,11 @@ class HTTPConnection(Dispatcher):
|
||||
|
||||
if http_version == "HTTP/2":
|
||||
self.h2_connection = HTTP2Connection(
|
||||
stream, self.backend, on_release=on_release
|
||||
stream, backend=self.backend, on_release=on_release
|
||||
)
|
||||
else:
|
||||
assert http_version == "HTTP/1.1"
|
||||
self.h11_connection = HTTP11Connection(
|
||||
stream, self.backend, on_release=on_release
|
||||
)
|
||||
self.h11_connection = HTTP11Connection(stream, on_release=on_release)
|
||||
|
||||
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
|
||||
if not self.origin.is_ssl:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import typing
|
||||
|
||||
from ..concurrency.asyncio import AsyncioBackend
|
||||
from ..concurrency.base import ConcurrencyBackend
|
||||
from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend
|
||||
from ..config import (
|
||||
DEFAULT_POOL_LIMITS,
|
||||
DEFAULT_TIMEOUT_CONFIG,
|
||||
@ -88,7 +87,7 @@ class ConnectionPool(Dispatcher):
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
http_2: bool = False,
|
||||
backend: ConcurrencyBackend = None,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
uds: typing.Optional[str] = None,
|
||||
):
|
||||
self.verify = verify
|
||||
@ -103,8 +102,15 @@ class ConnectionPool(Dispatcher):
|
||||
self.keepalive_connections = ConnectionStore()
|
||||
self.active_connections = ConnectionStore()
|
||||
|
||||
self.backend = AsyncioBackend() if backend is None else backend
|
||||
self.max_connections = self.backend.get_semaphore(pool_limits)
|
||||
self.backend = lookup_backend(backend)
|
||||
|
||||
@property
|
||||
def max_connections(self) -> BasePoolSemaphore:
|
||||
# We do this lazily, to make sure backend autodetection always
|
||||
# runs within an async context.
|
||||
if not hasattr(self, "_max_connections"):
|
||||
self._max_connections = self.backend.get_semaphore(self.pool_limits)
|
||||
return self._max_connections
|
||||
|
||||
@property
|
||||
def num_connections(self) -> int:
|
||||
|
||||
@ -2,7 +2,7 @@ import typing
|
||||
|
||||
import h11
|
||||
|
||||
from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, TimeoutFlag
|
||||
from ..concurrency.base import BaseSocketStream, TimeoutFlag
|
||||
from ..config import TimeoutConfig, TimeoutTypes
|
||||
from ..exceptions import ConnectionClosed, ProtocolError
|
||||
from ..models import Request, Response
|
||||
@ -33,11 +33,9 @@ class HTTP11Connection:
|
||||
def __init__(
|
||||
self,
|
||||
stream: BaseSocketStream,
|
||||
backend: ConcurrencyBackend,
|
||||
on_release: typing.Optional[OnReleaseCallback] = None,
|
||||
):
|
||||
self.stream = stream
|
||||
self.backend = backend
|
||||
self.on_release = on_release
|
||||
self.h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
self.timeout_flag = TimeoutFlag()
|
||||
|
||||
@ -10,6 +10,7 @@ from ..concurrency.base import (
|
||||
BaseSocketStream,
|
||||
ConcurrencyBackend,
|
||||
TimeoutFlag,
|
||||
lookup_backend,
|
||||
)
|
||||
from ..config import TimeoutConfig, TimeoutTypes
|
||||
from ..exceptions import ProtocolError
|
||||
@ -25,11 +26,11 @@ class HTTP2Connection:
|
||||
def __init__(
|
||||
self,
|
||||
stream: BaseSocketStream,
|
||||
backend: ConcurrencyBackend,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
on_release: typing.Callable = None,
|
||||
):
|
||||
self.stream = stream
|
||||
self.backend = backend
|
||||
self.backend = lookup_backend(backend)
|
||||
self.on_release = on_release
|
||||
self.h2_state = h2.connection.H2Connection()
|
||||
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import enum
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
|
||||
import h11
|
||||
@ -47,7 +48,7 @@ class HTTPProxy(ConnectionPool):
|
||||
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
http_2: bool = False,
|
||||
backend: ConcurrencyBackend = None,
|
||||
backend: typing.Union[str, ConcurrencyBackend] = "auto",
|
||||
):
|
||||
|
||||
super(HTTPProxy, self).__init__(
|
||||
@ -207,9 +208,7 @@ class HTTPProxy(ConnectionPool):
|
||||
)
|
||||
else:
|
||||
assert http_version == "HTTP/1.1"
|
||||
connection.h11_connection = HTTP11Connection(
|
||||
stream, self.backend, on_release=on_release
|
||||
)
|
||||
connection.h11_connection = HTTP11Connection(stream, on_release=on_release)
|
||||
|
||||
def should_forward_origin(self, origin: Origin) -> bool:
|
||||
"""Determines if the given origin should
|
||||
|
||||
@ -14,7 +14,7 @@ combine_as_imports = True
|
||||
force_grid_wrap = 0
|
||||
include_trailing_comma = True
|
||||
known_first_party = httpx,httpxprof,tests
|
||||
known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,tqdm,trio,trustme,uvicorn
|
||||
known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,sniffio,tqdm,trio,trustme,uvicorn
|
||||
line_length = 88
|
||||
multi_line_output = 3
|
||||
|
||||
|
||||
3
setup.py
3
setup.py
@ -52,12 +52,13 @@ setup(
|
||||
zip_safe=False,
|
||||
install_requires=[
|
||||
"certifi",
|
||||
"hstspreload",
|
||||
"chardet==3.*",
|
||||
"h11==0.8.*",
|
||||
"h2==3.*",
|
||||
"hstspreload>=2019.8.27",
|
||||
"idna==2.*",
|
||||
"rfc3986==1.*",
|
||||
"sniffio==1.*",
|
||||
],
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
|
||||
@ -48,7 +48,6 @@ def test_proxies_has_same_properties_as_dispatch():
|
||||
"cert",
|
||||
"timeout",
|
||||
"pool_limits",
|
||||
"backend",
|
||||
]:
|
||||
assert getattr(pool, prop) == getattr(proxy, prop)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user