Enforce that sync client uses asyncio-based backend (#232)

This commit is contained in:
Florimond Manca 2019-08-18 16:41:37 +02:00 committed by Seth Michael Larson
parent caba78568e
commit 9c7d23ce29
2 changed files with 34 additions and 0 deletions

View File

@ -70,6 +70,8 @@ class BaseClient:
if backend is None:
backend = AsyncioBackend()
self.check_concurrency_backend(backend)
if app is not None:
param_count = len(inspect.signature(app).parameters)
assert param_count in (2, 3)
@ -108,6 +110,9 @@ class BaseClient:
self.concurrency_backend = backend
self.trust_env = True if trust_env is None else trust_env
def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
pass # pragma: no cover
def merge_url(self, url: URLTypes) -> URL:
url = self.base_url.join(relative_url=url)
if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
@ -623,6 +628,19 @@ class AsyncClient(BaseClient):
class Client(BaseClient):
def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
# Iterating over response content allocates an async environment on each step.
# This is relatively cheap on asyncio, but cannot be guaranteed for all
# concurrency backends.
# The sync client performs I/O on its own, so it doesn't need to support
# arbitrary concurrency backends.
# Therefore, we kept the `backend` parameter (for testing/mocking), but enforce
# that the concurrency backend derives from the asyncio one.
if not isinstance(backend, AsyncioBackend):
raise ValueError(
"'Client' only supports asyncio-based concurrency backends"
)
def _async_request_data(
self, data: RequestData = None
) -> typing.Optional[AsyncRequestData]:

View File

@ -158,3 +158,19 @@ def test_merge_url():
assert url.scheme == "https"
assert url.is_ssl
class DerivedFromAsyncioBackend(httpx.AsyncioBackend):
pass
class AnyBackend:
pass
def test_client_backend_must_be_asyncio_based():
httpx.Client(backend=httpx.AsyncioBackend())
httpx.Client(backend=DerivedFromAsyncioBackend())
with pytest.raises(ValueError):
httpx.Client(backend=AnyBackend())