Drop run and run_in_threadpool (#710)
* Drop run and run_in_threadpool * Fix server restart errors * Re-introduce 'sleep' as a concurrency test utility * Simpler test concurrency utils Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
This commit is contained in:
parent
bd57b650a8
commit
6ac49dacdd
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
@ -182,15 +181,6 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
ssl_monkey_patch()
|
||||
SSL_MONKEY_PATCH_APPLIED = True
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
if not hasattr(self, "_loop"):
|
||||
try:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
return self._loop
|
||||
|
||||
async def open_tcp_stream(
|
||||
self,
|
||||
hostname: str,
|
||||
@ -233,25 +223,6 @@ class AsyncioBackend(ConcurrencyBackend):
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.time()
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
if kwargs:
|
||||
# loop.run_in_executor doesn't accept 'kwargs', so bind them in here
|
||||
func = functools.partial(func, **kwargs)
|
||||
return await self.loop.run_in_executor(None, func, *args)
|
||||
|
||||
def run(
|
||||
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
loop = self.loop
|
||||
if loop.is_running():
|
||||
self._loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return self.loop.run_until_complete(coroutine(*args, **kwargs))
|
||||
finally:
|
||||
self._loop = loop
|
||||
|
||||
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
|
||||
return Semaphore(max_value, exc_class)
|
||||
|
||||
|
||||
@ -44,11 +44,6 @@ class AutoBackend(ConcurrencyBackend):
|
||||
def time(self) -> float:
|
||||
return self.backend.time()
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
return await self.backend.run_in_threadpool(func, *args, **kwargs)
|
||||
|
||||
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
|
||||
return self.backend.create_semaphore(max_value, exc_class)
|
||||
|
||||
|
||||
@ -114,16 +114,6 @@ class ConcurrencyBackend:
|
||||
def time(self) -> float:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def run(
|
||||
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
@ -120,20 +119,6 @@ class TrioBackend(ConcurrencyBackend):
|
||||
|
||||
raise ConnectTimeout()
|
||||
|
||||
async def run_in_threadpool(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
return await trio.to_thread.run_sync(
|
||||
functools.partial(func, **kwargs) if kwargs else func, *args
|
||||
)
|
||||
|
||||
def run(
|
||||
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
return trio.run(
|
||||
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
|
||||
)
|
||||
|
||||
def time(self) -> float:
|
||||
return trio.current_time()
|
||||
|
||||
|
||||
@ -4,54 +4,34 @@ required as part of the ConcurrencyBackend API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import typing
|
||||
|
||||
import sniffio
|
||||
import trio
|
||||
|
||||
from httpx.backends.asyncio import AsyncioBackend
|
||||
from httpx.backends.auto import AutoBackend
|
||||
from httpx.backends.trio import TrioBackend
|
||||
|
||||
async def sleep(seconds: float):
|
||||
if sniffio.current_async_library() == "trio":
|
||||
await trio.sleep(seconds)
|
||||
else:
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awaitable]):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
async def run_concurrently(*coroutines):
|
||||
if sniffio.current_async_library() == "trio":
|
||||
async with trio.open_nursery() as nursery:
|
||||
for coroutine in coroutines:
|
||||
nursery.start_soon(coroutine)
|
||||
else:
|
||||
coros = (coroutine() for coroutine in coroutines)
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
|
||||
@run_concurrently.register(AutoBackend)
|
||||
async def _run_concurrently_auto(backend, *coroutines):
|
||||
await run_concurrently(backend.backend, *coroutines)
|
||||
|
||||
|
||||
@run_concurrently.register(AsyncioBackend)
|
||||
async def _run_concurrently_asyncio(backend, *coroutines):
|
||||
coros = (coroutine() for coroutine in coroutines)
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
|
||||
@run_concurrently.register(TrioBackend)
|
||||
async def _run_concurrently_trio(backend, *coroutines):
|
||||
async with trio.open_nursery() as nursery:
|
||||
for coroutine in coroutines:
|
||||
nursery.start_soon(coroutine)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def get_cipher(backend, stream):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
@get_cipher.register(AutoBackend)
|
||||
def _get_cipher_auto(backend, stream):
|
||||
return get_cipher(backend.backend, stream)
|
||||
|
||||
|
||||
@get_cipher.register(AsyncioBackend)
|
||||
def _get_cipher_asyncio(backend, stream):
|
||||
return stream.stream_writer.get_extra_info("cipher", default=None)
|
||||
|
||||
|
||||
@get_cipher.register(TrioBackend)
|
||||
def get_trio_cipher(backend, stream):
|
||||
return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
|
||||
def get_cipher(stream):
|
||||
if sniffio.current_async_library() == "trio":
|
||||
return (
|
||||
stream.stream.cipher()
|
||||
if isinstance(stream.stream, trio.SSLStream)
|
||||
else None
|
||||
)
|
||||
else:
|
||||
return stream.stream_writer.get_extra_info("cipher", default=None)
|
||||
|
||||
@ -18,8 +18,7 @@ from uvicorn.config import Config
|
||||
from uvicorn.main import Server
|
||||
|
||||
from httpx import URL
|
||||
from httpx.backends.asyncio import AsyncioBackend
|
||||
from httpx.backends.base import lookup_backend
|
||||
from tests.concurrency import sleep
|
||||
|
||||
ENVIRONMENT_VARIABLES = {
|
||||
"SSL_CERT_FILE",
|
||||
@ -108,7 +107,7 @@ async def slow_response(scope, receive, send):
|
||||
delay_ms = float(delay_ms_str)
|
||||
except ValueError:
|
||||
delay_ms = 100
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
await sleep(delay_ms / 1000.0)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
@ -252,15 +251,15 @@ class TestServer(Server):
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
async def restart(self) -> None:
|
||||
# Ensure we are in an asyncio environment.
|
||||
assert asyncio.get_event_loop() is not None
|
||||
# This may be called from a different thread than the one the server is
|
||||
# running on. For this reason, we use an event to coordinate with the server
|
||||
# instead of calling shutdown()/startup() directly.
|
||||
self.restart_requested.set()
|
||||
# This coroutine may be called from a different thread than the one the
|
||||
# server is running on, and from an async environment that's not asyncio.
|
||||
# For this reason, we use an event to coordinate with the server
|
||||
# instead of calling shutdown()/startup() directly, and should not make
|
||||
# any asyncio-specific operations.
|
||||
self.started = False
|
||||
self.restart_requested.set()
|
||||
while not self.started:
|
||||
await asyncio.sleep(0.5)
|
||||
await sleep(0.2)
|
||||
|
||||
async def watch_restarts(self):
|
||||
while True:
|
||||
@ -277,22 +276,6 @@ class TestServer(Server):
|
||||
await self.startup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restart():
|
||||
"""Restart the running server from an async test function.
|
||||
|
||||
This fixture deals with possible differences between the environment of the
|
||||
test function and that of the server.
|
||||
"""
|
||||
asyncio_backend = AsyncioBackend()
|
||||
|
||||
async def restart(server):
|
||||
backend = lookup_backend()
|
||||
await backend.run_in_threadpool(asyncio_backend.run, server.restart)
|
||||
|
||||
return restart
|
||||
|
||||
|
||||
def serve_in_thread(server: Server):
|
||||
thread = threading.Thread(target=server.run)
|
||||
thread.start()
|
||||
|
||||
@ -167,7 +167,7 @@ async def test_premature_response_close(server):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_keepalive_connection_closed_by_server_is_reestablished(server, restart):
|
||||
async def test_keepalive_connection_closed_by_server_is_reestablished(server):
|
||||
"""
|
||||
Upon keep-alive connection closed by remote a new connection
|
||||
should be reestablished.
|
||||
@ -177,7 +177,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(server, re
|
||||
await response.aread()
|
||||
|
||||
# Shutdown the server to close the keep-alive connection
|
||||
await restart(server)
|
||||
await server.restart()
|
||||
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
@ -186,9 +186,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(server, re
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
|
||||
server, restart
|
||||
):
|
||||
async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
|
||||
"""
|
||||
Upon keep-alive connection closed by remote a new connection
|
||||
should be reestablished.
|
||||
@ -198,7 +196,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
|
||||
await response.aread()
|
||||
|
||||
# Shutdown the server to close the keep-alive connection
|
||||
await restart(server)
|
||||
await server.restart()
|
||||
|
||||
response = await http.request("GET", server.url)
|
||||
await response.aread()
|
||||
@ -207,7 +205,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_connection_closed_free_semaphore_on_acquire(server, restart):
|
||||
async def test_connection_closed_free_semaphore_on_acquire(server):
|
||||
"""
|
||||
Verify that max_connections semaphore is released
|
||||
properly on a disconnected connection.
|
||||
@ -217,7 +215,7 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart):
|
||||
await response.aread()
|
||||
|
||||
# Close the connection so we're forced to recycle it
|
||||
await restart(server)
|
||||
await server.restart()
|
||||
|
||||
response = await http.request("GET", server.url)
|
||||
assert response.status_code == 200
|
||||
|
||||
@ -41,11 +41,11 @@ async def test_start_tls_on_tcp_socket_stream(https_server):
|
||||
|
||||
try:
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(backend, stream) is None
|
||||
assert get_cipher(stream) is None
|
||||
|
||||
stream = await stream.start_tls(https_server.url.host, ctx, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(backend, stream) is not None
|
||||
assert get_cipher(stream) is not None
|
||||
|
||||
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
|
||||
|
||||
@ -68,11 +68,11 @@ async def test_start_tls_on_uds_socket_stream(https_uds_server):
|
||||
|
||||
try:
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(backend, stream) is None
|
||||
assert get_cipher(stream) is None
|
||||
|
||||
stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(backend, stream) is not None
|
||||
assert get_cipher(stream) is not None
|
||||
|
||||
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
|
||||
|
||||
@ -96,7 +96,7 @@ async def test_concurrent_read(server):
|
||||
try:
|
||||
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
|
||||
await run_concurrently(
|
||||
backend, lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
|
||||
lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
|
||||
)
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user