Implement websocket keepalive pings for websockets-sansio (#2888)

This commit is contained in:
Marcelo Trylesinski 2026-04-06 09:52:58 +02:00 committed by GitHub
parent 8d397c7319
commit 029be08867
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 167 additions and 3 deletions

View File

@ -95,8 +95,8 @@ Using Uvicorn with watchfiles will enable the following options (which are other
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. Only available with the `websockets` protocol. **Default:** *16777216* (16 MB).
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*.
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*.
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
* `--ws-per-message-deflate <bool>` - Enable/disable WebSocket per-message-deflate compression. Only available with the `websockets` protocol. **Default:** *True*.
* `--lifespan <str>` - Set the Lifespan protocol implementation. **Options:** *'auto', 'on', 'off'.* **Default:** *'auto'*.
* `--h11-max-incomplete-event-size <int>` - Set the maximum number of bytes to buffer of an incomplete event. Only available for `h11` HTTP protocol implementation. **Default:** *16384* (16 KB).

View File

@ -27,6 +27,7 @@ from uvicorn._types import (
)
from uvicorn.config import Config
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
try:
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol as _WSProtocol
@ -1202,3 +1203,90 @@ async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HT
assert is_open
assert expected_states == actual_states
async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
config = Config(
app=app,
ws=WebSocketsSansIOProtocol,
http=http_protocol_cls,
lifespan="off",
ws_ping_interval=0.1,
ws_ping_timeout=5.0,
port=unused_tcp_port,
)
async with run_server(config) as server:
# The websockets client auto-responds to ping frames, keeping the connection alive.
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
protocol = list(server.server_state.connections)[0]
assert isinstance(protocol, WebSocketsSansIOProtocol)
# Wait until at least one ping/pong roundtrip completes.
async def ping_roundtrip() -> None:
while protocol.last_ping_rtt == 0.0:
await asyncio.sleep(0.1)
await asyncio.wait_for(ping_roundtrip(), timeout=5.0)
assert protocol.last_ping_rtt > 0
async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
config = Config(
app=app,
ws=WebSocketsSansIOProtocol,
http=http_protocol_cls,
lifespan="off",
ws_ping_interval=0.1,
ws_ping_timeout=0.1,
log_level="trace",
port=unused_tcp_port,
)
async with run_server(config):
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None) as websocket:
# Swallow outgoing pong frames so the server's ping never gets ack'd.
websocket.transport.write = lambda data: None # type: ignore[method-assign]
with pytest.raises(websockets.exceptions.ConnectionClosedError) as exc_info:
await asyncio.wait_for(websocket.recv(), timeout=1)
assert exc_info.value.rcvd is not None
assert exc_info.value.rcvd.code == 1011
assert exc_info.value.rcvd.reason == "keepalive ping timeout"
async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break
config = Config(
app=app,
ws=WebSocketsSansIOProtocol,
http=http_protocol_cls,
lifespan="off",
ws_ping_interval=None,
port=unused_tcp_port,
)
async with run_server(config) as server:
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
protocol = list(server.server_state.connections)[0]
assert isinstance(protocol, WebSocketsSansIOProtocol)
assert protocol.ping_timer is None

View File

@ -2,7 +2,10 @@ from __future__ import annotations
import asyncio
import logging
import random
import struct
import sys
from asyncio import TimerHandle
from asyncio.transports import BaseTransport, Transport
from http import HTTPStatus
from typing import Any, Literal, cast
@ -92,6 +95,15 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.writable = asyncio.Event()
self.writable.set()
# Keepalive state
self.ping_interval = config.ws_ping_interval
self.ping_timeout = config.ws_ping_timeout
self.ping_timer: TimerHandle | None = None
self.pong_timer: TimerHandle | None = None
self.pending_ping_payload: bytes | None = None
self.ping_sent_at: float = 0.0
self.last_ping_rtt: float = 0.0
# Buffers
self.bytes = b""
@ -109,6 +121,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
def connection_lost(self, exc: Exception | None) -> None:
self.stop_keepalive()
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)
@ -125,6 +138,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
pass
def shutdown(self) -> None:
self.stop_keepalive()
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
self.conn.send_close(1012)
@ -155,7 +169,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
elif event.opcode == Opcode.PING:
self.handle_ping()
elif event.opcode == Opcode.PONG:
pass # pragma: no cover
self.handle_pong(event)
elif event.opcode == Opcode.CLOSE:
self.handle_close(event)
else:
@ -238,6 +252,67 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
def handle_pong(self, event: Frame) -> None:
# Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight
if self.pending_ping_payload is None or bytes(event.data) != self.pending_ping_payload:
return # pragma: no cover
self.last_ping_rtt = self.loop.time() - self.ping_sent_at
self.pending_ping_payload = None
# The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is
# what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is
# already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here.
if self.pong_timer is not None:
self.pong_timer.cancel()
self.pong_timer = None
self.schedule_ping()
def start_keepalive(self) -> None:
if self.ping_interval is not None and self.ping_interval > 0:
self.schedule_ping()
def stop_keepalive(self) -> None:
if self.ping_timer is not None:
self.ping_timer.cancel()
self.ping_timer = None
if self.pong_timer is not None: # pragma: no cover
self.pong_timer.cancel()
self.pong_timer = None
self.pending_ping_payload = None
def schedule_ping(self) -> None:
assert self.ping_interval is not None
delay = max(0.0, self.ping_interval - self.last_ping_rtt)
self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping)
def send_keepalive_ping(self) -> None:
self.ping_timer = None
if self.close_sent or self.transport.is_closing(): # pragma: no cover
return
# Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs.
# See https://github.com/python-websockets/websockets/blob/4d229bf9f583d593aa103287aee0a77c9fbc3a79/src/websockets/asyncio/connection.py#L624
self.pending_ping_payload = struct.pack("!I", random.getrandbits(32))
self.ping_sent_at = self.loop.time()
self.conn.send_ping(self.pending_ping_payload)
self.transport.write(b"".join(self.conn.data_to_send()))
if self.ping_timeout is not None:
self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout)
else: # pragma: no cover
self.schedule_ping()
def keepalive_timeout(self) -> None:
self.pong_timer = None
self.pending_ping_payload = None
if self.close_sent or self.transport.is_closing(): # pragma: no cover
return
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix)
self.conn.fail(1011, "keepalive ping timeout")
self.transport.write(b"".join(self.conn.data_to_send()))
self.close_sent = True
self.transport.close()
def handle_close(self, event: Frame) -> None:
if not self.close_sent and not self.transport.is_closing():
assert self.conn.close_rcvd is not None
@ -311,6 +386,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.start_keepalive()
elif message["type"] == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})