Add trio concurrency backend (#276)

This commit is contained in:
Florimond Manca 2019-09-21 18:10:20 +02:00 committed by Seth Michael Larson
parent 12752466ae
commit 08355c62f5
13 changed files with 315 additions and 7 deletions

View File

@ -81,7 +81,7 @@ class BaseClient:
if param_count == 2:
dispatch = WSGIDispatch(app=app)
else:
dispatch = ASGIDispatch(app=app)
dispatch = ASGIDispatch(app=app, backend=backend)
self.trust_env = True if trust_env is None else trust_env

View File

@ -112,6 +112,20 @@ class TCPStream(BaseTCPStream):
raise WriteTimeout() from None
def is_connection_dropped(self) -> bool:
# Counter-intuitively, what we really want to know here is whether the socket is
# *readable*, i.e. whether it would return immediately with empty bytes if we
# called `.recv()` on it, indicating that the other end has closed the socket.
# See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
#
# As it turns out, asyncio checks for readability in the background
# (see: https://github.com/encode/httpx/pull/276#discussion_r322000402),
# so checking for EOF or readability here would yield the same result.
#
# At the cost of rigour, we check for EOF instead of readability because asyncio
# does not expose any public API to check for readability.
# (For a solution that uses private asyncio APIs, see:
# https://github.com/encode/httpx/pull/143#issuecomment-515202982)
return self.stream_reader.at_eof()
async def close(self) -> None:

View File

@ -187,3 +187,10 @@ class BaseBackgroundManager:
traceback: TracebackType = None,
) -> None:
raise NotImplementedError() # pragma: no cover
async def close(self, exception: BaseException = None) -> None:
if exception is None:
await self.__aexit__(None, None, None)
else:
traceback = exception.__traceback__ # type: ignore
await self.__aexit__(type(exception), exception, traceback)

255
httpx/concurrency/trio.py Normal file
View File

@ -0,0 +1,255 @@
import functools
import math
import ssl
import typing
from types import TracebackType
import trio
from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseQueue,
BaseTCPStream,
ConcurrencyBackend,
TimeoutFlag,
)
def _or_inf(value: typing.Optional[float]) -> float:
return value if value is not None else float("inf")
class TCPStream(BaseTCPStream):
def __init__(
self,
stream: typing.Union[trio.SocketStream, trio.SSLStream],
timeout: TimeoutConfig,
) -> None:
self.stream = stream
self.timeout = timeout
self.write_buffer = b""
self.write_lock = trio.Lock()
def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
return "HTTP/1.1"
ident = self.stream.selected_alpn_protocol()
if ident is None:
return "HTTP/1.1"
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
async def read(
self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
) -> bytes:
if timeout is None:
timeout = self.timeout
while True:
# Check our flag at the first possible moment, and use a fine
# grained retry loop if we're not yet in read-timeout mode.
should_raise = flag is None or flag.raise_on_read_timeout
read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01)
with trio.move_on_after(read_timeout):
return await self.stream.receive_some(max_bytes=n)
if should_raise:
raise ReadTimeout() from None
def is_connection_dropped(self) -> bool:
# Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982
stream = self.stream
# Peek through any SSLStream wrappers to get the underlying SocketStream.
while hasattr(stream, "transport_stream"):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)
# Counter-intuitively, what we really want to know here is whether the socket is
# *readable*, i.e. whether it would return immediately with empty bytes if we
# called `.recv()` on it, indicating that the other end has closed the socket.
# See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
return stream.socket.is_readable()
def write_no_block(self, data: bytes) -> None:
self.write_buffer += data
async def write(
self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
) -> None:
if self.write_buffer:
previous_data = self.write_buffer
# Reset before recursive call, otherwise we'll go through
# this branch indefinitely.
self.write_buffer = b""
try:
await self.write(previous_data, timeout=timeout, flag=flag)
except WriteTimeout:
self.writer_buffer = previous_data
raise
if not data:
return
if timeout is None:
timeout = self.timeout
write_timeout = _or_inf(timeout.write_timeout)
while True:
with trio.move_on_after(write_timeout):
async with self.write_lock:
await self.stream.send_all(data)
break
# We check our flag at the first possible moment, in order to
# allow us to suppress write timeouts, if we've since
# switched over to read-timeout mode.
should_raise = flag is None or flag.raise_on_write_timeout
if should_raise:
raise WriteTimeout() from None
async def close(self) -> None:
await self.stream.aclose()
class PoolSemaphore(BasePoolSemaphore):
def __init__(self, pool_limits: PoolLimits):
self.pool_limits = pool_limits
@property
def semaphore(self) -> typing.Optional[trio.Semaphore]:
if not hasattr(self, "_semaphore"):
max_connections = self.pool_limits.hard_limit
if max_connections is None:
self._semaphore = None
else:
self._semaphore = trio.Semaphore(
max_connections, max_value=max_connections
)
return self._semaphore
async def acquire(self) -> None:
if self.semaphore is None:
return
timeout = _or_inf(self.pool_limits.pool_timeout)
with trio.move_on_after(timeout):
await self.semaphore.acquire()
return
raise PoolTimeout()
def release(self) -> None:
if self.semaphore is None:
return
self.semaphore.release()
class TrioBackend(ConcurrencyBackend):
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> TCPStream:
connect_timeout = _or_inf(timeout.connect_timeout)
with trio.move_on_after(connect_timeout) as cancel_scope:
stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
await stream.do_handshake()
if cancel_scope.cancelled_caught:
raise ConnectTimeout()
return TCPStream(stream=stream, timeout=timeout)
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return await trio.to_thread.run_sync(
functools.partial(func, **kwargs) if kwargs else func, *args
)
def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return trio.run(
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
)
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)
def create_queue(self, max_size: int) -> BaseQueue:
return Queue(max_size=max_size)
def create_event(self) -> BaseEvent:
return Event()
def background_manager(
self, coroutine: typing.Callable, *args: typing.Any
) -> "BackgroundManager":
return BackgroundManager(coroutine, *args)
class Queue(BaseQueue):
def __init__(self, max_size: int) -> None:
self.send_channel, self.receive_channel = trio.open_memory_channel(math.inf)
async def get(self) -> typing.Any:
return await self.receive_channel.receive()
async def put(self, value: typing.Any) -> None:
await self.send_channel.send(value)
class Event(BaseEvent):
def __init__(self) -> None:
self._event = trio.Event()
def set(self) -> None:
self._event.set()
def is_set(self) -> bool:
return self._event.is_set()
async def wait(self) -> None:
await self._event.wait()
def clear(self) -> None:
# trio.Event.clear() was deprecated in Trio 0.12.
# https://github.com/python-trio/trio/issues/637
self._event = trio.Event()
class BackgroundManager(BaseBackgroundManager):
def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None:
self.coroutine = coroutine
self.args = args
self.nursery_manager = trio.open_nursery()
self.nursery: typing.Optional[trio.Nursery] = None
async def __aenter__(self) -> "BackgroundManager":
self.nursery = await self.nursery_manager.__aenter__()
self.nursery.start_soon(self.coroutine, *self.args)
return self
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
assert self.nursery is not None
await self.nursery_manager.__aexit__(exc_type, exc_value, traceback)

View File

@ -130,6 +130,7 @@ class ASGIDispatch(AsyncDispatcher):
await response_started_or_failed.wait()
if app_exc is not None and self.raise_app_exceptions:
await background.close(app_exc)
raise app_exc
assert status_code is not None, "application did not return a response."
@ -138,7 +139,7 @@ class ASGIDispatch(AsyncDispatcher):
async def on_close() -> None:
nonlocal response_body
await response_body.drain()
await background.__aexit__(None, None, None)
await background.close(app_exc)
if app_exc is not None and self.raise_app_exceptions:
raise app_exc

View File

@ -6,6 +6,7 @@ import h2.events
from ..concurrency.base import BaseEvent, BaseTCPStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import ProtocolError
from ..models import AsyncRequest, AsyncResponse
from ..utils import get_logger
@ -187,6 +188,10 @@ class HTTP2Connection:
logger.debug(
f"receive_event stream_id={event_stream_id} event={event!r}"
)
if hasattr(event, "error_code"):
raise ProtocolError(event)
if isinstance(event, h2.events.WindowUpdated):
if event_stream_id == 0:
for window_update_event in self.window_update_received.values():

View File

@ -11,7 +11,7 @@ combine_as_imports = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = httpx,tests
known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,nox,pytest,requests,rfc3986,setuptools,trustme,uvicorn
known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,nox,pytest,requests,rfc3986,setuptools,trio,trustme,uvicorn
line_length = 88
multi_line_output = 3

View File

@ -59,6 +59,7 @@ setup(
"idna==2.*",
"rfc3986==1.*",
],
extras_require={"trio": ["trio"]},
classifiers=[
"Development Status :: 3 - Alpha",
"Environment :: Web Environment",

View File

@ -1,4 +1,4 @@
-e .
-e .[trio]
# Optional
brotlipy==0.7.*
@ -11,6 +11,7 @@ isort
mypy
pytest
pytest-asyncio
pytest-trio
pytest-cov
trustme
uvicorn

View File

@ -17,3 +17,15 @@ async def sleep(backend, seconds: int):
@sleep.register(AsyncioBackend)
async def _sleep_asyncio(backend, seconds: int):
await asyncio.sleep(seconds)
try:
import trio
from httpx.concurrency.trio import TrioBackend
except ImportError: # pragma: no cover
pass
else:
@sleep.register(TrioBackend)
async def _sleep_trio(backend, seconds: int):
await trio.sleep(seconds)

View File

@ -47,7 +47,17 @@ def clean_environ() -> typing.Dict[str, typing.Any]:
os.environ.update(original_environ)
@pytest.fixture(params=[pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)])
backend_params = [pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)]
try:
from httpx.concurrency.trio import TrioBackend
except ImportError: # pragma: no cover
pass
else:
backend_params.append(pytest.param(TrioBackend, marks=pytest.mark.trio))
@pytest.fixture(params=backend_params)
def backend(request):
backend_cls = request.param
return backend_cls()

View File

@ -168,7 +168,9 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart, back
Verify that max_connections semaphore is released
properly on a disconnected connection.
"""
async with httpx.ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1)) as http:
async with httpx.ConnectionPool(
pool_limits=httpx.PoolLimits(hard_limit=1), backend=backend
) as http:
response = await http.request("GET", server.url)
await response.read()

View File

@ -38,7 +38,7 @@ async def test_connect_timeout(server, backend):
async def test_pool_timeout(server, backend):
pool_limits = PoolLimits(hard_limit=1, pool_timeout=1e-6)
pool_limits = PoolLimits(hard_limit=1, pool_timeout=1e-4)
async with AsyncClient(pool_limits=pool_limits, backend=backend) as client:
response = await client.get(server.url, stream=True)