Compare commits

...

6 Commits

Author SHA1 Message Date
Marcelo Trylesinski
e8b99bcb4a Restore SIGINT force-exit and rename shutdown flag 2026-04-02 15:28:38 -04:00
Marcelo Trylesinski
c86c701254 Add abort phase to graceful shutdown 2026-04-02 15:16:12 -04:00
Sebastián Ramírez
0a152cd767 👷 Trigger CI 2026-02-28 17:31:23 +01:00
Sebastián Ramírez
ea71de62b8 ♻️ Make _shutting_down public (shutting_down) 2026-02-28 14:22:42 +01:00
Sebastián Ramírez
7856fcae4d 🐛 Add implementation in protocols for sending ASGI close event for streaming responses on server shutdown 2026-02-28 13:06:19 +01:00
Sebastián Ramírez
e68af2919a Add tests for closing streaming response on server termination 2026-02-28 13:05:54 +01:00
8 changed files with 202 additions and 16 deletions

View File

@ -739,6 +739,79 @@ async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]):
assert protocol.transport.is_closing()
async def test_shutdown_during_streaming_sends_disconnect(http_protocol_cls: type[HTTPProtocol]):
"""When the server shuts down during an SSE/streaming response,
receive() should return http.disconnect so the ASGI app can stop."""
got_disconnect_event = False
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal got_disconnect_event
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/event-stream")],
}
)
await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True})
# This simulates an SSE app waiting for disconnect
message = await receive()
if message["type"] == "http.disconnect":
got_disconnect_event = True
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
# Let the app start and send initial response
await asyncio.sleep(0)
# Trigger server shutdown while the app is streaming
protocol.shutdown() # type: ignore[attr-defined]
await protocol.loop.run_one()
assert got_disconnect_event
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"data: hello" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_shutdown_during_streaming_allows_send_before_exit(http_protocol_cls: type[HTTPProtocol]):
"""During server shutdown, the app should still be able to send() data
(e.g., a farewell SSE event) before returning."""
farewell_sent = False
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal farewell_sent
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/event-stream"),
(b"transfer-encoding", b"chunked"),
],
}
)
await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True})
# Wait for disconnect
message = await receive()
assert message["type"] == "http.disconnect"
# Send a farewell event — this should still work since the transport is open
await send({"type": "http.response.body", "body": b"data: goodbye\n\n", "more_body": True})
farewell_sent = True
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await asyncio.sleep(0)
protocol.shutdown() # type: ignore[attr-defined]
await protocol.loop.run_one()
assert farewell_sent
assert b"data: hello" in protocol.transport.buffer
assert b"data: goodbye" in protocol.transport.buffer
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""

View File

@ -232,3 +232,84 @@ async def test_no_contextvars_pollution_asyncio(
async with server(app=app, http_protocol_cls=http_protocol_cls, port=unused_tcp_port) as extract_json_body:
assert await extract_json_body(large_request) == {}
assert await extract_json_body(SIMPLE_GET_REQUEST) == {}
async def test_shutdown_aborts_connections_after_grace_timeout(unused_tcp_port: int):
calls: list[str] = []
class Connection:
def shutdown(self) -> None:
calls.append("shutdown")
def abort(self) -> None:
calls.append("abort")
class Lifespan:
state = {}
should_exit = False
async def shutdown(self) -> None:
calls.append("lifespan")
config = Config(app=app, port=unused_tcp_port, timeout_graceful_shutdown=0.01)
server = Server(config=config)
server.servers = []
server.lifespan = Lifespan()
server.server_state.connections.add(Connection()) # type: ignore[arg-type]
async def never_finishes() -> None:
await asyncio.sleep(1)
server._wait_tasks_to_complete = never_finishes # type: ignore[method-assign]
await server.shutdown()
assert calls == ["shutdown", "abort", "lifespan"]
async def test_shutdown_does_not_abort_connections_that_finish_in_time(unused_tcp_port: int):
calls: list[str] = []
class Connection:
def shutdown(self) -> None:
calls.append("shutdown")
def abort(self) -> None:
calls.append("abort")
class Lifespan:
state = {}
should_exit = False
async def shutdown(self) -> None:
calls.append("lifespan")
config = Config(app=app, port=unused_tcp_port, timeout_graceful_shutdown=1)
server = Server(config=config)
server.servers = []
server.lifespan = Lifespan()
connection = Connection()
server.server_state.connections.add(connection) # type: ignore[arg-type]
async def completes_in_time() -> None:
server.server_state.connections.discard(connection) # type: ignore[arg-type]
server._wait_tasks_to_complete = completes_in_time # type: ignore[method-assign]
await server.shutdown()
assert calls == ["shutdown", "lifespan"]
def test_handle_exit_sets_force_exit_on_second_signal(unused_tcp_port: int):
server = Server(Config(app=app, port=unused_tcp_port))
server.handle_exit(sig=signal.SIGTERM, frame=None)
assert server.should_exit is True
assert server.force_exit is False
server.handle_exit(sig=signal.SIGTERM, frame=None)
assert server.force_exit is False
server.handle_exit(sig=signal.SIGINT, frame=None)
assert server.force_exit is True

View File

@ -342,6 +342,11 @@ class H11Protocol(asyncio.Protocol):
self.transport.close()
else:
self.cycle.keep_alive = False
self.cycle.shutting_down = True
self.cycle.message_event.set()
def abort(self) -> None:
self.transport.close()
def pause_writing(self) -> None:
"""
@ -393,6 +398,7 @@ class RequestResponseCycle:
# Connection state
self.disconnected = False
self.shutting_down = False
self.keep_alive = True
self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
@ -427,8 +433,9 @@ class RequestResponseCycle:
self.logger.error(msg)
await self.send_500_response()
elif not self.response_complete and not self.disconnected:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
if not self.shutting_down:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
self.transport.close()
finally:
self.on_response = lambda: None
@ -533,12 +540,12 @@ class RequestResponseCycle:
self.transport.write(output)
self.waiting_for_100_continue = False
if not self.disconnected and not self.response_complete:
if not self.disconnected and not self.response_complete and not self.shutting_down:
self.flow.resume_reading()
await self.message_event.wait()
self.message_event.clear()
if self.disconnected or self.response_complete:
if self.disconnected or self.response_complete or self.shutting_down:
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {

View File

@ -348,6 +348,11 @@ class HttpToolsProtocol(asyncio.Protocol):
self.transport.close()
else:
self.cycle.keep_alive = False
self.cycle.shutting_down = True
self.cycle.message_event.set()
def abort(self) -> None:
self.transport.close()
def pause_writing(self) -> None:
"""
@ -397,6 +402,7 @@ class RequestResponseCycle:
# Connection state
self.disconnected = False
self.shutting_down = False
self.keep_alive = keep_alive
self.waiting_for_100_continue = expect_100_continue
@ -433,8 +439,9 @@ class RequestResponseCycle:
self.logger.error(msg)
await self.send_500_response()
elif not self.response_complete and not self.disconnected:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
if not self.shutting_down:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
self.transport.close()
finally:
self.on_response = lambda: None
@ -565,12 +572,12 @@ class RequestResponseCycle:
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.waiting_for_100_continue = False
if not self.disconnected and not self.response_complete:
if not self.disconnected and not self.response_complete and not self.shutting_down:
self.flow.resume_reading()
await self.message_event.wait()
self.message_event.clear()
if self.disconnected or self.response_complete:
if self.disconnected or self.response_complete or self.shutting_down:
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
self.body = b""

View File

@ -14,6 +14,7 @@ from websockets.exceptions import ConnectionClosed
from websockets.extensions.base import ServerExtensionFactory
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.legacy.server import HTTPResponse
from websockets.protocol import State
from websockets.server import WebSocketServerProtocol
from websockets.typing import Subprotocol
@ -148,9 +149,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
def shutdown(self) -> None:
self.ws_server.closing = True
if self.handshake_completed_event.is_set():
self.fail_connection(1012)
if self.state is State.OPEN:
self.loop.create_task(self.close(code=1012))
else:
self.send_500_response()
def abort(self) -> None:
self.transport.close()
def on_task_complete(self, task: asyncio.Task[None]) -> None:

View File

@ -130,12 +130,15 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
def shutdown(self) -> None:
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
self.conn.send_close(1012)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
if self.conn.close_rcvd is None and self.conn.close_sent is None:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
self.conn.send_close(1012)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
else:
self.send_500_response()
def abort(self) -> None:
self.transport.close()
def data_received(self, data: bytes) -> None:

View File

@ -150,11 +150,14 @@ class WSProtocol(asyncio.Protocol):
def shutdown(self) -> None:
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
self.transport.write(output)
if self.conn.state is ConnectionState.OPEN:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
self.transport.write(output)
else:
self.send_500_response()
def abort(self) -> None:
self.transport.close()
def on_task_complete(self, task: asyncio.Task[None]) -> None:

View File

@ -293,8 +293,16 @@ class Server:
"Cancel %s running task(s), timeout graceful shutdown exceeded",
len(self.server_state.tasks),
)
for connection in list(self.server_state.connections):
connection.abort()
for t in self.server_state.tasks:
t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
else:
if self.force_exit:
for connection in list(self.server_state.connections):
connection.abort()
for t in self.server_state.tasks:
t.cancel(msg="Task cancelled, shutdown aborted by signal")
# Send the lifespan shutdown event, and wait for application shutdown.
if not self.force_exit: