diff --git a/docs/settings.md b/docs/settings.md index d2607e13..f738fc51 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -96,8 +96,8 @@ Using Uvicorn with watchfiles will enable the following options (which are other * `--ws ` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*. * `--ws-max-size ` - Set the WebSockets max message size, in bytes. **Default:** *16777216* (16 MB). * `--ws-max-queue ` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*. -* `--ws-ping-interval ` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*. -* `--ws-ping-timeout ` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*. +* `--ws-ping-interval ` - Set the WebSockets ping interval, in seconds. **Default:** *20.0*. +* `--ws-ping-timeout ` - Set the WebSockets ping timeout, in seconds. **Default:** *20.0*. * `--ws-per-message-deflate ` - Enable/disable WebSocket per-message-deflate compression. Only available with the `websockets` protocol. **Default:** *True*. * `--lifespan ` - Set the Lifespan protocol implementation. **Options:** *'auto', 'on', 'off'.* **Default:** *'auto'*. * `--h11-max-incomplete-event-size ` - Set the maximum number of bytes to buffer of an incomplete event. Only available for `h11` HTTP protocol implementation. **Default:** *16384* (16 KB). diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 8a0c20d1..8ab61ca7 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: HTTPProtocol: TypeAlias = "type[H11Protocol | HttpToolsProtocol]" WSProtocol: TypeAlias = "type[_WSProtocol | WebSocketProtocol]" + KeepaliveWSProtocol: TypeAlias = "type[_WSProtocol | WebSocketsSansIOProtocol]" pytestmark = pytest.mark.anyio @@ -1230,7 +1231,27 @@ async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HT assert expected_states == actual_states -async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): +@pytest.fixture( + params=[ + pytest.param( + "uvicorn.protocols.websockets.wsproto_impl:WSProtocol", + marks=skip_if_no_wsproto, + id="wsproto", + ), + pytest.param( + "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio" + ), + ] +) +def keepalive_ws_protocol_cls(request: pytest.FixtureRequest): + from uvicorn.importer import import_from_string + + return import_from_string(request.param) + + +async def test_server_keepalive_ping_pong( + keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int +): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): while True: message = await receive() @@ -1241,7 +1262,7 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse config = Config( app=app, - ws=WebSocketsSansIOProtocol, + ws=keepalive_ws_protocol_cls, http=http_protocol_cls, lifespan="off", ws_ping_interval=0.1, @@ -1252,7 +1273,7 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse # The websockets client auto-responds to ping frames, keeping the connection alive. async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None): protocol = list(server.server_state.connections)[0] - assert isinstance(protocol, WebSocketsSansIOProtocol) + assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol)) # Wait until the server sends at least one keepalive ping, then # sleep past the timeout window and ensure the connection stays open. @@ -1267,7 +1288,9 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse assert not protocol.transport.is_closing() -async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): +async def test_server_keepalive_ping_timeout( + keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int +): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): while True: message = await receive() @@ -1278,7 +1301,7 @@ async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, un config = Config( app=app, - ws=WebSocketsSansIOProtocol, + ws=keepalive_ws_protocol_cls, http=http_protocol_cls, lifespan="off", ws_ping_interval=0.1, @@ -1297,7 +1320,9 @@ async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, un assert exc_info.value.rcvd.reason == "keepalive ping timeout" -async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): +async def test_server_keepalive_disabled( + keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int +): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): while True: message = await receive() @@ -1308,7 +1333,7 @@ async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused config = Config( app=app, - ws=WebSocketsSansIOProtocol, + ws=keepalive_ws_protocol_cls, http=http_protocol_cls, lifespan="off", ws_ping_interval=None, @@ -1317,5 +1342,5 @@ async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused async with run_server(config) as server: async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None): protocol = list(server.server_state.connections)[0] - assert isinstance(protocol, WebSocketsSansIOProtocol) + assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol)) assert protocol.ping_timer is None diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 48d894eb..adca63f9 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -2,6 +2,9 @@ from __future__ import annotations import asyncio import logging +import random +import struct +from asyncio import TimerHandle from io import BytesIO, StringIO from typing import Any, Literal, cast from urllib.parse import unquote @@ -99,6 +102,15 @@ class WSProtocol(asyncio.Protocol): self.writable = asyncio.Event() self.writable.set() + # Keepalive state + self.ping_interval = config.ws_ping_interval + self.ping_timeout = config.ws_ping_timeout + self.ping_timer: TimerHandle | None = None + self.pong_timer: TimerHandle | None = None + self.pending_ping_payload: bytes | None = None + self.ping_sent_at: float = 0.0 + self.last_ping_rtt: float = 0.0 + # Buffer self.buffer = WebsocketBuffer(self.config.ws_max_size) @@ -116,6 +128,7 @@ class WSProtocol(asyncio.Protocol): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc: Exception | None) -> None: + self.stop_keepalive() code = 1005 if self.handshake_complete else 1006 self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) self.connections.remove(self) @@ -153,6 +166,8 @@ class WSProtocol(asyncio.Protocol): self.handle_close(event) elif isinstance(event, events.Ping): self.handle_ping(event) + elif isinstance(event, events.Pong): + self.handle_pong(event) def pause_writing(self) -> None: """ @@ -167,6 +182,7 @@ class WSProtocol(asyncio.Protocol): self.writable.set() # pragma: full coverage def shutdown(self) -> None: + self.stop_keepalive() if self.handshake_complete: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) output = self.conn.send(wsproto.events.CloseConnection(code=1012)) @@ -235,6 +251,65 @@ class WSProtocol(asyncio.Protocol): def handle_ping(self, event: events.Ping) -> None: self.transport.write(self.conn.send(event.response())) + def handle_pong(self, event: events.Pong) -> None: + # Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight. + if self.pending_ping_payload is None or bytes(event.payload) != self.pending_ping_payload: + return # pragma: no cover + + self.last_ping_rtt = self.loop.time() - self.ping_sent_at + self.pending_ping_payload = None + # The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is + # what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is + # already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here. + if self.pong_timer is not None: + self.pong_timer.cancel() + self.pong_timer = None + self.schedule_ping() + + def start_keepalive(self) -> None: + if self.ping_interval is not None and self.ping_interval > 0: + self.schedule_ping() + + def stop_keepalive(self) -> None: + if self.ping_timer is not None: + self.ping_timer.cancel() + self.ping_timer = None + if self.pong_timer is not None: # pragma: no cover + self.pong_timer.cancel() + self.pong_timer = None + self.pending_ping_payload = None + + def schedule_ping(self) -> None: + assert self.ping_interval is not None + delay = max(0.0, self.ping_interval - self.last_ping_rtt) + self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping) + + def send_keepalive_ping(self) -> None: + self.ping_timer = None + if self.close_sent or self.transport.is_closing(): # pragma: no cover + return + # Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs. + self.pending_ping_payload = struct.pack("!I", random.getrandbits(32)) + self.ping_sent_at = self.loop.time() + self.transport.write(self.conn.send(wsproto.events.Ping(payload=self.pending_ping_payload))) + if self.ping_timeout is not None: + self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout) + else: # pragma: no cover + self.schedule_ping() + + def keepalive_timeout(self) -> None: + self.pong_timer = None + self.pending_ping_payload = None + if self.close_sent or self.transport.is_closing(): # pragma: no cover + return + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix) + reason = "keepalive ping timeout" + self.transport.write(self.conn.send(wsproto.events.CloseConnection(code=1011, reason=reason))) + self.close_sent = True + self.transport.close() + def send_500_response(self) -> None: if self.response_started or self.handshake_complete: return # we cannot send responses anymore @@ -288,6 +363,7 @@ class WSProtocol(asyncio.Protocol): ) ) self.transport.write(output) + self.start_keepalive() elif message["type"] == "websocket.close": self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})