This commit is contained in:
Jean Hominal 2026-02-23 13:58:30 -03:00 committed by GitHub
commit f428f07497
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 422 additions and 29 deletions

View File

@ -30,6 +30,7 @@ __all__ = [
"__description__",
"__title__",
"__version__",
"ASGIStreamingTransport",
"ASGITransport",
"AsyncBaseTransport",
"AsyncByteStream",

View File

@ -5,6 +5,7 @@ from .mock import *
from .wsgi import *
__all__ = [
"ASGIStreamingTransport",
"ASGITransport",
"AsyncBaseTransport",
"BaseTransport",

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
import typing
from .._models import Request, Response
@ -9,21 +10,30 @@ from .base import AsyncBaseTransport
if typing.TYPE_CHECKING: # pragma: no cover
import asyncio
import anyio.abc
import anyio.streams.memory
import trio
Event = typing.Union[asyncio.Event, trio.Event]
MessageReceiveStream = typing.Union[
anyio.streams.memory.MemoryObjectReceiveStream["_Message"],
trio.MemoryReceiveChannel["_Message"],
]
MessageSendStream = typing.Union[
anyio.streams.memory.MemoryObjectSendStream["_Message"],
trio.MemorySendChannel["_Message"],
]
TaskGroup = typing.Union[anyio.abc.TaskGroup, trio.Nursery]
_Message = typing.MutableMapping[str, typing.Any]
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
_Send = typing.Callable[
[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
]
_Send = typing.Callable[[_Message], typing.Awaitable[None]]
_ASGIApp = typing.Callable[
[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
]
__all__ = ["ASGITransport"]
__all__ = ["ASGITransport", "ASGIStreamingTransport"]
def is_running_trio() -> bool:
@ -52,6 +62,41 @@ def create_event() -> Event:
return asyncio.Event()
def create_memory_object_stream(
max_buffer_size: float,
) -> tuple[MessageSendStream, MessageReceiveStream]:
if is_running_trio():
import trio
return trio.open_memory_channel(max_buffer_size)
import anyio
return anyio.create_memory_object_stream(max_buffer_size)
def create_task_group() -> typing.AsyncContextManager[TaskGroup]:
if is_running_trio():
import trio
return trio.open_nursery()
import anyio
return anyio.create_task_group()
def get_end_of_stream_error_type() -> type[anyio.EndOfStream | trio.EndOfChannel]:
if is_running_trio():
import trio
return trio.EndOfChannel
import anyio
return anyio.EndOfStream
class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: list[bytes]) -> None:
self._body = body
@ -185,3 +230,171 @@ class ASGITransport(AsyncBaseTransport):
stream = ASGIResponseStream(body_parts)
return Response(status_code, headers=response_headers, stream=stream)
class ASGIStreamingResponseStream(AsyncByteStream):
def __init__(
self,
ignore_body: bool,
asgi_generator: typing.AsyncGenerator[_Message, None],
) -> None:
self._ignore_body = ignore_body
self._asgi_generator = asgi_generator
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
more_body = True
try:
async for message in self._asgi_generator:
assert message["type"] != "http.response.start"
if message["type"] == "http.response.body":
assert more_body
chunk = message.get("body", b"")
more_body = message.get("more_body", False)
if chunk and not self._ignore_body:
yield chunk
finally:
await self.aclose()
async def aclose(self) -> None:
await self._asgi_generator.aclose()
class ASGIStreamingTransport(AsyncBaseTransport):
"""
An equivalent of ASGITransport that operates by running app in a sub-task and
streaming response events as soon as they arrive.
It is used in the same way, with the same arguments having the same signification,
as ASGITransport.
The main observable differences between the two implementations will be as follows:
* As the application callable is invoked in a sub-task, any context variables that
are set by the app will not propagate to the caller;
* The streaming mode of operation means that a response will generally be returned
to the AsyncClient caller before the application has fully run;
Arguments:
* `app` - The ASGI application.
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
should be raised. Default to `True`. Can be set to `False` for use cases
such as testing the content of a client 500 response.
* `root_path` - The root path on which the ASGI application should be mounted.
* `client` - A two-tuple indicating the client IP and port of incoming requests.
```
"""
def __init__(
self,
app: _ASGIApp,
raise_app_exceptions: bool = True,
root_path: str = "",
client: tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
async def handle_async_request(
self,
request: Request,
) -> Response:
asgi_generator = self._stream_asgi_messages(request)
async for message in asgi_generator:
if message["type"] == "http.response.start":
return Response(
status_code=message["status"],
headers=message.get("headers", []),
stream=ASGIStreamingResponseStream(
ignore_body=request.method == "HEAD",
asgi_generator=asgi_generator,
),
)
else:
return Response(status_code=500, headers=[])
async def _stream_asgi_messages(
self, request: Request
) -> typing.AsyncGenerator[typing.MutableMapping[str, typing.Any]]:
assert isinstance(request.stream, AsyncByteStream)
# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path.split(b"?")[0],
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
"root_path": self.root_path,
}
# Request.
request_body_chunks = request.stream.__aiter__()
request_complete = False
# Response.
response_complete = create_event()
# ASGI response messages stream
response_message_send_stream, response_message_recv_stream = (
create_memory_object_stream(0)
)
# ASGI app exception
app_exception: Exception | None = None
# ASGI callables.
async def receive() -> _Message:
nonlocal request_complete
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}
try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: _Message) -> None:
await response_message_send_stream.send(message)
if message["type"] == "http.response.body" and not message.get(
"more_body", False
):
response_complete.set()
async def run_app() -> None:
nonlocal app_exception
try:
await self.app(scope, receive, send)
except Exception as ex:
app_exception = ex
finally:
await response_message_send_stream.aclose()
async with contextlib.AsyncExitStack() as exit_stack:
exit_stack.callback(response_complete.set)
task_group = await exit_stack.enter_async_context(create_task_group())
task_group.start_soon(run_app)
async with response_message_recv_stream:
try:
while True:
message = await response_message_recv_stream.receive()
yield message
except get_end_of_stream_error_type():
pass
if app_exception is not None and self.raise_app_exceptions:
raise app_exception

View File

@ -1,9 +1,43 @@
import json
from __future__ import annotations
import contextvars
import json
import typing
import anyio
import pytest
import httpx
test_asgi_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
"test_asgi_contextvar"
)
@pytest.fixture(
params=[httpx.ASGITransport, httpx.ASGIStreamingTransport],
ids=["ASGITransport", "ASGIStreamingTransport"],
)
def asgi_transport_class(
request: pytest.FixtureRequest,
) -> type[typing.Union[httpx.ASGITransport, httpx.ASGIStreamingTransport]]:
return typing.cast(
type[typing.Union[httpx.ASGITransport, httpx.ASGIStreamingTransport]],
request.param,
)
def run_in_task_group(
app: typing.Callable[..., typing.Awaitable[None]],
) -> typing.Callable[..., typing.Awaitable[None]]:
"""A decorator that runs an ASGI callable in a task group"""
async def wrapped_app(*args):
async with anyio.create_task_group() as task_group:
task_group.start_soon(app, *args)
return wrapped_app
async def hello_world(scope, receive, send):
status = 200
@ -60,6 +94,15 @@ async def raise_exc(scope, receive, send):
raise RuntimeError()
async def raise_exc_after_response_start(scope, receive, send):
status = 200
output = b"Hello, World!"
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]
await send({"type": "http.response.start", "status": status, "headers": headers})
raise RuntimeError()
async def raise_exc_after_response(scope, receive, send):
status = 200
output = b"Hello, World!"
@ -71,8 +114,8 @@ async def raise_exc_after_response(scope, receive, send):
@pytest.mark.anyio
async def test_asgi_transport():
async with httpx.ASGITransport(app=hello_world) as transport:
async def test_asgi_transport(asgi_transport_class):
async with asgi_transport_class(app=hello_world) as transport:
request = httpx.Request("GET", "http://www.example.com/")
response = await transport.handle_async_request(request)
await response.aread()
@ -81,8 +124,8 @@ async def test_asgi_transport():
@pytest.mark.anyio
async def test_asgi_transport_no_body():
async with httpx.ASGITransport(app=echo_body) as transport:
async def test_asgi_transport_no_body(asgi_transport_class):
async with asgi_transport_class(app=echo_body) as transport:
request = httpx.Request("GET", "http://www.example.com/")
response = await transport.handle_async_request(request)
await response.aread()
@ -91,8 +134,8 @@ async def test_asgi_transport_no_body():
@pytest.mark.anyio
async def test_asgi():
transport = httpx.ASGITransport(app=hello_world)
async def test_asgi(asgi_transport_class):
transport = asgi_transport_class(app=hello_world)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
@ -101,8 +144,8 @@ async def test_asgi():
@pytest.mark.anyio
async def test_asgi_urlencoded_path():
transport = httpx.ASGITransport(app=echo_path)
async def test_asgi_urlencoded_path(asgi_transport_class):
transport = asgi_transport_class(app=echo_path)
async with httpx.AsyncClient(transport=transport) as client:
url = httpx.URL("http://www.example.org/").copy_with(path="/user@example.org")
response = await client.get(url)
@ -112,8 +155,8 @@ async def test_asgi_urlencoded_path():
@pytest.mark.anyio
async def test_asgi_raw_path():
transport = httpx.ASGITransport(app=echo_raw_path)
async def test_asgi_raw_path(asgi_transport_class):
transport = asgi_transport_class(app=echo_raw_path)
async with httpx.AsyncClient(transport=transport) as client:
url = httpx.URL("http://www.example.org/").copy_with(path="/user@example.org")
response = await client.get(url)
@ -123,11 +166,13 @@ async def test_asgi_raw_path():
@pytest.mark.anyio
async def test_asgi_raw_path_should_not_include_querystring_portion():
async def test_asgi_raw_path_should_not_include_querystring_portion(
asgi_transport_class,
):
"""
See https://github.com/encode/httpx/issues/2810
"""
transport = httpx.ASGITransport(app=echo_raw_path)
transport = asgi_transport_class(app=echo_raw_path)
async with httpx.AsyncClient(transport=transport) as client:
url = httpx.URL("http://www.example.org/path?query")
response = await client.get(url)
@ -137,8 +182,8 @@ async def test_asgi_raw_path_should_not_include_querystring_portion():
@pytest.mark.anyio
async def test_asgi_upload():
transport = httpx.ASGITransport(app=echo_body)
async def test_asgi_upload(asgi_transport_class):
transport = asgi_transport_class(app=echo_body)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.post("http://www.example.org/", content=b"example")
@ -147,8 +192,8 @@ async def test_asgi_upload():
@pytest.mark.anyio
async def test_asgi_headers():
transport = httpx.ASGITransport(app=echo_headers)
async def test_asgi_headers(asgi_transport_class):
transport = asgi_transport_class(app=echo_headers)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
@ -165,23 +210,31 @@ async def test_asgi_headers():
@pytest.mark.anyio
async def test_asgi_exc():
transport = httpx.ASGITransport(app=raise_exc)
async def test_asgi_exc(asgi_transport_class):
transport = asgi_transport_class(app=raise_exc)
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")
@pytest.mark.anyio
async def test_asgi_exc_after_response():
transport = httpx.ASGITransport(app=raise_exc_after_response)
async def test_asgi_exc_after_response_start(asgi_transport_class):
transport = asgi_transport_class(app=raise_exc_after_response_start)
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")
@pytest.mark.anyio
async def test_asgi_disconnect_after_response_complete():
async def test_asgi_exc_after_response(asgi_transport_class):
transport = asgi_transport_class(app=raise_exc_after_response)
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")
@pytest.mark.anyio
async def test_asgi_disconnect_after_response_complete(asgi_transport_class):
disconnect = False
async def read_body(scope, receive, send):
@ -207,7 +260,7 @@ async def test_asgi_disconnect_after_response_complete():
message = await receive()
disconnect = message.get("type") == "http.disconnect"
transport = httpx.ASGITransport(app=read_body)
transport = asgi_transport_class(app=read_body)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.post("http://www.example.org/", content=b"example")
@ -216,9 +269,134 @@ async def test_asgi_disconnect_after_response_complete():
@pytest.mark.anyio
async def test_asgi_exc_no_raise():
transport = httpx.ASGITransport(app=raise_exc, raise_app_exceptions=False)
async def test_asgi_exc_no_raise(asgi_transport_class):
transport = asgi_transport_class(app=raise_exc, raise_app_exceptions=False)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
assert response.status_code == 500
@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response_start(asgi_transport_class):
transport = asgi_transport_class(
app=raise_exc_after_response_start,
raise_app_exceptions=False,
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
assert response.status_code == 200
@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response(asgi_transport_class):
transport = httpx.ASGITransport(
app=raise_exc_after_response, raise_app_exceptions=False
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
assert response.status_code == 200
@pytest.mark.anyio
async def test_asgi_app_runs_in_same_context_as_caller():
async def set_contextvar_in_app(scope, receive, send):
test_asgi_contextvar.set("value_from_app")
status = 200
output = b"Hello, World!"
headers = [
(b"content-type", "text/plain"),
(b"content-length", str(len(output))),
]
await send(
{"type": "http.response.start", "status": status, "headers": headers}
)
await send({"type": "http.response.body", "body": output})
transport = httpx.ASGITransport(
app=set_contextvar_in_app, raise_app_exceptions=False
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")
assert response.status_code == 200
assert test_asgi_contextvar.get(None) == "value_from_app"
@pytest.mark.parametrize(
"send_in_sub_task",
[pytest.param(False, id="no_sub_task"), pytest.param(True, id="with_sub_task")],
)
@pytest.mark.anyio
async def test_asgi_stream_returns_before_waiting_for_body(send_in_sub_task):
start_response_body = anyio.Event()
async def send_response_body_after_event(scope, receive, send):
status = 200
headers = [(b"content-type", b"text/plain")]
await send(
{"type": "http.response.start", "status": status, "headers": headers}
)
await start_response_body.wait()
await send({"type": "http.response.body", "body": b"body", "more_body": False})
if send_in_sub_task:
send_response_body_after_event = run_in_task_group(
send_response_body_after_event
)
transport = httpx.ASGIStreamingTransport(app=send_response_body_after_event)
async with httpx.AsyncClient(transport=transport) as client:
with anyio.fail_after(0.1):
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
start_response_body.set()
await response.aread()
assert response.text == "body"
@pytest.mark.parametrize(
"send_in_sub_task",
[pytest.param(False, id="no_sub_task"), pytest.param(True, id="with_sub_task")],
)
@pytest.mark.anyio
async def test_asgi_stream_allows_iterative_streaming(send_in_sub_task):
stream_events = [anyio.Event() for i in range(4)]
async def send_response_body_after_event(scope, receive, send):
status = 200
headers = [(b"content-type", b"text/plain")]
await send(
{"type": "http.response.start", "status": status, "headers": headers}
)
for e in stream_events:
await e.wait()
await send(
{
"type": "http.response.body",
"body": b"chunk",
"more_body": e is not stream_events[-1],
}
)
if send_in_sub_task:
send_response_body_after_event = run_in_task_group(
send_response_body_after_event
)
transport = httpx.ASGIStreamingTransport(app=send_response_body_after_event)
async with httpx.AsyncClient(transport=transport) as client:
with anyio.fail_after(0.1):
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
iterator = response.aiter_raw()
for e in stream_events:
e.set()
assert await iterator.__anext__() == b"chunk"
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()