diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 6068d715..1ae45f40 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -775,6 +775,76 @@ async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]): assert protocol.transport.is_closing() +async def test_shutdown_during_streaming_sends_disconnect(http_protocol_cls: type[HTTPProtocol]): + """When the server shuts down during an SSE/streaming response, + receive() should return http.disconnect so the ASGI app can stop.""" + got_disconnect_event = False + + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + nonlocal got_disconnect_event + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"text/event-stream")], + } + ) + await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True}) + + # This simulates an SSE app waiting for disconnect + message = await receive() + if message["type"] == "http.disconnect": + got_disconnect_event = True + + protocol = get_connected_protocol(app, http_protocol_cls) + protocol.data_received(SIMPLE_GET_REQUEST) + # Trigger server shutdown while the app is streaming + protocol.shutdown() # type: ignore[attr-defined] + await protocol.loop.run_one() + assert got_disconnect_event + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"data: hello" in protocol.transport.buffer + assert protocol.transport.is_closing() + + +async def test_shutdown_during_streaming_allows_send_before_exit(http_protocol_cls: type[HTTPProtocol]): + """During server shutdown, the app should still be able to send() data + (e.g., a farewell SSE event) before returning.""" + farewell_sent = False + + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + nonlocal farewell_sent + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/event-stream"), + (b"transfer-encoding", b"chunked"), + ], + } + ) + await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True}) + + # Wait for disconnect + message = await receive() + assert message["type"] == "http.disconnect" + + # Send a farewell event — this should still work since the transport is open + await send({"type": "http.response.body", "body": b"data: goodbye\n\n", "more_body": True}) + farewell_sent = True + + protocol = get_connected_protocol(app, http_protocol_cls) + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.shutdown() # type: ignore[attr-defined] + await protocol.loop.run_one() + assert farewell_sent + assert b"data: hello" in protocol.transport.buffer + assert b"data: goodbye" in protocol.transport.buffer + + async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): body = b"" diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index dd98269e..2ad140c8 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -344,6 +344,8 @@ class H11Protocol(asyncio.Protocol): self.transport.close() else: self.cycle.keep_alive = False + self.cycle.shutting_down = True + self.cycle.message_event.set() def pause_writing(self) -> None: """ @@ -397,6 +399,7 @@ class RequestResponseCycle: self.disconnected = False self.keep_alive = True self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue + self.shutting_down = False # Request state self.body = bytearray() @@ -429,8 +432,9 @@ class RequestResponseCycle: self.logger.error(msg) await self.send_500_response() elif not self.response_complete and not self.disconnected: - msg = "ASGI callable returned without completing response." - self.logger.error(msg) + if not self.shutting_down: + msg = "ASGI callable returned without completing response." + self.logger.error(msg) self.transport.close() finally: self.on_response = lambda: None @@ -528,12 +532,12 @@ class RequestResponseCycle: self.transport.write(output) self.waiting_for_100_continue = False - if not self.disconnected and not self.response_complete: + if not self.disconnected and not self.response_complete and not self.shutting_down: self.flow.resume_reading() await self.message_event.wait() self.message_event.clear() - if self.disconnected or self.response_complete: + if self.disconnected or self.response_complete or self.shutting_down: return {"type": "http.disconnect"} message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body} diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 3aa354cc..a23c2fb6 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -349,6 +349,8 @@ class HttpToolsProtocol(asyncio.Protocol): self.transport.close() else: self.cycle.keep_alive = False + self.cycle.shutting_down = True + self.cycle.message_event.set() def pause_writing(self) -> None: """ @@ -400,6 +402,7 @@ class RequestResponseCycle: self.disconnected = False self.keep_alive = keep_alive self.waiting_for_100_continue = expect_100_continue + self.shutting_down = False # Request state self.body = bytearray() @@ -434,8 +437,9 @@ class RequestResponseCycle: self.logger.error(msg) await self.send_500_response() elif not self.response_complete and not self.disconnected: - msg = "ASGI callable returned without completing response." - self.logger.error(msg) + if not self.shutting_down: + msg = "ASGI callable returned without completing response." + self.logger.error(msg) self.transport.close() finally: self.on_response = lambda: None @@ -560,12 +564,12 @@ class RequestResponseCycle: self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.waiting_for_100_continue = False - if not self.disconnected and not self.response_complete: + if not self.disconnected and not self.response_complete and not self.shutting_down: self.flow.resume_reading() await self.message_event.wait() self.message_event.clear() - if self.disconnected or self.response_complete: + if self.disconnected or self.response_complete or self.shutting_down: return {"type": "http.disconnect"} message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body} self.body = bytearray()