restore original ASGITransport class code and split streaming-enabled version into separate ASGIStreamingTransport class

This commit is contained in:
Jean Hominal 2025-12-28 22:38:21 +01:00
parent 485d2f341f
commit 6a240f4690
4 changed files with 199 additions and 66 deletions

View File

@ -30,6 +30,7 @@ __all__ = [
"__description__",
"__title__",
"__version__",
"ASGIStreamingTransport",
"ASGITransport",
"AsyncBaseTransport",
"AsyncByteStream",

View File

@ -5,6 +5,7 @@ from .mock import *
from .wsgi import *
__all__ = [
"ASGIStreamingTransport",
"ASGITransport",
"AsyncBaseTransport",
"BaseTransport",

View File

@ -33,7 +33,7 @@ _ASGIApp = typing.Callable[
[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
]
__all__ = ["ASGITransport"]
__all__ = ["ASGITransport", "ASGIStreamingTransport"]
def is_running_trio() -> bool:
@ -98,6 +98,141 @@ def get_end_of_stream_error_type() -> type[anyio.EndOfStream | trio.EndOfChannel
class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: list[bytes]) -> None:
self._body = body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b"".join(self._body)
class ASGITransport(AsyncBaseTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app.
```python
transport = httpx.ASGITransport(
app=app,
root_path="/submount",
client=("1.2.3.4", 123)
)
client = httpx.AsyncClient(transport=transport)
```
Arguments:
* `app` - The ASGI application.
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
should be raised. Default to `True`. Can be set to `False` for use cases
such as testing the content of a client 500 response.
* `root_path` - The root path on which the ASGI application should be mounted.
* `client` - A two-tuple indicating the client IP and port of incoming requests.
```
"""
def __init__(
self,
app: _ASGIApp,
raise_app_exceptions: bool = True,
root_path: str = "",
client: tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
async def handle_async_request(
self,
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)
# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path.split(b"?")[0],
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
"root_path": self.root_path,
}
# Request.
request_body_chunks = request.stream.__aiter__()
request_complete = False
# Response.
status_code = None
response_headers = None
body_parts = []
response_started = False
response_complete = create_event()
# ASGI callables.
async def receive() -> dict[str, typing.Any]:
nonlocal request_complete
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}
try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
nonlocal status_code, response_headers, response_started
if message["type"] == "http.response.start":
assert not response_started
status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True
elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body and request.method != "HEAD":
body_parts.append(body)
if not more_body:
response_complete.set()
try:
await self.app(scope, receive, send)
except Exception: # noqa: PIE-786
if self.raise_app_exceptions:
raise
response_complete.set()
if status_code is None:
status_code = 500
if response_headers is None:
response_headers = {}
assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None
stream = ASGIResponseStream(body_parts)
return Response(status_code, headers=response_headers, stream=stream)
class ASGIStreamingResponseStream(AsyncByteStream):
def __init__(
self,
ignore_body: bool,
@ -124,18 +259,19 @@ class ASGIResponseStream(AsyncByteStream):
await self._asgi_generator.aclose()
class ASGITransport(AsyncBaseTransport):
class ASGIStreamingTransport(AsyncBaseTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app.
An equivalent of ASGITransport that operates by running app in a sub-task and
streaming response events as soon as they arrive.
```python
transport = httpx.ASGITransport(
app=app,
root_path="/submount",
client=("1.2.3.4", 123)
)
client = httpx.AsyncClient(transport=transport)
```
It is used in the same way, with the same arguments having the same signification,
as ASGITransport.
The main observable differences between the two implementations will be as follows:
* As the application callable is invoked in a sub-task, any context variables that
are set by the app will not propagate to the caller;
* The streaming mode of operation means that a response will generally be returned
to the AsyncClient caller before the application has fully run;
Arguments:
@ -145,9 +281,6 @@ class ASGITransport(AsyncBaseTransport):
such as testing the content of a client 500 response.
* `root_path` - The root path on which the ASGI application should be mounted.
* `client` - A two-tuple indicating the client IP and port of incoming requests.
* `streaming` - Set to `True` to enable streaming of response content. Default to
`False`, as activating this feature means that the ASGI `app` will run in a
sub-task, which has observable side effects for context variables.
```
"""
@ -157,14 +290,11 @@ class ASGITransport(AsyncBaseTransport):
raise_app_exceptions: bool = True,
root_path: str = "",
client: tuple[str, int] = ("127.0.0.1", 123),
*,
streaming: bool = False,
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
self.streaming = streaming
async def handle_async_request(
self,
@ -177,7 +307,7 @@ class ASGITransport(AsyncBaseTransport):
return Response(
status_code=message["status"],
headers=message.get("headers", []),
stream=ASGIResponseStream(
stream=ASGIStreamingResponseStream(
ignore_body=request.method == "HEAD",
asgi_generator=asgi_generator,
),
@ -214,9 +344,8 @@ class ASGITransport(AsyncBaseTransport):
response_complete = create_event()
# ASGI response messages stream
stream_size = 0 if self.streaming else float("inf")
response_message_send_stream, response_message_recv_stream = (
create_memory_object_stream(stream_size)
create_memory_object_stream(0)
)
# ASGI app exception
@ -256,11 +385,8 @@ class ASGITransport(AsyncBaseTransport):
async with contextlib.AsyncExitStack() as exit_stack:
exit_stack.callback(response_complete.set)
if self.streaming:
task_group = await exit_stack.enter_async_context(create_task_group())
task_group.start_soon(run_app)
else:
await run_app()
task_group = await exit_stack.enter_async_context(create_task_group())
task_group.start_soon(run_app)
async with response_message_recv_stream:
try:

View File

@ -14,9 +14,17 @@ test_asgi_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
)
@pytest.fixture(params=[False, True], ids=["no_streaming", "with_streaming"])
def streaming(request):
return request.param
@pytest.fixture(
params=[httpx.ASGITransport, httpx.ASGIStreamingTransport],
ids=["ASGITransport", "ASGIStreamingTransport"],
)
def asgi_transport_class(
request: pytest.FixtureRequest,
) -> type[typing.Union[httpx.ASGITransport, httpx.ASGIStreamingTransport]]:
return typing.cast(
type[typing.Union[httpx.ASGITransport, httpx.ASGIStreamingTransport]],
request.param,
)
def run_in_task_group(
@ -106,8 +114,8 @@ async def raise_exc_after_response(scope, receive, send):
@pytest.mark.anyio
async def test_asgi_transport(streaming):
async with httpx.ASGITransport(app=hello_world, streaming=streaming) as transport:
async def test_asgi_transport(asgi_transport_class):
async with asgi_transport_class(app=hello_world) as transport:
request = httpx.Request("GET", "http://www.example.com/")
response = await transport.handle_async_request(request)
await response.aread()
@ -116,8 +124,8 @@ async def test_asgi_transport(streaming):
@pytest.mark.anyio
async def test_asgi_transport_no_body(streaming):
async with httpx.ASGITransport(app=echo_body, streaming=streaming) as transport:
async def test_asgi_transport_no_body(asgi_transport_class):
async with asgi_transport_class(app=echo_body) as transport:
request = httpx.Request("GET", "http://www.example.com/")
response = await transport.handle_async_request(request)
await response.aread()
@ -126,8 +134,8 @@ async def test_asgi_transport_no_body(streaming):
@pytest.mark.anyio
async def test_asgi(streaming):
transport = httpx.ASGITransport(app=hello_world, streaming=streaming)
async def test_asgi(asgi_transport_class):
transport = asgi_transport_class(app=hello_world)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
@ -136,8 +144,8 @@ async def test_asgi(streaming):
@pytest.mark.anyio
async def test_asgi_urlencoded_path(streaming):
transport = httpx.ASGITransport(app=echo_path, streaming=streaming)
async def test_asgi_urlencoded_path(asgi_transport_class):
transport = asgi_transport_class(app=echo_path)
async with httpx.AsyncClient(transport=transport) as client:
url = httpx.URL("http://www.example.org/").copy_with(path="/user@example.org")
response = await client.get(url)
@ -147,8 +155,8 @@ async def test_asgi_urlencoded_path(streaming):
@pytest.mark.anyio
async def test_asgi_raw_path(streaming):
transport = httpx.ASGITransport(app=echo_raw_path, streaming=streaming)
async def test_asgi_raw_path(asgi_transport_class):
transport = asgi_transport_class(app=echo_raw_path)
async with httpx.AsyncClient(transport=transport) as client:
url = httpx.URL("http://www.example.org/").copy_with(path="/user@example.org")
response = await client.get(url)
@ -158,11 +166,13 @@ async def test_asgi_raw_path(streaming):
@pytest.mark.anyio
async def test_asgi_raw_path_should_not_include_querystring_portion(streaming):
async def test_asgi_raw_path_should_not_include_querystring_portion(
asgi_transport_class,
):
"""
See https://github.com/encode/httpx/issues/2810
"""
transport = httpx.ASGITransport(app=echo_raw_path, streaming=streaming)
transport = asgi_transport_class(app=echo_raw_path)
async with httpx.AsyncClient(transport=transport) as client:
url = httpx.URL("http://www.example.org/path?query")
response = await client.get(url)
@ -172,8 +182,8 @@ async def test_asgi_raw_path_should_not_include_querystring_portion(streaming):
@pytest.mark.anyio
async def test_asgi_upload(streaming):
transport = httpx.ASGITransport(app=echo_body, streaming=streaming)
async def test_asgi_upload(asgi_transport_class):
transport = asgi_transport_class(app=echo_body)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.post("http://www.example.org/", content=b"example")
@ -182,8 +192,8 @@ async def test_asgi_upload(streaming):
@pytest.mark.anyio
async def test_asgi_headers(streaming):
transport = httpx.ASGITransport(app=echo_headers, streaming=streaming)
async def test_asgi_headers(asgi_transport_class):
transport = asgi_transport_class(app=echo_headers)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
@ -200,33 +210,31 @@ async def test_asgi_headers(streaming):
@pytest.mark.anyio
async def test_asgi_exc(streaming):
transport = httpx.ASGITransport(app=raise_exc, streaming=streaming)
async def test_asgi_exc(asgi_transport_class):
transport = asgi_transport_class(app=raise_exc)
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")
@pytest.mark.anyio
async def test_asgi_exc_after_response_start(streaming):
transport = httpx.ASGITransport(
app=raise_exc_after_response_start, streaming=streaming
)
async def test_asgi_exc_after_response_start(asgi_transport_class):
transport = asgi_transport_class(app=raise_exc_after_response_start)
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")
@pytest.mark.anyio
async def test_asgi_exc_after_response(streaming):
transport = httpx.ASGITransport(app=raise_exc_after_response, streaming=streaming)
async def test_asgi_exc_after_response(asgi_transport_class):
transport = asgi_transport_class(app=raise_exc_after_response)
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")
@pytest.mark.anyio
async def test_asgi_disconnect_after_response_complete(streaming):
async def test_asgi_disconnect_after_response_complete(asgi_transport_class):
disconnect = False
async def read_body(scope, receive, send):
@ -252,7 +260,7 @@ async def test_asgi_disconnect_after_response_complete(streaming):
message = await receive()
disconnect = message.get("type") == "http.disconnect"
transport = httpx.ASGITransport(app=read_body, streaming=streaming)
transport = asgi_transport_class(app=read_body)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.post("http://www.example.org/", content=b"example")
@ -261,10 +269,8 @@ async def test_asgi_disconnect_after_response_complete(streaming):
@pytest.mark.anyio
async def test_asgi_exc_no_raise(streaming):
transport = httpx.ASGITransport(
app=raise_exc, raise_app_exceptions=False, streaming=streaming
)
async def test_asgi_exc_no_raise(asgi_transport_class):
transport = asgi_transport_class(app=raise_exc, raise_app_exceptions=False)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
@ -272,11 +278,10 @@ async def test_asgi_exc_no_raise(streaming):
@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response_start(streaming):
transport = httpx.ASGITransport(
async def test_asgi_exc_no_raise_after_response_start(asgi_transport_class):
transport = asgi_transport_class(
app=raise_exc_after_response_start,
raise_app_exceptions=False,
streaming=streaming,
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
@ -285,9 +290,9 @@ async def test_asgi_exc_no_raise_after_response_start(streaming):
@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response(streaming):
async def test_asgi_exc_no_raise_after_response(asgi_transport_class):
transport = httpx.ASGITransport(
app=raise_exc_after_response, raise_app_exceptions=False, streaming=streaming
app=raise_exc_after_response, raise_app_exceptions=False
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
@ -313,7 +318,7 @@ async def test_asgi_app_runs_in_same_context_as_caller():
await send({"type": "http.response.body", "body": output})
transport = httpx.ASGITransport(
app=set_contextvar_in_app, raise_app_exceptions=False, streaming=False
app=set_contextvar_in_app, raise_app_exceptions=False
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
@ -345,7 +350,7 @@ async def test_asgi_stream_returns_before_waiting_for_body(send_in_sub_task):
send_response_body_after_event
)
transport = httpx.ASGITransport(app=send_response_body_after_event, streaming=True)
transport = httpx.ASGIStreamingTransport(app=send_response_body_after_event)
async with httpx.AsyncClient(transport=transport) as client:
with anyio.fail_after(0.1):
async with client.stream("GET", "http://www.example.org/") as response:
@ -384,7 +389,7 @@ async def test_asgi_stream_allows_iterative_streaming(send_in_sub_task):
send_response_body_after_event
)
transport = httpx.ASGITransport(app=send_response_body_after_event, streaming=True)
transport = httpx.ASGIStreamingTransport(app=send_response_body_after_event)
async with httpx.AsyncClient(transport=transport) as client:
with anyio.fail_after(0.1):
async with client.stream("GET", "http://www.example.org/") as response: