httpx/tests/concurrency.py
Florimond Manca ab41a5d5c3
Refactor tests in the light of backend auto-detection (#615)
* Refactor tests in the light of backend auto-detection

* Test passing explicit backend separately

* Drop 'backend=backend'

* Fix usage of asyncio.run() on 3.6
2019-12-07 15:17:35 +01:00

78 lines
2.1 KiB
Python

"""
This module contains concurrency utilities that are only used in tests, thus not
required as part of the ConcurrencyBackend API.
"""
import asyncio
import functools
import typing
import trio
from httpx.concurrency.asyncio import AsyncioBackend
from httpx.concurrency.auto import AutoBackend
from httpx.concurrency.trio import TrioBackend
@functools.singledispatch
async def sleep(backend, seconds: int):
raise NotImplementedError # pragma: no cover
@sleep.register(AutoBackend)
async def _sleep_auto(backend, seconds: int):
return await sleep(backend.backend, seconds=seconds)
@sleep.register(AsyncioBackend)
async def _sleep_asyncio(backend, seconds: int):
await asyncio.sleep(seconds)
@sleep.register(TrioBackend)
async def _sleep_trio(backend, seconds: int):
await trio.sleep(seconds)
@functools.singledispatch
async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awaitable]):
raise NotImplementedError # pragma: no cover
@run_concurrently.register(AutoBackend)
async def _run_concurrently_auto(backend, *coroutines):
await run_concurrently(backend.backend, *coroutines)
@run_concurrently.register(AsyncioBackend)
async def _run_concurrently_asyncio(backend, *coroutines):
coros = (coroutine() for coroutine in coroutines)
await asyncio.gather(*coros)
@run_concurrently.register(TrioBackend)
async def _run_concurrently_trio(backend, *coroutines):
async with trio.open_nursery() as nursery:
for coroutine in coroutines:
nursery.start_soon(coroutine)
@functools.singledispatch
def get_cipher(backend, stream):
raise NotImplementedError # pragma: no cover
@get_cipher.register(AutoBackend)
def _get_cipher_auto(backend, stream):
return get_cipher(backend.backend, stream)
@get_cipher.register(AsyncioBackend)
def _get_cipher_asyncio(backend, stream):
return stream.stream_writer.get_extra_info("cipher", default=None)
@get_cipher.register(TrioBackend)
def get_trio_cipher(backend, stream):
return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None