diff --git a/docs/settings.md b/docs/settings.md index 99ee455a..8c9ab8e8 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -95,8 +95,8 @@ Using Uvicorn with watchfiles will enable the following options (which are other * `--ws ` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*. * `--ws-max-size ` - Set the WebSockets max message size, in bytes. Only available with the `websockets` protocol. **Default:** *16777216* (16 MB). * `--ws-max-queue ` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*. -* `--ws-ping-interval ` - Set the WebSockets ping interval, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*. -* `--ws-ping-timeout ` - Set the WebSockets ping timeout, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*. +* `--ws-ping-interval ` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*. +* `--ws-ping-timeout ` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*. * `--ws-per-message-deflate ` - Enable/disable WebSocket per-message-deflate compression. Only available with the `websockets` protocol. **Default:** *True*. * `--lifespan ` - Set the Lifespan protocol implementation. **Options:** *'auto', 'on', 'off'.* **Default:** *'auto'*. * `--h11-max-incomplete-event-size ` - Set the maximum number of bytes to buffer of an incomplete event. Only available for `h11` HTTP protocol implementation. **Default:** *16384* (16 KB). diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index ed257a1e..1e873f64 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -27,6 +27,7 @@ from uvicorn._types import ( ) from uvicorn.config import Config from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol +from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol try: from uvicorn.protocols.websockets.wsproto_impl import WSProtocol as _WSProtocol @@ -1202,3 +1203,90 @@ async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HT assert is_open assert expected_states == actual_states + + +async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + while True: + message = await receive() + if message["type"] == "websocket.connect": + await send({"type": "websocket.accept"}) + elif message["type"] == "websocket.disconnect": + break + + config = Config( + app=app, + ws=WebSocketsSansIOProtocol, + http=http_protocol_cls, + lifespan="off", + ws_ping_interval=0.1, + ws_ping_timeout=5.0, + port=unused_tcp_port, + ) + async with run_server(config) as server: + # The websockets client auto-responds to ping frames, keeping the connection alive. + async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None): + protocol = list(server.server_state.connections)[0] + assert isinstance(protocol, WebSocketsSansIOProtocol) + + # Wait until at least one ping/pong roundtrip completes. + async def ping_roundtrip() -> None: + while protocol.last_ping_rtt == 0.0: + await asyncio.sleep(0.1) + + await asyncio.wait_for(ping_roundtrip(), timeout=5.0) + assert protocol.last_ping_rtt > 0 + + +async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + while True: + message = await receive() + if message["type"] == "websocket.connect": + await send({"type": "websocket.accept"}) + elif message["type"] == "websocket.disconnect": + break + + config = Config( + app=app, + ws=WebSocketsSansIOProtocol, + http=http_protocol_cls, + lifespan="off", + ws_ping_interval=0.1, + ws_ping_timeout=0.1, + log_level="trace", + port=unused_tcp_port, + ) + async with run_server(config): + async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None) as websocket: + # Swallow outgoing pong frames so the server's ping never gets ack'd. + websocket.transport.write = lambda data: None # type: ignore[method-assign] + with pytest.raises(websockets.exceptions.ConnectionClosedError) as exc_info: + await asyncio.wait_for(websocket.recv(), timeout=1) + assert exc_info.value.rcvd is not None + assert exc_info.value.rcvd.code == 1011 + assert exc_info.value.rcvd.reason == "keepalive ping timeout" + + +async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + while True: + message = await receive() + if message["type"] == "websocket.connect": + await send({"type": "websocket.accept"}) + elif message["type"] == "websocket.disconnect": + break + + config = Config( + app=app, + ws=WebSocketsSansIOProtocol, + http=http_protocol_cls, + lifespan="off", + ws_ping_interval=None, + port=unused_tcp_port, + ) + async with run_server(config) as server: + async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None): + protocol = list(server.server_state.connections)[0] + assert isinstance(protocol, WebSocketsSansIOProtocol) + assert protocol.ping_timer is None diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index a02baff3..bb0afe9d 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -2,7 +2,10 @@ from __future__ import annotations import asyncio import logging +import random +import struct import sys +from asyncio import TimerHandle from asyncio.transports import BaseTransport, Transport from http import HTTPStatus from typing import Any, Literal, cast @@ -92,6 +95,15 @@ class WebSocketsSansIOProtocol(asyncio.Protocol): self.writable = asyncio.Event() self.writable.set() + # Keepalive state + self.ping_interval = config.ws_ping_interval + self.ping_timeout = config.ws_ping_timeout + self.ping_timer: TimerHandle | None = None + self.pong_timer: TimerHandle | None = None + self.pending_ping_payload: bytes | None = None + self.ping_sent_at: float = 0.0 + self.last_ping_rtt: float = 0.0 + # Buffers self.bytes = b"" @@ -109,6 +121,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc: Exception | None) -> None: + self.stop_keepalive() code = 1005 if self.handshake_complete else 1006 self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) self.connections.remove(self) @@ -125,6 +138,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol): pass def shutdown(self) -> None: + self.stop_keepalive() if self.handshake_complete: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) self.conn.send_close(1012) @@ -155,7 +169,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol): elif event.opcode == Opcode.PING: self.handle_ping() elif event.opcode == Opcode.PONG: - pass # pragma: no cover + self.handle_pong(event) elif event.opcode == Opcode.CLOSE: self.handle_close(event) else: @@ -238,6 +252,67 @@ class WebSocketsSansIOProtocol(asyncio.Protocol): output = self.conn.data_to_send() self.transport.write(b"".join(output)) + def handle_pong(self, event: Frame) -> None: + # Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight + if self.pending_ping_payload is None or bytes(event.data) != self.pending_ping_payload: + return # pragma: no cover + + self.last_ping_rtt = self.loop.time() - self.ping_sent_at + self.pending_ping_payload = None + # The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is + # what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is + # already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here. + if self.pong_timer is not None: + self.pong_timer.cancel() + self.pong_timer = None + self.schedule_ping() + + def start_keepalive(self) -> None: + if self.ping_interval is not None and self.ping_interval > 0: + self.schedule_ping() + + def stop_keepalive(self) -> None: + if self.ping_timer is not None: + self.ping_timer.cancel() + self.ping_timer = None + if self.pong_timer is not None: # pragma: no cover + self.pong_timer.cancel() + self.pong_timer = None + self.pending_ping_payload = None + + def schedule_ping(self) -> None: + assert self.ping_interval is not None + delay = max(0.0, self.ping_interval - self.last_ping_rtt) + self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping) + + def send_keepalive_ping(self) -> None: + self.ping_timer = None + if self.close_sent or self.transport.is_closing(): # pragma: no cover + return + # Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs. + # See https://github.com/python-websockets/websockets/blob/4d229bf9f583d593aa103287aee0a77c9fbc3a79/src/websockets/asyncio/connection.py#L624 + self.pending_ping_payload = struct.pack("!I", random.getrandbits(32)) + self.ping_sent_at = self.loop.time() + self.conn.send_ping(self.pending_ping_payload) + self.transport.write(b"".join(self.conn.data_to_send())) + if self.ping_timeout is not None: + self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout) + else: # pragma: no cover + self.schedule_ping() + + def keepalive_timeout(self) -> None: + self.pong_timer = None + self.pending_ping_payload = None + if self.close_sent or self.transport.is_closing(): # pragma: no cover + return + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix) + self.conn.fail(1011, "keepalive ping timeout") + self.transport.write(b"".join(self.conn.data_to_send())) + self.close_sent = True + self.transport.close() + def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): assert self.conn.close_rcvd is not None @@ -311,6 +386,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol): self.conn.send_response(self.response) output = self.conn.data_to_send() self.transport.write(b"".join(output)) + self.start_keepalive() elif message["type"] == "websocket.close": self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})