🐛 Emit http.disconnect ASGI receive() event on server shutting down for streaming responses (#2829)

This commit is contained in:
Sebastián Ramírez 2026-04-03 16:23:03 +02:00 committed by GitHub
parent c9a75fb67b
commit 587042d68f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 8 deletions

View File

@ -775,6 +775,76 @@ async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]):
assert protocol.transport.is_closing()
async def test_shutdown_during_streaming_sends_disconnect(http_protocol_cls: type[HTTPProtocol]):
"""When the server shuts down during an SSE/streaming response,
receive() should return http.disconnect so the ASGI app can stop."""
got_disconnect_event = False
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal got_disconnect_event
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/event-stream")],
}
)
await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True})
# This simulates an SSE app waiting for disconnect
message = await receive()
if message["type"] == "http.disconnect":
got_disconnect_event = True
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
# Trigger server shutdown while the app is streaming
protocol.shutdown() # type: ignore[attr-defined]
await protocol.loop.run_one()
assert got_disconnect_event
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"data: hello" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_shutdown_during_streaming_allows_send_before_exit(http_protocol_cls: type[HTTPProtocol]):
"""During server shutdown, the app should still be able to send() data
(e.g., a farewell SSE event) before returning."""
farewell_sent = False
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal farewell_sent
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/event-stream"),
(b"transfer-encoding", b"chunked"),
],
}
)
await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True})
# Wait for disconnect
message = await receive()
assert message["type"] == "http.disconnect"
# Send a farewell event — this should still work since the transport is open
await send({"type": "http.response.body", "body": b"data: goodbye\n\n", "more_body": True})
farewell_sent = True
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.shutdown() # type: ignore[attr-defined]
await protocol.loop.run_one()
assert farewell_sent
assert b"data: hello" in protocol.transport.buffer
assert b"data: goodbye" in protocol.transport.buffer
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""

View File

@ -344,6 +344,8 @@ class H11Protocol(asyncio.Protocol):
self.transport.close()
else:
self.cycle.keep_alive = False
self.cycle.shutting_down = True
self.cycle.message_event.set()
def pause_writing(self) -> None:
"""
@ -397,6 +399,7 @@ class RequestResponseCycle:
self.disconnected = False
self.keep_alive = True
self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
self.shutting_down = False
# Request state
self.body = bytearray()
@ -429,8 +432,9 @@ class RequestResponseCycle:
self.logger.error(msg)
await self.send_500_response()
elif not self.response_complete and not self.disconnected:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
if not self.shutting_down:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
self.transport.close()
finally:
self.on_response = lambda: None
@ -528,12 +532,12 @@ class RequestResponseCycle:
self.transport.write(output)
self.waiting_for_100_continue = False
if not self.disconnected and not self.response_complete:
if not self.disconnected and not self.response_complete and not self.shutting_down:
self.flow.resume_reading()
await self.message_event.wait()
self.message_event.clear()
if self.disconnected or self.response_complete:
if self.disconnected or self.response_complete or self.shutting_down:
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body}

View File

@ -349,6 +349,8 @@ class HttpToolsProtocol(asyncio.Protocol):
self.transport.close()
else:
self.cycle.keep_alive = False
self.cycle.shutting_down = True
self.cycle.message_event.set()
def pause_writing(self) -> None:
"""
@ -400,6 +402,7 @@ class RequestResponseCycle:
self.disconnected = False
self.keep_alive = keep_alive
self.waiting_for_100_continue = expect_100_continue
self.shutting_down = False
# Request state
self.body = bytearray()
@ -434,8 +437,9 @@ class RequestResponseCycle:
self.logger.error(msg)
await self.send_500_response()
elif not self.response_complete and not self.disconnected:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
if not self.shutting_down:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
self.transport.close()
finally:
self.on_response = lambda: None
@ -560,12 +564,12 @@ class RequestResponseCycle:
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.waiting_for_100_continue = False
if not self.disconnected and not self.response_complete:
if not self.disconnected and not self.response_complete and not self.shutting_down:
self.flow.resume_reading()
await self.message_event.wait()
self.message_event.clear()
if self.disconnected or self.response_complete:
if self.disconnected or self.response_complete or self.shutting_down:
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body}
self.body = bytearray()