Support ws_ping_interval and ws_ping_timeout in wsproto implementation (#2916)

This commit is contained in:
Marcelo Trylesinski 2026-04-22 20:33:16 +02:00 committed by GitHub
parent 3e6b964466
commit d438fb16fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 111 additions and 10 deletions

View File

@ -96,8 +96,8 @@ Using Uvicorn with watchfiles will enable the following options (which are other
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. **Default:** *16777216* (16 MB).
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. **Default:** *20.0*.
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. **Default:** *20.0*.
* `--ws-per-message-deflate <bool>` - Enable/disable WebSocket per-message-deflate compression. Only available with the `websockets` protocol. **Default:** *True*.
* `--lifespan <str>` - Set the Lifespan protocol implementation. **Options:** *'auto', 'on', 'off'.* **Default:** *'auto'*.
* `--h11-max-incomplete-event-size <int>` - Set the maximum number of bytes to buffer of an incomplete event. Only available for `h11` HTTP protocol implementation. **Default:** *16384* (16 KB).

View File

@ -44,6 +44,7 @@ if TYPE_CHECKING:
HTTPProtocol: TypeAlias = "type[H11Protocol | HttpToolsProtocol]"
WSProtocol: TypeAlias = "type[_WSProtocol | WebSocketProtocol]"
KeepaliveWSProtocol: TypeAlias = "type[_WSProtocol | WebSocketsSansIOProtocol]"
pytestmark = pytest.mark.anyio
@ -1230,7 +1231,27 @@ async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HT
assert expected_states == actual_states
async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
@pytest.fixture(
params=[
pytest.param(
"uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
marks=skip_if_no_wsproto,
id="wsproto",
),
pytest.param(
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio"
),
]
)
def keepalive_ws_protocol_cls(request: pytest.FixtureRequest):
from uvicorn.importer import import_from_string
return import_from_string(request.param)
async def test_server_keepalive_ping_pong(
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
while True:
message = await receive()
@ -1241,7 +1262,7 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse
config = Config(
app=app,
ws=WebSocketsSansIOProtocol,
ws=keepalive_ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
ws_ping_interval=0.1,
@ -1252,7 +1273,7 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse
# The websockets client auto-responds to ping frames, keeping the connection alive.
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
protocol = list(server.server_state.connections)[0]
assert isinstance(protocol, WebSocketsSansIOProtocol)
assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol))
# Wait until the server sends at least one keepalive ping, then
# sleep past the timeout window and ensure the connection stays open.
@ -1267,7 +1288,9 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse
assert not protocol.transport.is_closing()
async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
async def test_server_keepalive_ping_timeout(
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
while True:
message = await receive()
@ -1278,7 +1301,7 @@ async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, un
config = Config(
app=app,
ws=WebSocketsSansIOProtocol,
ws=keepalive_ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
ws_ping_interval=0.1,
@ -1297,7 +1320,9 @@ async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, un
assert exc_info.value.rcvd.reason == "keepalive ping timeout"
async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
async def test_server_keepalive_disabled(
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
while True:
message = await receive()
@ -1308,7 +1333,7 @@ async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused
config = Config(
app=app,
ws=WebSocketsSansIOProtocol,
ws=keepalive_ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
ws_ping_interval=None,
@ -1317,5 +1342,5 @@ async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused
async with run_server(config) as server:
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
protocol = list(server.server_state.connections)[0]
assert isinstance(protocol, WebSocketsSansIOProtocol)
assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol))
assert protocol.ping_timer is None

View File

@ -2,6 +2,9 @@ from __future__ import annotations
import asyncio
import logging
import random
import struct
from asyncio import TimerHandle
from io import BytesIO, StringIO
from typing import Any, Literal, cast
from urllib.parse import unquote
@ -99,6 +102,15 @@ class WSProtocol(asyncio.Protocol):
self.writable = asyncio.Event()
self.writable.set()
# Keepalive state
self.ping_interval = config.ws_ping_interval
self.ping_timeout = config.ws_ping_timeout
self.ping_timer: TimerHandle | None = None
self.pong_timer: TimerHandle | None = None
self.pending_ping_payload: bytes | None = None
self.ping_sent_at: float = 0.0
self.last_ping_rtt: float = 0.0
# Buffer
self.buffer = WebsocketBuffer(self.config.ws_max_size)
@ -116,6 +128,7 @@ class WSProtocol(asyncio.Protocol):
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
def connection_lost(self, exc: Exception | None) -> None:
self.stop_keepalive()
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)
@ -153,6 +166,8 @@ class WSProtocol(asyncio.Protocol):
self.handle_close(event)
elif isinstance(event, events.Ping):
self.handle_ping(event)
elif isinstance(event, events.Pong):
self.handle_pong(event)
def pause_writing(self) -> None:
"""
@ -167,6 +182,7 @@ class WSProtocol(asyncio.Protocol):
self.writable.set() # pragma: full coverage
def shutdown(self) -> None:
self.stop_keepalive()
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
@ -235,6 +251,65 @@ class WSProtocol(asyncio.Protocol):
def handle_ping(self, event: events.Ping) -> None:
self.transport.write(self.conn.send(event.response()))
def handle_pong(self, event: events.Pong) -> None:
# Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight.
if self.pending_ping_payload is None or bytes(event.payload) != self.pending_ping_payload:
return # pragma: no cover
self.last_ping_rtt = self.loop.time() - self.ping_sent_at
self.pending_ping_payload = None
# The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is
# what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is
# already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here.
if self.pong_timer is not None:
self.pong_timer.cancel()
self.pong_timer = None
self.schedule_ping()
def start_keepalive(self) -> None:
if self.ping_interval is not None and self.ping_interval > 0:
self.schedule_ping()
def stop_keepalive(self) -> None:
if self.ping_timer is not None:
self.ping_timer.cancel()
self.ping_timer = None
if self.pong_timer is not None: # pragma: no cover
self.pong_timer.cancel()
self.pong_timer = None
self.pending_ping_payload = None
def schedule_ping(self) -> None:
assert self.ping_interval is not None
delay = max(0.0, self.ping_interval - self.last_ping_rtt)
self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping)
def send_keepalive_ping(self) -> None:
self.ping_timer = None
if self.close_sent or self.transport.is_closing(): # pragma: no cover
return
# Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs.
self.pending_ping_payload = struct.pack("!I", random.getrandbits(32))
self.ping_sent_at = self.loop.time()
self.transport.write(self.conn.send(wsproto.events.Ping(payload=self.pending_ping_payload)))
if self.ping_timeout is not None:
self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout)
else: # pragma: no cover
self.schedule_ping()
def keepalive_timeout(self) -> None:
self.pong_timer = None
self.pending_ping_payload = None
if self.close_sent or self.transport.is_closing(): # pragma: no cover
return
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix)
reason = "keepalive ping timeout"
self.transport.write(self.conn.send(wsproto.events.CloseConnection(code=1011, reason=reason)))
self.close_sent = True
self.transport.close()
def send_500_response(self) -> None:
if self.response_started or self.handshake_complete:
return # we cannot send responses anymore
@ -288,6 +363,7 @@ class WSProtocol(asyncio.Protocol):
)
)
self.transport.write(output)
self.start_keepalive()
elif message["type"] == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})