Drop run and run_in_threadpool (#710)

* Drop run and run_in_threadpool

* Fix server restart errors

* Re-introduce 'sleep' as a concurrency test utility

* Simpler test concurrency utils

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
This commit is contained in:
Tom Christie 2020-01-06 11:14:43 +00:00 committed by GitHub
parent bd57b650a8
commit 6ac49dacdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 44 additions and 142 deletions

View File

@ -1,5 +1,4 @@
import asyncio
import functools
import ssl
import typing
@ -182,15 +181,6 @@ class AsyncioBackend(ConcurrencyBackend):
ssl_monkey_patch()
SSL_MONKEY_PATCH_APPLIED = True
@property
def loop(self) -> asyncio.AbstractEventLoop:
if not hasattr(self, "_loop"):
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
return self._loop
async def open_tcp_stream(
self,
hostname: str,
@ -233,25 +223,6 @@ class AsyncioBackend(ConcurrencyBackend):
loop = asyncio.get_event_loop()
return loop.time()
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
if kwargs:
# loop.run_in_executor doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await self.loop.run_in_executor(None, func, *args)
def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
loop = self.loop
if loop.is_running():
self._loop = asyncio.new_event_loop()
try:
return self.loop.run_until_complete(coroutine(*args, **kwargs))
finally:
self._loop = loop
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
return Semaphore(max_value, exc_class)

View File

@ -44,11 +44,6 @@ class AutoBackend(ConcurrencyBackend):
def time(self) -> float:
return self.backend.time()
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return await self.backend.run_in_threadpool(func, *args, **kwargs)
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
return self.backend.create_semaphore(max_value, exc_class)

View File

@ -114,16 +114,6 @@ class ConcurrencyBackend:
def time(self) -> float:
raise NotImplementedError() # pragma: no cover
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
raise NotImplementedError() # pragma: no cover
def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
raise NotImplementedError() # pragma: no cover
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
raise NotImplementedError() # pragma: no cover

View File

@ -1,4 +1,3 @@
import functools
import ssl
import typing
@ -120,20 +119,6 @@ class TrioBackend(ConcurrencyBackend):
raise ConnectTimeout()
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return await trio.to_thread.run_sync(
functools.partial(func, **kwargs) if kwargs else func, *args
)
def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return trio.run(
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
)
def time(self) -> float:
return trio.current_time()

View File

@ -4,54 +4,34 @@ required as part of the ConcurrencyBackend API.
"""
import asyncio
import functools
import typing
import sniffio
import trio
from httpx.backends.asyncio import AsyncioBackend
from httpx.backends.auto import AutoBackend
from httpx.backends.trio import TrioBackend
async def sleep(seconds: float):
if sniffio.current_async_library() == "trio":
await trio.sleep(seconds)
else:
await asyncio.sleep(seconds)
@functools.singledispatch
async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awaitable]):
raise NotImplementedError # pragma: no cover
async def run_concurrently(*coroutines):
if sniffio.current_async_library() == "trio":
async with trio.open_nursery() as nursery:
for coroutine in coroutines:
nursery.start_soon(coroutine)
else:
coros = (coroutine() for coroutine in coroutines)
await asyncio.gather(*coros)
@run_concurrently.register(AutoBackend)
async def _run_concurrently_auto(backend, *coroutines):
await run_concurrently(backend.backend, *coroutines)
@run_concurrently.register(AsyncioBackend)
async def _run_concurrently_asyncio(backend, *coroutines):
coros = (coroutine() for coroutine in coroutines)
await asyncio.gather(*coros)
@run_concurrently.register(TrioBackend)
async def _run_concurrently_trio(backend, *coroutines):
async with trio.open_nursery() as nursery:
for coroutine in coroutines:
nursery.start_soon(coroutine)
@functools.singledispatch
def get_cipher(backend, stream):
raise NotImplementedError # pragma: no cover
@get_cipher.register(AutoBackend)
def _get_cipher_auto(backend, stream):
return get_cipher(backend.backend, stream)
@get_cipher.register(AsyncioBackend)
def _get_cipher_asyncio(backend, stream):
return stream.stream_writer.get_extra_info("cipher", default=None)
@get_cipher.register(TrioBackend)
def get_trio_cipher(backend, stream):
return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
def get_cipher(stream):
if sniffio.current_async_library() == "trio":
return (
stream.stream.cipher()
if isinstance(stream.stream, trio.SSLStream)
else None
)
else:
return stream.stream_writer.get_extra_info("cipher", default=None)

View File

@ -18,8 +18,7 @@ from uvicorn.config import Config
from uvicorn.main import Server
from httpx import URL
from httpx.backends.asyncio import AsyncioBackend
from httpx.backends.base import lookup_backend
from tests.concurrency import sleep
ENVIRONMENT_VARIABLES = {
"SSL_CERT_FILE",
@ -108,7 +107,7 @@ async def slow_response(scope, receive, send):
delay_ms = float(delay_ms_str)
except ValueError:
delay_ms = 100
await asyncio.sleep(delay_ms / 1000.0)
await sleep(delay_ms / 1000.0)
await send(
{
"type": "http.response.start",
@ -252,15 +251,15 @@ class TestServer(Server):
await asyncio.wait(tasks)
async def restart(self) -> None:
# Ensure we are in an asyncio environment.
assert asyncio.get_event_loop() is not None
# This may be called from a different thread than the one the server is
# running on. For this reason, we use an event to coordinate with the server
# instead of calling shutdown()/startup() directly.
self.restart_requested.set()
# This coroutine may be called from a different thread than the one the
# server is running on, and from an async environment that's not asyncio.
# For this reason, we use an event to coordinate with the server
# instead of calling shutdown()/startup() directly, and should not make
# any asyncio-specific operations.
self.started = False
self.restart_requested.set()
while not self.started:
await asyncio.sleep(0.5)
await sleep(0.2)
async def watch_restarts(self):
while True:
@ -277,22 +276,6 @@ class TestServer(Server):
await self.startup()
@pytest.fixture
def restart():
"""Restart the running server from an async test function.
This fixture deals with possible differences between the environment of the
test function and that of the server.
"""
asyncio_backend = AsyncioBackend()
async def restart(server):
backend = lookup_backend()
await backend.run_in_threadpool(asyncio_backend.run, server.restart)
return restart
def serve_in_thread(server: Server):
thread = threading.Thread(target=server.run)
thread.start()

View File

@ -167,7 +167,7 @@ async def test_premature_response_close(server):
@pytest.mark.usefixtures("async_environment")
async def test_keepalive_connection_closed_by_server_is_reestablished(server, restart):
async def test_keepalive_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection
should be reestablished.
@ -177,7 +177,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(server, re
await response.aread()
# Shutdown the server to close the keep-alive connection
await restart(server)
await server.restart()
response = await http.request("GET", server.url)
await response.aread()
@ -186,9 +186,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(server, re
@pytest.mark.usefixtures("async_environment")
async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
server, restart
):
async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection
should be reestablished.
@ -198,7 +196,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
await response.aread()
# Shutdown the server to close the keep-alive connection
await restart(server)
await server.restart()
response = await http.request("GET", server.url)
await response.aread()
@ -207,7 +205,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
@pytest.mark.usefixtures("async_environment")
async def test_connection_closed_free_semaphore_on_acquire(server, restart):
async def test_connection_closed_free_semaphore_on_acquire(server):
"""
Verify that max_connections semaphore is released
properly on a disconnected connection.
@ -217,7 +215,7 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart):
await response.aread()
# Close the connection so we're forced to recycle it
await restart(server)
await server.restart()
response = await http.request("GET", server.url)
assert response.status_code == 200

View File

@ -41,11 +41,11 @@ async def test_start_tls_on_tcp_socket_stream(https_server):
try:
assert stream.is_connection_dropped() is False
assert get_cipher(backend, stream) is None
assert get_cipher(stream) is None
stream = await stream.start_tls(https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
assert get_cipher(backend, stream) is not None
assert get_cipher(stream) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
@ -68,11 +68,11 @@ async def test_start_tls_on_uds_socket_stream(https_uds_server):
try:
assert stream.is_connection_dropped() is False
assert get_cipher(backend, stream) is None
assert get_cipher(stream) is None
stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
assert get_cipher(backend, stream) is not None
assert get_cipher(stream) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
@ -96,7 +96,7 @@ async def test_concurrent_read(server):
try:
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
await run_concurrently(
backend, lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
)
finally:
await stream.close()