Compare commits
5 Commits
main
...
shutdown-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f7a1c78f0a | ||
|
|
0a152cd767 | ||
|
|
ea71de62b8 | ||
|
|
7856fcae4d | ||
|
|
e68af2919a |
@ -739,6 +739,76 @@ async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_shutdown_during_streaming_sends_disconnect(http_protocol_cls: type[HTTPProtocol]):
|
||||
"""When the server shuts down during an SSE/streaming response,
|
||||
receive() should return http.disconnect so the ASGI app can stop."""
|
||||
got_disconnect_event = False
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
nonlocal got_disconnect_event
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/event-stream")],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True})
|
||||
|
||||
# This simulates an SSE app waiting for disconnect
|
||||
message = await receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
got_disconnect_event = True
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
# Trigger server shutdown while the app is streaming
|
||||
protocol.shutdown() # type: ignore[attr-defined]
|
||||
await protocol.loop.run_one()
|
||||
assert got_disconnect_event
|
||||
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
|
||||
assert b"data: hello" in protocol.transport.buffer
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_shutdown_during_streaming_allows_send_before_exit(http_protocol_cls: type[HTTPProtocol]):
|
||||
"""During server shutdown, the app should still be able to send() data
|
||||
(e.g., a farewell SSE event) before returning."""
|
||||
farewell_sent = False
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
nonlocal farewell_sent
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [
|
||||
(b"content-type", b"text/event-stream"),
|
||||
(b"transfer-encoding", b"chunked"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True})
|
||||
|
||||
# Wait for disconnect
|
||||
message = await receive()
|
||||
assert message["type"] == "http.disconnect"
|
||||
|
||||
# Send a farewell event — this should still work since the transport is open
|
||||
await send({"type": "http.response.body", "body": b"data: goodbye\n\n", "more_body": True})
|
||||
farewell_sent = True
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
protocol.shutdown() # type: ignore[attr-defined]
|
||||
await protocol.loop.run_one()
|
||||
assert farewell_sent
|
||||
assert b"data: hello" in protocol.transport.buffer
|
||||
assert b"data: goodbye" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
body = b""
|
||||
|
||||
@ -342,6 +342,8 @@ class H11Protocol(asyncio.Protocol):
|
||||
self.transport.close()
|
||||
else:
|
||||
self.cycle.keep_alive = False
|
||||
self.cycle.shutting_down = True
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
@ -395,6 +397,7 @@ class RequestResponseCycle:
|
||||
self.disconnected = False
|
||||
self.keep_alive = True
|
||||
self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
|
||||
self.shutting_down = False
|
||||
|
||||
# Request state
|
||||
self.body = b""
|
||||
@ -427,8 +430,9 @@ class RequestResponseCycle:
|
||||
self.logger.error(msg)
|
||||
await self.send_500_response()
|
||||
elif not self.response_complete and not self.disconnected:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
if not self.shutting_down:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
self.transport.close()
|
||||
finally:
|
||||
self.on_response = lambda: None
|
||||
@ -533,12 +537,12 @@ class RequestResponseCycle:
|
||||
self.transport.write(output)
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
if not self.disconnected and not self.response_complete:
|
||||
if not self.disconnected and not self.response_complete and not self.shutting_down:
|
||||
self.flow.resume_reading()
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
if self.disconnected or self.response_complete:
|
||||
if self.disconnected or self.response_complete or self.shutting_down:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
message: HTTPRequestEvent = {
|
||||
|
||||
@ -348,6 +348,8 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.transport.close()
|
||||
else:
|
||||
self.cycle.keep_alive = False
|
||||
self.cycle.shutting_down = True
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
@ -399,6 +401,7 @@ class RequestResponseCycle:
|
||||
self.disconnected = False
|
||||
self.keep_alive = keep_alive
|
||||
self.waiting_for_100_continue = expect_100_continue
|
||||
self.shutting_down = False
|
||||
|
||||
# Request state
|
||||
self.body = b""
|
||||
@ -433,8 +436,9 @@ class RequestResponseCycle:
|
||||
self.logger.error(msg)
|
||||
await self.send_500_response()
|
||||
elif not self.response_complete and not self.disconnected:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
if not self.shutting_down:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
self.transport.close()
|
||||
finally:
|
||||
self.on_response = lambda: None
|
||||
@ -565,12 +569,12 @@ class RequestResponseCycle:
|
||||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
if not self.disconnected and not self.response_complete:
|
||||
if not self.disconnected and not self.response_complete and not self.shutting_down:
|
||||
self.flow.resume_reading()
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
if self.disconnected or self.response_complete:
|
||||
if self.disconnected or self.response_complete or self.shutting_down:
|
||||
return {"type": "http.disconnect"}
|
||||
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
|
||||
self.body = b""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user