Add trio concurrency backend (#276)
This commit is contained in:
parent
12752466ae
commit
08355c62f5
@ -81,7 +81,7 @@ class BaseClient:
|
||||
if param_count == 2:
|
||||
dispatch = WSGIDispatch(app=app)
|
||||
else:
|
||||
dispatch = ASGIDispatch(app=app)
|
||||
dispatch = ASGIDispatch(app=app, backend=backend)
|
||||
|
||||
self.trust_env = True if trust_env is None else trust_env
|
||||
|
||||
|
||||
@ -112,6 +112,20 @@ class TCPStream(BaseTCPStream):
|
||||
raise WriteTimeout() from None
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
# Counter-intuitively, what we really want to know here is whether the socket is
|
||||
# *readable*, i.e. whether it would return immediately with empty bytes if we
|
||||
# called `.recv()` on it, indicating that the other end has closed the socket.
|
||||
# See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
|
||||
#
|
||||
# As it turns out, asyncio checks for readability in the background
|
||||
# (see: https://github.com/encode/httpx/pull/276#discussion_r322000402),
|
||||
# so checking for EOF or readability here would yield the same result.
|
||||
#
|
||||
# At the cost of rigour, we check for EOF instead of readability because asyncio
|
||||
# does not expose any public API to check for readability.
|
||||
# (For a solution that uses private asyncio APIs, see:
|
||||
# https://github.com/encode/httpx/pull/143#issuecomment-515202982)
|
||||
|
||||
return self.stream_reader.at_eof()
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
@ -187,3 +187,10 @@ class BaseBackgroundManager:
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def close(self, exception: BaseException = None) -> None:
|
||||
if exception is None:
|
||||
await self.__aexit__(None, None, None)
|
||||
else:
|
||||
traceback = exception.__traceback__ # type: ignore
|
||||
await self.__aexit__(type(exception), exception, traceback)
|
||||
|
||||
255
httpx/concurrency/trio.py
Normal file
255
httpx/concurrency/trio.py
Normal file
@ -0,0 +1,255 @@
|
||||
import functools
|
||||
import math
|
||||
import ssl
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
import trio
|
||||
|
||||
from ..config import PoolLimits, TimeoutConfig
|
||||
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
|
||||
from .base import (
|
||||
BaseBackgroundManager,
|
||||
BaseEvent,
|
||||
BasePoolSemaphore,
|
||||
BaseQueue,
|
||||
BaseTCPStream,
|
||||
ConcurrencyBackend,
|
||||
TimeoutFlag,
|
||||
)
|
||||
|
||||
|
||||
def _or_inf(value: typing.Optional[float]) -> float:
|
||||
return value if value is not None else float("inf")
|
||||
|
||||
|
||||
class TCPStream(BaseTCPStream):
|
||||
def __init__(
|
||||
self,
|
||||
stream: typing.Union[trio.SocketStream, trio.SSLStream],
|
||||
timeout: TimeoutConfig,
|
||||
) -> None:
|
||||
self.stream = stream
|
||||
self.timeout = timeout
|
||||
self.write_buffer = b""
|
||||
self.write_lock = trio.Lock()
|
||||
|
||||
def get_http_version(self) -> str:
|
||||
if not isinstance(self.stream, trio.SSLStream):
|
||||
return "HTTP/1.1"
|
||||
|
||||
ident = self.stream.selected_alpn_protocol()
|
||||
if ident is None:
|
||||
return "HTTP/1.1"
|
||||
|
||||
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
|
||||
|
||||
async def read(
|
||||
self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
|
||||
) -> bytes:
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
while True:
|
||||
# Check our flag at the first possible moment, and use a fine
|
||||
# grained retry loop if we're not yet in read-timeout mode.
|
||||
should_raise = flag is None or flag.raise_on_read_timeout
|
||||
read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01)
|
||||
|
||||
with trio.move_on_after(read_timeout):
|
||||
return await self.stream.receive_some(max_bytes=n)
|
||||
|
||||
if should_raise:
|
||||
raise ReadTimeout() from None
|
||||
|
||||
def is_connection_dropped(self) -> bool:
|
||||
# Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982
|
||||
stream = self.stream
|
||||
|
||||
# Peek through any SSLStream wrappers to get the underlying SocketStream.
|
||||
while hasattr(stream, "transport_stream"):
|
||||
stream = stream.transport_stream
|
||||
assert isinstance(stream, trio.SocketStream)
|
||||
|
||||
# Counter-intuitively, what we really want to know here is whether the socket is
|
||||
# *readable*, i.e. whether it would return immediately with empty bytes if we
|
||||
# called `.recv()` on it, indicating that the other end has closed the socket.
|
||||
# See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
|
||||
return stream.socket.is_readable()
|
||||
|
||||
def write_no_block(self, data: bytes) -> None:
|
||||
self.write_buffer += data
|
||||
|
||||
async def write(
|
||||
self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
|
||||
) -> None:
|
||||
if self.write_buffer:
|
||||
previous_data = self.write_buffer
|
||||
# Reset before recursive call, otherwise we'll go through
|
||||
# this branch indefinitely.
|
||||
self.write_buffer = b""
|
||||
try:
|
||||
await self.write(previous_data, timeout=timeout, flag=flag)
|
||||
except WriteTimeout:
|
||||
self.writer_buffer = previous_data
|
||||
raise
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
write_timeout = _or_inf(timeout.write_timeout)
|
||||
|
||||
while True:
|
||||
with trio.move_on_after(write_timeout):
|
||||
async with self.write_lock:
|
||||
await self.stream.send_all(data)
|
||||
break
|
||||
# We check our flag at the first possible moment, in order to
|
||||
# allow us to suppress write timeouts, if we've since
|
||||
# switched over to read-timeout mode.
|
||||
should_raise = flag is None or flag.raise_on_write_timeout
|
||||
if should_raise:
|
||||
raise WriteTimeout() from None
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.stream.aclose()
|
||||
|
||||
|
||||
class PoolSemaphore(BasePoolSemaphore):
|
||||
def __init__(self, pool_limits: PoolLimits):
|
||||
self.pool_limits = pool_limits
|
||||
|
||||
@property
|
||||
def semaphore(self) -> typing.Optional[trio.Semaphore]:
|
||||
if not hasattr(self, "_semaphore"):
|
||||
max_connections = self.pool_limits.hard_limit
|
||||
if max_connections is None:
|
||||
self._semaphore = None
|
||||
else:
|
||||
self._semaphore = trio.Semaphore(
|
||||
max_connections, max_value=max_connections
|
||||
)
|
||||
return self._semaphore
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self.semaphore is None:
|
||||
return
|
||||
|
||||
timeout = _or_inf(self.pool_limits.pool_timeout)
|
||||
|
||||
with trio.move_on_after(timeout):
|
||||
await self.semaphore.acquire()
|
||||
return
|
||||
|
||||
raise PoolTimeout()
|
||||
|
||||
def release(self) -> None:
|
||||
if self.semaphore is None:
|
||||
return
|
||||
|
||||
self.semaphore.release()
|
||||
|
||||
|
||||
class TrioBackend(ConcurrencyBackend):
|
||||
async def open_tcp_stream(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
ssl_context: typing.Optional[ssl.SSLContext],
|
||||
timeout: TimeoutConfig,
|
||||
) -> TCPStream:
|
||||
connect_timeout = _or_inf(timeout.connect_timeout)
|
||||
|
||||
with trio.move_on_after(connect_timeout) as cancel_scope:
|
||||
stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
|
||||
if ssl_context is not None:
|
||||
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
|
||||
await stream.do_handshake()
|
||||
|
||||
if cancel_scope.cancelled_caught:
|
||||
raise ConnectTimeout()
|
||||
|
||||
return TCPStream(stream=stream, timeout=timeout)
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
return await trio.to_thread.run_sync(
|
||||
functools.partial(func, **kwargs) if kwargs else func, *args
|
||||
)
|
||||
|
||||
def run(
|
||||
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
return trio.run(
|
||||
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
|
||||
)
|
||||
|
||||
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
|
||||
return PoolSemaphore(limits)
|
||||
|
||||
def create_queue(self, max_size: int) -> BaseQueue:
|
||||
return Queue(max_size=max_size)
|
||||
|
||||
def create_event(self) -> BaseEvent:
|
||||
return Event()
|
||||
|
||||
def background_manager(
|
||||
self, coroutine: typing.Callable, *args: typing.Any
|
||||
) -> "BackgroundManager":
|
||||
return BackgroundManager(coroutine, *args)
|
||||
|
||||
|
||||
class Queue(BaseQueue):
|
||||
def __init__(self, max_size: int) -> None:
|
||||
self.send_channel, self.receive_channel = trio.open_memory_channel(math.inf)
|
||||
|
||||
async def get(self) -> typing.Any:
|
||||
return await self.receive_channel.receive()
|
||||
|
||||
async def put(self, value: typing.Any) -> None:
|
||||
await self.send_channel.send(value)
|
||||
|
||||
|
||||
class Event(BaseEvent):
|
||||
def __init__(self) -> None:
|
||||
self._event = trio.Event()
|
||||
|
||||
def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._event.is_set()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
def clear(self) -> None:
|
||||
# trio.Event.clear() was deprecated in Trio 0.12.
|
||||
# https://github.com/python-trio/trio/issues/637
|
||||
self._event = trio.Event()
|
||||
|
||||
|
||||
class BackgroundManager(BaseBackgroundManager):
|
||||
def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None:
|
||||
self.coroutine = coroutine
|
||||
self.args = args
|
||||
self.nursery_manager = trio.open_nursery()
|
||||
self.nursery: typing.Optional[trio.Nursery] = None
|
||||
|
||||
async def __aenter__(self) -> "BackgroundManager":
|
||||
self.nursery = await self.nursery_manager.__aenter__()
|
||||
self.nursery.start_soon(self.coroutine, *self.args)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
assert self.nursery is not None
|
||||
await self.nursery_manager.__aexit__(exc_type, exc_value, traceback)
|
||||
@ -130,6 +130,7 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
await response_started_or_failed.wait()
|
||||
|
||||
if app_exc is not None and self.raise_app_exceptions:
|
||||
await background.close(app_exc)
|
||||
raise app_exc
|
||||
|
||||
assert status_code is not None, "application did not return a response."
|
||||
@ -138,7 +139,7 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
async def on_close() -> None:
|
||||
nonlocal response_body
|
||||
await response_body.drain()
|
||||
await background.__aexit__(None, None, None)
|
||||
await background.close(app_exc)
|
||||
if app_exc is not None and self.raise_app_exceptions:
|
||||
raise app_exc
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import h2.events
|
||||
|
||||
from ..concurrency.base import BaseEvent, BaseTCPStream, ConcurrencyBackend, TimeoutFlag
|
||||
from ..config import TimeoutConfig, TimeoutTypes
|
||||
from ..exceptions import ProtocolError
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
from ..utils import get_logger
|
||||
|
||||
@ -187,6 +188,10 @@ class HTTP2Connection:
|
||||
logger.debug(
|
||||
f"receive_event stream_id={event_stream_id} event={event!r}"
|
||||
)
|
||||
|
||||
if hasattr(event, "error_code"):
|
||||
raise ProtocolError(event)
|
||||
|
||||
if isinstance(event, h2.events.WindowUpdated):
|
||||
if event_stream_id == 0:
|
||||
for window_update_event in self.window_update_received.values():
|
||||
|
||||
@ -11,7 +11,7 @@ combine_as_imports = True
|
||||
force_grid_wrap = 0
|
||||
include_trailing_comma = True
|
||||
known_first_party = httpx,tests
|
||||
known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,nox,pytest,requests,rfc3986,setuptools,trustme,uvicorn
|
||||
known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,nox,pytest,requests,rfc3986,setuptools,trio,trustme,uvicorn
|
||||
line_length = 88
|
||||
multi_line_output = 3
|
||||
|
||||
|
||||
1
setup.py
1
setup.py
@ -59,6 +59,7 @@ setup(
|
||||
"idna==2.*",
|
||||
"rfc3986==1.*",
|
||||
],
|
||||
extras_require={"trio": ["trio"]},
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Environment :: Web Environment",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
-e .
|
||||
-e .[trio]
|
||||
|
||||
# Optional
|
||||
brotlipy==0.7.*
|
||||
@ -11,6 +11,7 @@ isort
|
||||
mypy
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-trio
|
||||
pytest-cov
|
||||
trustme
|
||||
uvicorn
|
||||
|
||||
@ -17,3 +17,15 @@ async def sleep(backend, seconds: int):
|
||||
@sleep.register(AsyncioBackend)
|
||||
async def _sleep_asyncio(backend, seconds: int):
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
|
||||
try:
|
||||
import trio
|
||||
from httpx.concurrency.trio import TrioBackend
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
else:
|
||||
|
||||
@sleep.register(TrioBackend)
|
||||
async def _sleep_trio(backend, seconds: int):
|
||||
await trio.sleep(seconds)
|
||||
|
||||
@ -47,7 +47,17 @@ def clean_environ() -> typing.Dict[str, typing.Any]:
|
||||
os.environ.update(original_environ)
|
||||
|
||||
|
||||
@pytest.fixture(params=[pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)])
|
||||
backend_params = [pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)]
|
||||
|
||||
try:
|
||||
from httpx.concurrency.trio import TrioBackend
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
else:
|
||||
backend_params.append(pytest.param(TrioBackend, marks=pytest.mark.trio))
|
||||
|
||||
|
||||
@pytest.fixture(params=backend_params)
|
||||
def backend(request):
|
||||
backend_cls = request.param
|
||||
return backend_cls()
|
||||
|
||||
@ -168,7 +168,9 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart, back
|
||||
Verify that max_connections semaphore is released
|
||||
properly on a disconnected connection.
|
||||
"""
|
||||
async with httpx.ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1)) as http:
|
||||
async with httpx.ConnectionPool(
|
||||
pool_limits=httpx.PoolLimits(hard_limit=1), backend=backend
|
||||
) as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.read()
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ async def test_connect_timeout(server, backend):
|
||||
|
||||
|
||||
async def test_pool_timeout(server, backend):
|
||||
pool_limits = PoolLimits(hard_limit=1, pool_timeout=1e-6)
|
||||
pool_limits = PoolLimits(hard_limit=1, pool_timeout=1e-4)
|
||||
|
||||
async with AsyncClient(pool_limits=pool_limits, backend=backend) as client:
|
||||
response = await client.get(server.url, stream=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user