Compare commits
2 Commits
main
...
add-traile
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23ef7d5ddb | ||
|
|
6cf4765b40 |
@ -171,6 +171,48 @@ async def app(scope, receive, send):
|
||||
})
|
||||
```
|
||||
|
||||
### Sending trailers
|
||||
|
||||
HTTP trailers are additional headers sent after the response body. Uvicorn supports the
|
||||
[HTTP Trailers ASGI extension][http-trailers]. To send trailers, set `"trailers": True` on the
|
||||
`http.response.start` message, then send one or more `http.response.trailers` messages after the
|
||||
body completes. Set `more_trailers` to `False` on the last trailers message.
|
||||
|
||||
Trailers are only emitted on the wire when the client sends a `TE: trailers` request header. When
|
||||
the client does not advertise support, trailers sent by the application are silently dropped.
|
||||
|
||||
Applications should also announce the trailer field names in advance via the `Trailer` response
|
||||
header, per [RFC 7230][rfc-7230-trailer].
|
||||
|
||||
```python
|
||||
async def app(scope, receive, send):
|
||||
assert scope['type'] == 'http'
|
||||
await send({
|
||||
'type': 'http.response.start',
|
||||
'status': 200,
|
||||
'headers': [
|
||||
[b'content-type', b'text/plain'],
|
||||
[b'trailer', b'x-app-status'],
|
||||
],
|
||||
'trailers': True,
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': b'Hello, world!',
|
||||
})
|
||||
await send({
|
||||
'type': 'http.response.trailers',
|
||||
'headers': [
|
||||
[b'x-app-status', b'ok'],
|
||||
],
|
||||
'more_trailers': False,
|
||||
})
|
||||
```
|
||||
|
||||
[rfc-7230-trailer]: https://www.rfc-editor.org/rfc/rfc7230#section-4.4
|
||||
|
||||
[http-trailers]: https://asgi.readthedocs.io/en/latest/extensions.html#http-trailers
|
||||
|
||||
---
|
||||
|
||||
## Why ASGI?
|
||||
|
||||
@ -41,6 +41,10 @@ WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys()
|
||||
|
||||
SIMPLE_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"", b""])
|
||||
|
||||
SIMPLE_GET_REQUEST_WITH_TRAILERS = b"\r\n".join(
|
||||
[b"GET / HTTP/1.1", b"Host: example.org", b"TE: trailers", b"", b""]
|
||||
)
|
||||
|
||||
SIMPLE_HEAD_REQUEST = b"\r\n".join([b"HEAD / HTTP/1.1", b"Host: example.org", b"", b""])
|
||||
|
||||
SIMPLE_POST_REQUEST = b"\r\n".join(
|
||||
@ -1224,3 +1228,181 @@ async def test_header_upgrade_is_websocket_depend_not_installed(
|
||||
assert msg in caplog.text
|
||||
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
|
||||
assert b"Hello, world" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_trailers_extension_in_scope(http_protocol_cls: type[HTTPProtocol]):
|
||||
received_scope: dict[str, Any] = {}
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
received_scope.update(scope) # type: ignore[arg-type]
|
||||
await Response("Hello, world", media_type="text/plain")(scope, receive, send)
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
await protocol.loop.run_one()
|
||||
assert "extensions" in received_scope
|
||||
assert "http.response.trailers" in received_scope["extensions"]
|
||||
|
||||
|
||||
async def test_trailers(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
"trailers": True,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Hi"})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.trailers",
|
||||
"headers": [(b"x-trailer-1", b"value-1")],
|
||||
"more_trailers": True,
|
||||
}
|
||||
)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.trailers",
|
||||
"headers": [(b"x-trailer-2", b"value-2")],
|
||||
"more_trailers": False,
|
||||
}
|
||||
)
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST_WITH_TRAILERS)
|
||||
await protocol.loop.run_one()
|
||||
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
|
||||
assert b"Hi" in protocol.transport.buffer
|
||||
assert b"x-trailer-1: value-1" in protocol.transport.buffer
|
||||
assert b"x-trailer-2: value-2" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_trailers_without_te_header_are_dropped(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
"trailers": True,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Hi"})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.trailers",
|
||||
"headers": [(b"x-trailer-1", b"value-1")],
|
||||
"more_trailers": False,
|
||||
}
|
||||
)
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
await protocol.loop.run_one()
|
||||
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
|
||||
assert b"Hi" in protocol.transport.buffer
|
||||
assert b"x-trailer-1: value-1" not in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_trailers_for_head_request_are_skipped(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain"), (b"content-length", b"0")],
|
||||
"trailers": True,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b""})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_HEAD_REQUEST)
|
||||
await protocol.loop.run_one()
|
||||
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_body_after_trailers_raises(http_protocol_cls: type[HTTPProtocol]):
|
||||
with_body_after_trailers: dict[str, bool] = {"raised": False}
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
"trailers": True,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Hi"})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.trailers",
|
||||
"headers": [(b"x-trailer-1", b"value-1")],
|
||||
"more_trailers": False,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await send({"type": "http.response.body", "body": b"oops"})
|
||||
except RuntimeError:
|
||||
with_body_after_trailers["raised"] = True
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST_WITH_TRAILERS)
|
||||
await protocol.loop.run_one()
|
||||
assert with_body_after_trailers["raised"]
|
||||
|
||||
|
||||
async def test_body_during_trailers_phase_raises(http_protocol_cls: type[HTTPProtocol]):
|
||||
raised: dict[str, bool] = {"raised": False}
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
"trailers": True,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Hi"})
|
||||
try:
|
||||
await send({"type": "http.response.body", "body": b"more"})
|
||||
except RuntimeError:
|
||||
raised["raised"] = True
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST_WITH_TRAILERS)
|
||||
await protocol.loop.run_one()
|
||||
assert raised["raised"]
|
||||
|
||||
|
||||
async def test_trailers_with_close_connection(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
"trailers": True,
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Hi"})
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.trailers",
|
||||
"headers": [(b"x-trailer-1", b"value-1")],
|
||||
"more_trailers": False,
|
||||
}
|
||||
)
|
||||
|
||||
request = b"\r\n".join(
|
||||
[b"GET / HTTP/1.1", b"Host: example.org", b"Connection: close", b"TE: trailers", b"", b""]
|
||||
)
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(request)
|
||||
await protocol.loop.run_one()
|
||||
assert b"x-trailer-1: value-1" in protocol.transport.buffer
|
||||
assert protocol.transport.closed
|
||||
|
||||
@ -215,7 +215,12 @@ class H11Protocol(asyncio.Protocol):
|
||||
"query_string": query_string,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"http.response.trailers": {}},
|
||||
}
|
||||
expect_trailers = any(
|
||||
name == b"te" and b"trailers" in [v.strip() for v in value.lower().split(b",")]
|
||||
for name, value in self.headers
|
||||
)
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade(event)
|
||||
return
|
||||
@ -248,6 +253,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
access_log=self.access_log,
|
||||
default_headers=self.server_state.default_headers,
|
||||
message_event=asyncio.Event(),
|
||||
expect_trailers=expect_trailers,
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
# For the asyncio loop, we need to explicitly start with an empty context
|
||||
@ -382,6 +388,7 @@ class RequestResponseCycle:
|
||||
access_log: bool,
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
expect_trailers: bool,
|
||||
on_response: Callable[..., None],
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
@ -399,6 +406,7 @@ class RequestResponseCycle:
|
||||
self.disconnected = False
|
||||
self.keep_alive = True
|
||||
self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
|
||||
self.expect_trailers = expect_trailers
|
||||
self.shutting_down = False
|
||||
|
||||
# Request state
|
||||
@ -408,6 +416,8 @@ class RequestResponseCycle:
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
self.send_trailers = False
|
||||
self.trailers: list[tuple[bytes, bytes]] = []
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
@ -474,6 +484,7 @@ class RequestResponseCycle:
|
||||
|
||||
status = message["status"]
|
||||
headers = self.default_headers + list(message.get("headers", []))
|
||||
self.send_trailers = message.get("trailers", False) and self.scope["method"] != "HEAD"
|
||||
|
||||
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
|
||||
headers = headers + [CLOSE_HEADER]
|
||||
@ -511,14 +522,30 @@ class RequestResponseCycle:
|
||||
if not more_body:
|
||||
self.response_complete = True
|
||||
self.message_event.set()
|
||||
output = self.conn.send(event=h11.EndOfMessage())
|
||||
if not self.send_trailers:
|
||||
output = self.conn.send(event=h11.EndOfMessage())
|
||||
self.transport.write(output)
|
||||
|
||||
elif self.send_trailers:
|
||||
# Sending response trailers
|
||||
if message["type"] != "http.response.trailers":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.trailers', but got '{message['type']}'.")
|
||||
|
||||
self.trailers.extend(message.get("headers", []))
|
||||
more_trailers = message.get("more_trailers", False)
|
||||
|
||||
if not more_trailers:
|
||||
self.send_trailers = False
|
||||
# h11 emits trailers only if the client advertised TE: trailers.
|
||||
trailers = self.trailers if self.expect_trailers else []
|
||||
output = self.conn.send(event=h11.EndOfMessage(headers=trailers))
|
||||
self.transport.write(output)
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}' sent, after response already completed.")
|
||||
|
||||
if self.response_complete:
|
||||
if self.response_complete and not self.send_trailers:
|
||||
if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive:
|
||||
self.conn.send(event=h11.ConnectionClosed())
|
||||
self.transport.close()
|
||||
|
||||
@ -94,6 +94,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.expect_100_continue = False
|
||||
self.expect_trailers = False
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
@ -221,6 +222,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
def on_message_begin(self) -> None:
|
||||
self.url = b""
|
||||
self.expect_100_continue = False
|
||||
self.expect_trailers = False
|
||||
self.headers = []
|
||||
self.scope = { # type: ignore[typeddict-item]
|
||||
"type": "http",
|
||||
@ -232,6 +234,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
"root_path": self.root_path,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"http.response.trailers": {}},
|
||||
}
|
||||
|
||||
# Parser callbacks
|
||||
@ -242,6 +245,8 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
name = name.lower()
|
||||
if name == b"expect" and value.lower() == b"100-continue":
|
||||
self.expect_100_continue = True
|
||||
if name == b"te" and b"trailers" in [v.strip() for v in value.lower().split(b",")]:
|
||||
self.expect_trailers = True
|
||||
self.headers.append((name, value))
|
||||
|
||||
def on_headers_complete(self) -> None:
|
||||
@ -284,6 +289,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
default_headers=self.server_state.default_headers,
|
||||
message_event=asyncio.Event(),
|
||||
expect_100_continue=self.expect_100_continue,
|
||||
expect_trailers=self.expect_trailers,
|
||||
keep_alive=http_version != "1.0",
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
@ -385,6 +391,7 @@ class RequestResponseCycle:
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
expect_100_continue: bool,
|
||||
expect_trailers: bool,
|
||||
keep_alive: bool,
|
||||
on_response: Callable[..., None],
|
||||
):
|
||||
@ -402,6 +409,7 @@ class RequestResponseCycle:
|
||||
self.disconnected = False
|
||||
self.keep_alive = keep_alive
|
||||
self.waiting_for_100_continue = expect_100_continue
|
||||
self.expect_trailers = expect_trailers
|
||||
self.shutting_down = False
|
||||
|
||||
# Request state
|
||||
@ -411,6 +419,7 @@ class RequestResponseCycle:
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
self.send_trailers = False
|
||||
self.chunked_encoding: bool | None = None
|
||||
self.expected_content_length = 0
|
||||
|
||||
@ -476,6 +485,7 @@ class RequestResponseCycle:
|
||||
|
||||
status_code = message["status"]
|
||||
headers = self.default_headers + list(message.get("headers", []))
|
||||
self.send_trailers = message.get("trailers", False) and self.scope["method"] != "HEAD"
|
||||
|
||||
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
|
||||
headers = headers + [CLOSE_HEADER]
|
||||
@ -535,7 +545,10 @@ class RequestResponseCycle:
|
||||
else:
|
||||
content = []
|
||||
if not more_body:
|
||||
content.append(b"0\r\n\r\n")
|
||||
if self.send_trailers:
|
||||
content.append(b"0\r\n")
|
||||
else:
|
||||
content.append(b"0\r\n\r\n")
|
||||
self.transport.write(b"".join(content))
|
||||
else:
|
||||
num_bytes = len(body)
|
||||
@ -551,6 +564,37 @@ class RequestResponseCycle:
|
||||
raise RuntimeError("Response content shorter than Content-Length")
|
||||
self.response_complete = True
|
||||
self.message_event.set()
|
||||
if not self.send_trailers:
|
||||
if not self.keep_alive:
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
elif self.send_trailers:
|
||||
# Sending response trailers
|
||||
if message["type"] != "http.response.trailers":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.trailers', but got '{message['type']}'.")
|
||||
|
||||
trailers = list(message.get("headers", []))
|
||||
more_trailers = message.get("more_trailers", False)
|
||||
content = []
|
||||
|
||||
for name, value in trailers:
|
||||
if HEADER_RE.search(name):
|
||||
raise RuntimeError("Invalid HTTP header name.") # pragma: no cover
|
||||
if HEADER_VALUE_RE.search(value):
|
||||
raise RuntimeError("Invalid HTTP header value.") # pragma: no cover
|
||||
name = name.lower()
|
||||
content.extend([name, b": ", value, b"\r\n"])
|
||||
|
||||
if not more_trailers:
|
||||
content.append(b"\r\n")
|
||||
|
||||
# Server should only send if the client sent a TE: trailers header.
|
||||
if self.expect_trailers:
|
||||
self.transport.write(b"".join(content))
|
||||
|
||||
if not more_trailers:
|
||||
self.send_trailers = False
|
||||
if not self.keep_alive:
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user