Implement websocket keepalive pings for websockets-sansio (#2888)
This commit is contained in:
parent
8d397c7319
commit
029be08867
@ -95,8 +95,8 @@ Using Uvicorn with watchfiles will enable the following options (which are other
|
||||
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
|
||||
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. Only available with the `websockets` protocol. **Default:** *16777216* (16 MB).
|
||||
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
|
||||
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*.
|
||||
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*.
|
||||
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
|
||||
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
|
||||
* `--ws-per-message-deflate <bool>` - Enable/disable WebSocket per-message-deflate compression. Only available with the `websockets` protocol. **Default:** *True*.
|
||||
* `--lifespan <str>` - Set the Lifespan protocol implementation. **Options:** *'auto', 'on', 'off'.* **Default:** *'auto'*.
|
||||
* `--h11-max-incomplete-event-size <int>` - Set the maximum number of bytes to buffer of an incomplete event. Only available for `h11` HTTP protocol implementation. **Default:** *16384* (16 KB).
|
||||
|
||||
@ -27,6 +27,7 @@ from uvicorn._types import (
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
|
||||
|
||||
try:
|
||||
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol as _WSProtocol
|
||||
@ -1202,3 +1203,90 @@ async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HT
|
||||
assert is_open
|
||||
|
||||
assert expected_states == actual_states
|
||||
|
||||
|
||||
async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "websocket.connect":
|
||||
await send({"type": "websocket.accept"})
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
config = Config(
|
||||
app=app,
|
||||
ws=WebSocketsSansIOProtocol,
|
||||
http=http_protocol_cls,
|
||||
lifespan="off",
|
||||
ws_ping_interval=0.1,
|
||||
ws_ping_timeout=5.0,
|
||||
port=unused_tcp_port,
|
||||
)
|
||||
async with run_server(config) as server:
|
||||
# The websockets client auto-responds to ping frames, keeping the connection alive.
|
||||
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
|
||||
protocol = list(server.server_state.connections)[0]
|
||||
assert isinstance(protocol, WebSocketsSansIOProtocol)
|
||||
|
||||
# Wait until at least one ping/pong roundtrip completes.
|
||||
async def ping_roundtrip() -> None:
|
||||
while protocol.last_ping_rtt == 0.0:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await asyncio.wait_for(ping_roundtrip(), timeout=5.0)
|
||||
assert protocol.last_ping_rtt > 0
|
||||
|
||||
|
||||
async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "websocket.connect":
|
||||
await send({"type": "websocket.accept"})
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
config = Config(
|
||||
app=app,
|
||||
ws=WebSocketsSansIOProtocol,
|
||||
http=http_protocol_cls,
|
||||
lifespan="off",
|
||||
ws_ping_interval=0.1,
|
||||
ws_ping_timeout=0.1,
|
||||
log_level="trace",
|
||||
port=unused_tcp_port,
|
||||
)
|
||||
async with run_server(config):
|
||||
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None) as websocket:
|
||||
# Swallow outgoing pong frames so the server's ping never gets ack'd.
|
||||
websocket.transport.write = lambda data: None # type: ignore[method-assign]
|
||||
with pytest.raises(websockets.exceptions.ConnectionClosedError) as exc_info:
|
||||
await asyncio.wait_for(websocket.recv(), timeout=1)
|
||||
assert exc_info.value.rcvd is not None
|
||||
assert exc_info.value.rcvd.code == 1011
|
||||
assert exc_info.value.rcvd.reason == "keepalive ping timeout"
|
||||
|
||||
|
||||
async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "websocket.connect":
|
||||
await send({"type": "websocket.accept"})
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
config = Config(
|
||||
app=app,
|
||||
ws=WebSocketsSansIOProtocol,
|
||||
http=http_protocol_cls,
|
||||
lifespan="off",
|
||||
ws_ping_interval=None,
|
||||
port=unused_tcp_port,
|
||||
)
|
||||
async with run_server(config) as server:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
|
||||
protocol = list(server.server_state.connections)[0]
|
||||
assert isinstance(protocol, WebSocketsSansIOProtocol)
|
||||
assert protocol.ping_timer is None
|
||||
|
||||
@ -2,7 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import struct
|
||||
import sys
|
||||
from asyncio import TimerHandle
|
||||
from asyncio.transports import BaseTransport, Transport
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Literal, cast
|
||||
@ -92,6 +95,15 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.writable = asyncio.Event()
|
||||
self.writable.set()
|
||||
|
||||
# Keepalive state
|
||||
self.ping_interval = config.ws_ping_interval
|
||||
self.ping_timeout = config.ws_ping_timeout
|
||||
self.ping_timer: TimerHandle | None = None
|
||||
self.pong_timer: TimerHandle | None = None
|
||||
self.pending_ping_payload: bytes | None = None
|
||||
self.ping_sent_at: float = 0.0
|
||||
self.last_ping_rtt: float = 0.0
|
||||
|
||||
# Buffers
|
||||
self.bytes = b""
|
||||
|
||||
@ -109,6 +121,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.stop_keepalive()
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
@ -125,6 +138,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
pass
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.stop_keepalive()
|
||||
if self.handshake_complete:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
||||
self.conn.send_close(1012)
|
||||
@ -155,7 +169,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
elif event.opcode == Opcode.PING:
|
||||
self.handle_ping()
|
||||
elif event.opcode == Opcode.PONG:
|
||||
pass # pragma: no cover
|
||||
self.handle_pong(event)
|
||||
elif event.opcode == Opcode.CLOSE:
|
||||
self.handle_close(event)
|
||||
else:
|
||||
@ -238,6 +252,67 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
def handle_pong(self, event: Frame) -> None:
|
||||
# Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight
|
||||
if self.pending_ping_payload is None or bytes(event.data) != self.pending_ping_payload:
|
||||
return # pragma: no cover
|
||||
|
||||
self.last_ping_rtt = self.loop.time() - self.ping_sent_at
|
||||
self.pending_ping_payload = None
|
||||
# The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is
|
||||
# what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is
|
||||
# already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here.
|
||||
if self.pong_timer is not None:
|
||||
self.pong_timer.cancel()
|
||||
self.pong_timer = None
|
||||
self.schedule_ping()
|
||||
|
||||
def start_keepalive(self) -> None:
|
||||
if self.ping_interval is not None and self.ping_interval > 0:
|
||||
self.schedule_ping()
|
||||
|
||||
def stop_keepalive(self) -> None:
|
||||
if self.ping_timer is not None:
|
||||
self.ping_timer.cancel()
|
||||
self.ping_timer = None
|
||||
if self.pong_timer is not None: # pragma: no cover
|
||||
self.pong_timer.cancel()
|
||||
self.pong_timer = None
|
||||
self.pending_ping_payload = None
|
||||
|
||||
def schedule_ping(self) -> None:
|
||||
assert self.ping_interval is not None
|
||||
delay = max(0.0, self.ping_interval - self.last_ping_rtt)
|
||||
self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping)
|
||||
|
||||
def send_keepalive_ping(self) -> None:
|
||||
self.ping_timer = None
|
||||
if self.close_sent or self.transport.is_closing(): # pragma: no cover
|
||||
return
|
||||
# Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs.
|
||||
# See https://github.com/python-websockets/websockets/blob/4d229bf9f583d593aa103287aee0a77c9fbc3a79/src/websockets/asyncio/connection.py#L624
|
||||
self.pending_ping_payload = struct.pack("!I", random.getrandbits(32))
|
||||
self.ping_sent_at = self.loop.time()
|
||||
self.conn.send_ping(self.pending_ping_payload)
|
||||
self.transport.write(b"".join(self.conn.data_to_send()))
|
||||
if self.ping_timeout is not None:
|
||||
self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout)
|
||||
else: # pragma: no cover
|
||||
self.schedule_ping()
|
||||
|
||||
def keepalive_timeout(self) -> None:
|
||||
self.pong_timer = None
|
||||
self.pending_ping_payload = None
|
||||
if self.close_sent or self.transport.is_closing(): # pragma: no cover
|
||||
return
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix)
|
||||
self.conn.fail(1011, "keepalive ping timeout")
|
||||
self.transport.write(b"".join(self.conn.data_to_send()))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
def handle_close(self, event: Frame) -> None:
|
||||
if not self.close_sent and not self.transport.is_closing():
|
||||
assert self.conn.close_rcvd is not None
|
||||
@ -311,6 +386,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.conn.send_response(self.response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.start_keepalive()
|
||||
|
||||
elif message["type"] == "websocket.close":
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user