Refactor tests in the light of backend auto-detection (#615)
* Refactor tests in the light of backend auto-detection * Test passing explicit backend separately * Drop 'backend=backend' * Fix usage of asyncio.run() on 3.6
This commit is contained in:
parent
f57bb2f142
commit
ab41a5d5c3
@ -7,7 +7,7 @@ import httpx
|
||||
|
||||
async def test_get(server, backend):
|
||||
url = server.url
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.get(url)
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
@ -20,7 +20,7 @@ async def test_get(server, backend):
|
||||
async def test_build_request(server, backend):
|
||||
url = server.url.copy_with(path="/echo_headers")
|
||||
headers = {"Custom-header": "value"}
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
request = client.build_request("GET", url)
|
||||
request.headers.update(headers)
|
||||
response = await client.send(request)
|
||||
@ -48,20 +48,20 @@ async def test_get_no_backend(server):
|
||||
|
||||
async def test_post(server, backend):
|
||||
url = server.url
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.post(url, data=b"Hello, world!")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_post_json(server, backend):
|
||||
url = server.url
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.post(url, json={"text": "Hello, world!"})
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_stream_response(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
async with client.stream("GET", server.url) as response:
|
||||
body = await response.read()
|
||||
|
||||
@ -71,7 +71,7 @@ async def test_stream_response(server, backend):
|
||||
|
||||
|
||||
async def test_access_content_stream_response(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
async with client.stream("GET", server.url) as response:
|
||||
pass
|
||||
|
||||
@ -85,13 +85,13 @@ async def test_stream_request(server, backend):
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.request("POST", server.url, data=hello_world())
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_raise_for_status(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
for status_code in (200, 400, 404, 500, 505):
|
||||
response = await client.request(
|
||||
"GET", server.url.copy_with(path=f"/status/{status_code}")
|
||||
@ -106,33 +106,33 @@ async def test_raise_for_status(server, backend):
|
||||
|
||||
|
||||
async def test_options(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.options(server.url)
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
||||
|
||||
async def test_head(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.head(server.url)
|
||||
assert response.status_code == 200
|
||||
assert response.text == ""
|
||||
|
||||
|
||||
async def test_put(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.put(server.url, data=b"Hello, world!")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_patch(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.patch(server.url, data=b"Hello, world!")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_delete(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.delete(server.url)
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
@ -142,7 +142,7 @@ async def test_100_continue(server, backend):
|
||||
headers = {"Expect": "100-continue"}
|
||||
data = b"Echo request body"
|
||||
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
async with httpx.Client() as client:
|
||||
response = await client.post(
|
||||
server.url.copy_with(path="/echo_body"), headers=headers, data=data
|
||||
)
|
||||
@ -155,8 +155,22 @@ async def test_uds(uds_server, backend):
|
||||
url = uds_server.url
|
||||
uds = uds_server.config.uds
|
||||
assert uds is not None
|
||||
async with httpx.Client(backend=backend, uds=uds) as client:
|
||||
async with httpx.Client(uds=uds) as client:
|
||||
response = await client.get(url)
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
assert response.encoding == "iso-8859-1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
[
|
||||
pytest.param("asyncio", marks=pytest.mark.asyncio),
|
||||
pytest.param("trio", marks=pytest.mark.trio),
|
||||
],
|
||||
)
|
||||
async def test_explicit_backend(server, backend):
|
||||
async with httpx.Client(backend=backend) as client:
|
||||
response = await client.get(server.url)
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
||||
@ -105,7 +105,7 @@ class MockDispatch(Dispatcher):
|
||||
|
||||
|
||||
async def test_no_redirect(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.com/no_redirect"
|
||||
response = await client.get(url)
|
||||
assert response.status_code == 200
|
||||
@ -114,7 +114,7 @@ async def test_no_redirect(backend):
|
||||
|
||||
|
||||
async def test_redirect_301(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
response = await client.post("https://example.org/redirect_301")
|
||||
assert response.status_code == codes.OK
|
||||
assert response.url == URL("https://example.org/")
|
||||
@ -122,7 +122,7 @@ async def test_redirect_301(backend):
|
||||
|
||||
|
||||
async def test_redirect_302(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
response = await client.post("https://example.org/redirect_302")
|
||||
assert response.status_code == codes.OK
|
||||
assert response.url == URL("https://example.org/")
|
||||
@ -130,7 +130,7 @@ async def test_redirect_302(backend):
|
||||
|
||||
|
||||
async def test_redirect_303(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
response = await client.get("https://example.org/redirect_303")
|
||||
assert response.status_code == codes.OK
|
||||
assert response.url == URL("https://example.org/")
|
||||
@ -138,7 +138,7 @@ async def test_redirect_303(backend):
|
||||
|
||||
|
||||
async def test_disallow_redirects(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
response = await client.post(
|
||||
"https://example.org/redirect_303", allow_redirects=False
|
||||
)
|
||||
@ -155,7 +155,7 @@ async def test_disallow_redirects(backend):
|
||||
|
||||
|
||||
async def test_relative_redirect(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
response = await client.get("https://example.org/relative_redirect")
|
||||
assert response.status_code == codes.OK
|
||||
assert response.url == URL("https://example.org/")
|
||||
@ -163,7 +163,7 @@ async def test_relative_redirect(backend):
|
||||
|
||||
|
||||
async def test_no_scheme_redirect(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
response = await client.get("https://example.org/no_scheme_redirect")
|
||||
assert response.status_code == codes.OK
|
||||
assert response.url == URL("https://example.org/")
|
||||
@ -171,7 +171,7 @@ async def test_no_scheme_redirect(backend):
|
||||
|
||||
|
||||
async def test_fragment_redirect(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
response = await client.get("https://example.org/relative_redirect#fragment")
|
||||
assert response.status_code == codes.OK
|
||||
assert response.url == URL("https://example.org/#fragment")
|
||||
@ -179,7 +179,7 @@ async def test_fragment_redirect(backend):
|
||||
|
||||
|
||||
async def test_multiple_redirects(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
response = await client.get("https://example.org/multiple_redirects?count=20")
|
||||
assert response.status_code == codes.OK
|
||||
assert response.url == URL("https://example.org/multiple_redirects")
|
||||
@ -195,13 +195,13 @@ async def test_multiple_redirects(backend):
|
||||
|
||||
|
||||
async def test_too_many_redirects(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
with pytest.raises(TooManyRedirects):
|
||||
await client.get("https://example.org/multiple_redirects?count=21")
|
||||
|
||||
|
||||
async def test_too_many_redirects_calling_next(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.org/multiple_redirects?count=21"
|
||||
response = await client.get(url, allow_redirects=False)
|
||||
with pytest.raises(TooManyRedirects):
|
||||
@ -210,13 +210,13 @@ async def test_too_many_redirects_calling_next(backend):
|
||||
|
||||
|
||||
async def test_redirect_loop(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
with pytest.raises(RedirectLoop):
|
||||
await client.get("https://example.org/redirect_loop")
|
||||
|
||||
|
||||
async def test_redirect_loop_calling_next(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.org/redirect_loop"
|
||||
response = await client.get(url, allow_redirects=False)
|
||||
with pytest.raises(RedirectLoop):
|
||||
@ -225,7 +225,7 @@ async def test_redirect_loop_calling_next(backend):
|
||||
|
||||
|
||||
async def test_cross_domain_redirect(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.com/cross_domain"
|
||||
headers = {"Authorization": "abc"}
|
||||
response = await client.get(url, headers=headers)
|
||||
@ -234,7 +234,7 @@ async def test_cross_domain_redirect(backend):
|
||||
|
||||
|
||||
async def test_same_domain_redirect(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.org/cross_domain"
|
||||
headers = {"Authorization": "abc"}
|
||||
response = await client.get(url, headers=headers)
|
||||
@ -246,7 +246,7 @@ async def test_body_redirect(backend):
|
||||
"""
|
||||
A 308 redirect should preserve the request body.
|
||||
"""
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.org/redirect_body"
|
||||
data = b"Example request body"
|
||||
response = await client.post(url, data=data)
|
||||
@ -259,7 +259,7 @@ async def test_no_body_redirect(backend):
|
||||
"""
|
||||
A 303 redirect should remove the request body.
|
||||
"""
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.org/redirect_no_body"
|
||||
data = b"Example request body"
|
||||
response = await client.post(url, data=data)
|
||||
@ -269,7 +269,7 @@ async def test_no_body_redirect(backend):
|
||||
|
||||
|
||||
async def test_cannot_redirect_streaming_body(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.org/redirect_body"
|
||||
|
||||
async def streaming_body():
|
||||
@ -280,7 +280,7 @@ async def test_cannot_redirect_streaming_body(backend):
|
||||
|
||||
|
||||
async def test_cross_subdomain_redirect(backend):
|
||||
client = Client(dispatch=MockDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockDispatch())
|
||||
url = "https://example.com/cross_subdomain"
|
||||
response = await client.get(url)
|
||||
assert response.url == URL("https://www.example.org/cross_subdomain")
|
||||
@ -326,7 +326,7 @@ class MockCookieDispatch(Dispatcher):
|
||||
|
||||
|
||||
async def test_redirect_cookie_behavior(backend):
|
||||
client = Client(dispatch=MockCookieDispatch(), backend=backend)
|
||||
client = Client(dispatch=MockCookieDispatch())
|
||||
|
||||
# The client is not logged in.
|
||||
response = await client.get("https://example.com/")
|
||||
|
||||
@ -10,6 +10,7 @@ import typing
|
||||
import trio
|
||||
|
||||
from httpx.concurrency.asyncio import AsyncioBackend
|
||||
from httpx.concurrency.auto import AutoBackend
|
||||
from httpx.concurrency.trio import TrioBackend
|
||||
|
||||
|
||||
@ -18,6 +19,11 @@ async def sleep(backend, seconds: int):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
@sleep.register(AutoBackend)
|
||||
async def _sleep_auto(backend, seconds: int):
|
||||
return await sleep(backend.backend, seconds=seconds)
|
||||
|
||||
|
||||
@sleep.register(AsyncioBackend)
|
||||
async def _sleep_asyncio(backend, seconds: int):
|
||||
await asyncio.sleep(seconds)
|
||||
@ -33,6 +39,11 @@ async def run_concurrently(backend, *coroutines: typing.Callable[[], typing.Awai
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
@run_concurrently.register(AutoBackend)
|
||||
async def _run_concurrently_auto(backend, *coroutines):
|
||||
await run_concurrently(backend.backend, *coroutines)
|
||||
|
||||
|
||||
@run_concurrently.register(AsyncioBackend)
|
||||
async def _run_concurrently_asyncio(backend, *coroutines):
|
||||
coros = (coroutine() for coroutine in coroutines)
|
||||
@ -44,3 +55,23 @@ async def _run_concurrently_trio(backend, *coroutines):
|
||||
async with trio.open_nursery() as nursery:
|
||||
for coroutine in coroutines:
|
||||
nursery.start_soon(coroutine)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def get_cipher(backend, stream):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
@get_cipher.register(AutoBackend)
|
||||
def _get_cipher_auto(backend, stream):
|
||||
return get_cipher(backend.backend, stream)
|
||||
|
||||
|
||||
@get_cipher.register(AsyncioBackend)
|
||||
def _get_cipher_asyncio(backend, stream):
|
||||
return stream.stream_writer.get_extra_info("cipher", default=None)
|
||||
|
||||
|
||||
@get_cipher.register(TrioBackend)
|
||||
def get_trio_cipher(backend, stream):
|
||||
return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
|
||||
|
||||
@ -19,7 +19,7 @@ from uvicorn.main import Server
|
||||
|
||||
from httpx import URL
|
||||
from httpx.concurrency.asyncio import AsyncioBackend
|
||||
from httpx.concurrency.trio import TrioBackend
|
||||
from httpx.concurrency.base import lookup_backend
|
||||
|
||||
ENVIRONMENT_VARIABLES = {
|
||||
"SSL_CERT_FILE",
|
||||
@ -51,13 +51,17 @@ def clean_environ() -> typing.Dict[str, typing.Any]:
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
pytest.param(AsyncioBackend, marks=pytest.mark.asyncio),
|
||||
pytest.param(TrioBackend, marks=pytest.mark.trio),
|
||||
# pytest uses the marks to set up the specified async environment and run
|
||||
# 'async def' test functions. The "auto" backend should then auto-detect
|
||||
# the environment it's running in.
|
||||
# Passing the backend explicitly, e.g. `backend="asyncio"`,
|
||||
# is tested separately.
|
||||
pytest.param("auto", marks=pytest.mark.asyncio),
|
||||
pytest.param("auto", marks=pytest.mark.trio),
|
||||
]
|
||||
)
|
||||
def backend(request):
|
||||
backend_cls = request.param
|
||||
return backend_cls()
|
||||
return request.param
|
||||
|
||||
|
||||
async def app(scope, receive, send):
|
||||
@ -271,9 +275,13 @@ def restart(backend):
|
||||
This fixture deals with possible differences between the environment of the
|
||||
test function and that of the server.
|
||||
"""
|
||||
asyncio_backend = AsyncioBackend()
|
||||
backend_implementation = lookup_backend(backend)
|
||||
|
||||
async def restart(server):
|
||||
await backend.run_in_threadpool(AsyncioBackend().run, server.restart)
|
||||
await backend_implementation.run_in_threadpool(
|
||||
asyncio_backend.run, server.restart
|
||||
)
|
||||
|
||||
return restart
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ async def test_keepalive_connections(server, backend):
|
||||
"""
|
||||
Connections should default to staying in a keep-alive state.
|
||||
"""
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
@ -22,7 +22,7 @@ async def test_differing_connection_keys(server, backend):
|
||||
"""
|
||||
Connections to differing connection keys should result in multiple connections.
|
||||
"""
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
@ -40,7 +40,7 @@ async def test_soft_limit(server, backend):
|
||||
"""
|
||||
pool_limits = httpx.PoolLimits(soft_limit=1)
|
||||
|
||||
async with ConnectionPool(pool_limits=pool_limits, backend=backend) as http:
|
||||
async with ConnectionPool(pool_limits=pool_limits) as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
@ -56,7 +56,7 @@ async def test_streaming_response_holds_connection(server, backend):
|
||||
"""
|
||||
A streaming request should hold the connection open until the response is read.
|
||||
"""
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
@ -71,7 +71,7 @@ async def test_multiple_concurrent_connections(server, backend):
|
||||
"""
|
||||
Multiple conncurrent requests should open multiple conncurrent connections.
|
||||
"""
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response_a = await http.request("GET", server.url)
|
||||
assert len(http.active_connections) == 1
|
||||
assert len(http.keepalive_connections) == 0
|
||||
@ -94,7 +94,7 @@ async def test_close_connections(server, backend):
|
||||
Using a `Connection: close` header should close the connection.
|
||||
"""
|
||||
headers = [(b"connection", b"close")]
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url, headers=headers)
|
||||
await response.read()
|
||||
assert len(http.active_connections) == 0
|
||||
@ -105,7 +105,7 @@ async def test_standard_response_close(server, backend):
|
||||
"""
|
||||
A standard close should keep the connection open.
|
||||
"""
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.read()
|
||||
await response.close()
|
||||
@ -117,7 +117,7 @@ async def test_premature_response_close(server, backend):
|
||||
"""
|
||||
A premature close should close the connection.
|
||||
"""
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.close()
|
||||
assert len(http.active_connections) == 0
|
||||
@ -131,7 +131,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(
|
||||
Upon keep-alive connection closed by remote a new connection
|
||||
should be reestablished.
|
||||
"""
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.read()
|
||||
|
||||
@ -151,7 +151,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
|
||||
Upon keep-alive connection closed by remote a new connection
|
||||
should be reestablished.
|
||||
"""
|
||||
async with ConnectionPool(backend=backend) as http:
|
||||
async with ConnectionPool() as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.read()
|
||||
|
||||
@ -169,9 +169,7 @@ async def test_connection_closed_free_semaphore_on_acquire(server, restart, back
|
||||
Verify that max_connections semaphore is released
|
||||
properly on a disconnected connection.
|
||||
"""
|
||||
async with ConnectionPool(
|
||||
pool_limits=httpx.PoolLimits(hard_limit=1), backend=backend
|
||||
) as http:
|
||||
async with ConnectionPool(pool_limits=httpx.PoolLimits(hard_limit=1)) as http:
|
||||
response = await http.request("GET", server.url)
|
||||
await response.read()
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from httpx.dispatch.connection import HTTPConnection
|
||||
|
||||
|
||||
async def test_get(server, backend):
|
||||
async with HTTPConnection(origin=server.url, backend=backend) as conn:
|
||||
async with HTTPConnection(origin=server.url) as conn:
|
||||
response = await conn.request("GET", server.url)
|
||||
await response.read()
|
||||
assert response.status_code == 200
|
||||
@ -13,14 +13,14 @@ async def test_get(server, backend):
|
||||
|
||||
|
||||
async def test_post(server, backend):
|
||||
async with HTTPConnection(origin=server.url, backend=backend) as conn:
|
||||
async with HTTPConnection(origin=server.url) as conn:
|
||||
response = await conn.request("GET", server.url, data=b"Hello, world!")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_premature_close(server, backend):
|
||||
with pytest.raises(httpx.ConnectionClosed):
|
||||
async with HTTPConnection(origin=server.url, backend=backend) as conn:
|
||||
async with HTTPConnection(origin=server.url) as conn:
|
||||
response = await conn.request(
|
||||
"GET", server.url.copy_with(path="/premature_close")
|
||||
)
|
||||
@ -31,9 +31,7 @@ async def test_https_get_with_ssl_defaults(https_server, ca_cert_pem_file, backe
|
||||
"""
|
||||
An HTTPS request, with default SSL configuration set on the client.
|
||||
"""
|
||||
async with HTTPConnection(
|
||||
origin=https_server.url, verify=ca_cert_pem_file, backend=backend
|
||||
) as conn:
|
||||
async with HTTPConnection(origin=https_server.url, verify=ca_cert_pem_file) as conn:
|
||||
response = await conn.request("GET", https_server.url)
|
||||
await response.read()
|
||||
assert response.status_code == 200
|
||||
@ -44,7 +42,7 @@ async def test_https_get_with_sll_overrides(https_server, ca_cert_pem_file, back
|
||||
"""
|
||||
An HTTPS request, with SSL configuration set on the request.
|
||||
"""
|
||||
async with HTTPConnection(origin=https_server.url, backend=backend) as conn:
|
||||
async with HTTPConnection(origin=https_server.url) as conn:
|
||||
response = await conn.request("GET", https_server.url, verify=ca_cert_pem_file)
|
||||
await response.read()
|
||||
assert response.status_code == 200
|
||||
|
||||
@ -104,7 +104,7 @@ async def test_http2_reconnect():
|
||||
|
||||
|
||||
async def test_http2_settings_in_handshake(backend):
|
||||
backend = MockHTTP2Backend(app=app, backend=backend)
|
||||
backend = MockHTTP2Backend(app=app)
|
||||
|
||||
async with Client(backend=backend, http2=True) as client:
|
||||
await client.get("http://example.org")
|
||||
@ -139,7 +139,7 @@ async def test_http2_settings_in_handshake(backend):
|
||||
|
||||
|
||||
async def test_http2_live_request(backend):
|
||||
async with Client(backend=backend, http2=True) as client:
|
||||
async with Client(http2=True) as client:
|
||||
try:
|
||||
resp = await client.get("https://nghttp2.org/httpbin/anything")
|
||||
except TimeoutException:
|
||||
|
||||
@ -19,7 +19,6 @@ async def test_proxy_tunnel_success(backend):
|
||||
b"\r\n",
|
||||
]
|
||||
),
|
||||
backend=backend,
|
||||
)
|
||||
async with httpx.HTTPProxy(
|
||||
proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
|
||||
@ -53,7 +52,6 @@ async def test_proxy_tunnel_non_2xx_response(backend, status_code):
|
||||
b"\r\n",
|
||||
]
|
||||
),
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.ProxyError) as e:
|
||||
@ -105,7 +103,6 @@ async def test_proxy_tunnel_start_tls(backend):
|
||||
b"\r\n",
|
||||
]
|
||||
),
|
||||
backend=backend,
|
||||
)
|
||||
async with httpx.HTTPProxy(
|
||||
proxy_url="http://127.0.0.1:8000", backend=raw_io, proxy_mode="TUNNEL_ONLY",
|
||||
@ -155,7 +152,6 @@ async def test_proxy_forwarding(backend, proxy_mode):
|
||||
b"\r\n"
|
||||
]
|
||||
),
|
||||
backend=backend,
|
||||
)
|
||||
async with httpx.HTTPProxy(
|
||||
proxy_url="http://127.0.0.1:8000",
|
||||
|
||||
@ -6,15 +6,14 @@ import h2.connection
|
||||
import h2.events
|
||||
|
||||
from httpx import Request, Timeout
|
||||
from httpx.concurrency.asyncio import AsyncioBackend
|
||||
from httpx.concurrency.base import BaseSocketStream
|
||||
from httpx.concurrency.base import BaseSocketStream, lookup_backend
|
||||
from tests.concurrency import sleep
|
||||
|
||||
|
||||
class MockHTTP2Backend:
|
||||
def __init__(self, app, backend=None):
|
||||
def __init__(self, app, backend="auto"):
|
||||
self.app = app
|
||||
self.backend = AsyncioBackend() if backend is None else backend
|
||||
self.backend = lookup_backend(backend)
|
||||
self.server = None
|
||||
|
||||
async def open_tcp_stream(
|
||||
@ -168,8 +167,8 @@ class MockHTTP2Server(BaseSocketStream):
|
||||
|
||||
|
||||
class MockRawSocketBackend:
|
||||
def __init__(self, data_to_send=b"", backend=None):
|
||||
self.backend = AsyncioBackend() if backend is None else backend
|
||||
def __init__(self, data_to_send=b"", backend="auto"):
|
||||
self.backend = lookup_backend(backend)
|
||||
self.data_to_send = data_to_send
|
||||
self.received_data = []
|
||||
self.stream = MockRawSocketStream(self)
|
||||
|
||||
@ -1,19 +1,14 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from httpx import Timeout
|
||||
from httpx.concurrency.asyncio import AsyncioBackend
|
||||
from httpx.concurrency.base import lookup_backend
|
||||
from httpx.concurrency.trio import TrioBackend
|
||||
from httpx.config import SSLConfig
|
||||
from tests.concurrency import run_concurrently, sleep
|
||||
|
||||
|
||||
def get_asyncio_cipher(stream):
|
||||
return stream.stream_writer.get_extra_info("cipher", default=None)
|
||||
|
||||
|
||||
def get_trio_cipher(stream):
|
||||
return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
|
||||
from tests.concurrency import get_cipher, run_concurrently, sleep
|
||||
|
||||
|
||||
async def read_response(stream, timeout: Timeout, should_contain: bytes) -> bytes:
|
||||
@ -34,14 +29,8 @@ async def read_response(stream, timeout: Timeout, should_contain: bytes) -> byte
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend, get_cipher",
|
||||
[
|
||||
pytest.param(AsyncioBackend(), get_asyncio_cipher, marks=pytest.mark.asyncio),
|
||||
pytest.param(TrioBackend(), get_trio_cipher, marks=pytest.mark.trio),
|
||||
],
|
||||
)
|
||||
async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher):
|
||||
async def test_start_tls_on_tcp_socket_stream(https_server, backend):
|
||||
backend = lookup_backend(backend)
|
||||
ctx = SSLConfig().load_ssl_context_no_verify()
|
||||
timeout = Timeout(5)
|
||||
|
||||
@ -51,11 +40,11 @@ async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher)
|
||||
|
||||
try:
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(stream) is None
|
||||
assert get_cipher(backend, stream) is None
|
||||
|
||||
stream = await stream.start_tls(https_server.url.host, ctx, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(stream) is not None
|
||||
assert get_cipher(backend, stream) is not None
|
||||
|
||||
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
|
||||
|
||||
@ -66,14 +55,8 @@ async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher)
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend, get_cipher",
|
||||
[
|
||||
pytest.param(AsyncioBackend(), get_asyncio_cipher, marks=pytest.mark.asyncio),
|
||||
pytest.param(TrioBackend(), get_trio_cipher, marks=pytest.mark.trio),
|
||||
],
|
||||
)
|
||||
async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cipher):
|
||||
async def test_start_tls_on_uds_socket_stream(https_uds_server, backend):
|
||||
backend = lookup_backend(backend)
|
||||
ctx = SSLConfig().load_ssl_context_no_verify()
|
||||
timeout = Timeout(5)
|
||||
|
||||
@ -83,11 +66,11 @@ async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cip
|
||||
|
||||
try:
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(stream) is None
|
||||
assert get_cipher(backend, stream) is None
|
||||
|
||||
stream = await stream.start_tls(https_uds_server.url.host, ctx, timeout)
|
||||
assert stream.is_connection_dropped() is False
|
||||
assert get_cipher(stream) is not None
|
||||
assert get_cipher(backend, stream) is not None
|
||||
|
||||
await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
|
||||
|
||||
@ -102,6 +85,7 @@ async def test_concurrent_read(server, backend):
|
||||
"""
|
||||
Regression test for: https://github.com/encode/httpx/issues/527
|
||||
"""
|
||||
backend = lookup_backend(backend)
|
||||
stream = await backend.open_tcp_stream(
|
||||
server.url.host, server.url.port, ssl_context=None, timeout=Timeout(5)
|
||||
)
|
||||
@ -116,6 +100,7 @@ async def test_concurrent_read(server, backend):
|
||||
|
||||
|
||||
async def test_fork(backend):
|
||||
backend = lookup_backend(backend)
|
||||
ok_counter = 0
|
||||
|
||||
async def ok(delay: int) -> None:
|
||||
@ -159,3 +144,23 @@ async def test_fork(backend):
|
||||
# No 'match', since we can't know which will be raised first.
|
||||
with pytest.raises(RuntimeError):
|
||||
await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0])
|
||||
|
||||
|
||||
def test_lookup_backend():
|
||||
assert isinstance(lookup_backend("asyncio"), AsyncioBackend)
|
||||
assert isinstance(lookup_backend("trio"), TrioBackend)
|
||||
assert isinstance(lookup_backend(AsyncioBackend()), AsyncioBackend)
|
||||
|
||||
async def get_backend_from_auto():
|
||||
auto_backend = lookup_backend("auto")
|
||||
return auto_backend.backend
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
backend = loop.run_until_complete(get_backend_from_auto())
|
||||
assert isinstance(backend, AsyncioBackend)
|
||||
|
||||
backend = trio.run(get_backend_from_auto)
|
||||
assert isinstance(backend, TrioBackend)
|
||||
|
||||
with pytest.raises(Exception, match="unknownio"):
|
||||
lookup_backend("unknownio")
|
||||
|
||||
@ -6,7 +6,7 @@ import httpx
|
||||
async def test_read_timeout(server, backend):
|
||||
timeout = httpx.Timeout(read_timeout=1e-6)
|
||||
|
||||
async with httpx.Client(timeout=timeout, backend=backend) as client:
|
||||
async with httpx.Client(timeout=timeout) as client:
|
||||
with pytest.raises(httpx.ReadTimeout):
|
||||
await client.get(server.url.copy_with(path="/slow_response"))
|
||||
|
||||
@ -14,7 +14,7 @@ async def test_read_timeout(server, backend):
|
||||
async def test_write_timeout(server, backend):
|
||||
timeout = httpx.Timeout(write_timeout=1e-6)
|
||||
|
||||
async with httpx.Client(timeout=timeout, backend=backend) as client:
|
||||
async with httpx.Client(timeout=timeout) as client:
|
||||
with pytest.raises(httpx.WriteTimeout):
|
||||
data = b"*" * 1024 * 1024 * 100
|
||||
await client.put(server.url.copy_with(path="/slow_response"), data=data)
|
||||
@ -23,7 +23,7 @@ async def test_write_timeout(server, backend):
|
||||
async def test_connect_timeout(server, backend):
|
||||
timeout = httpx.Timeout(connect_timeout=1e-6)
|
||||
|
||||
async with httpx.Client(timeout=timeout, backend=backend) as client:
|
||||
async with httpx.Client(timeout=timeout) as client:
|
||||
with pytest.raises(httpx.ConnectTimeout):
|
||||
# See https://stackoverflow.com/questions/100841/
|
||||
await client.get("http://10.255.255.1/")
|
||||
@ -33,9 +33,7 @@ async def test_pool_timeout(server, backend):
|
||||
pool_limits = httpx.PoolLimits(hard_limit=1)
|
||||
timeout = httpx.Timeout(pool_timeout=1e-4)
|
||||
|
||||
async with httpx.Client(
|
||||
pool_limits=pool_limits, timeout=timeout, backend=backend
|
||||
) as client:
|
||||
async with httpx.Client(pool_limits=pool_limits, timeout=timeout) as client:
|
||||
async with client.stream("GET", server.url):
|
||||
with pytest.raises(httpx.PoolTimeout):
|
||||
await client.get("http://localhost:8000/")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user