Support ws_ping_interval and ws_ping_timeout in wsproto implementation (#2916)
This commit is contained in:
parent
3e6b964466
commit
d438fb16fe
@ -96,8 +96,8 @@ Using Uvicorn with watchfiles will enable the following options (which are other
|
||||
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
|
||||
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. **Default:** *16777216* (16 MB).
|
||||
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
|
||||
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
|
||||
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
|
||||
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. **Default:** *20.0*.
|
||||
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. **Default:** *20.0*.
|
||||
* `--ws-per-message-deflate <bool>` - Enable/disable WebSocket per-message-deflate compression. Only available with the `websockets` protocol. **Default:** *True*.
|
||||
* `--lifespan <str>` - Set the Lifespan protocol implementation. **Options:** *'auto', 'on', 'off'.* **Default:** *'auto'*.
|
||||
* `--h11-max-incomplete-event-size <int>` - Set the maximum number of bytes to buffer of an incomplete event. Only available for `h11` HTTP protocol implementation. **Default:** *16384* (16 KB).
|
||||
|
||||
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
||||
|
||||
HTTPProtocol: TypeAlias = "type[H11Protocol | HttpToolsProtocol]"
|
||||
WSProtocol: TypeAlias = "type[_WSProtocol | WebSocketProtocol]"
|
||||
KeepaliveWSProtocol: TypeAlias = "type[_WSProtocol | WebSocketsSansIOProtocol]"
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
@ -1230,7 +1231,27 @@ async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HT
|
||||
assert expected_states == actual_states
|
||||
|
||||
|
||||
async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
pytest.param(
|
||||
"uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
|
||||
marks=skip_if_no_wsproto,
|
||||
id="wsproto",
|
||||
),
|
||||
pytest.param(
|
||||
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio"
|
||||
),
|
||||
]
|
||||
)
|
||||
def keepalive_ws_protocol_cls(request: pytest.FixtureRequest):
|
||||
from uvicorn.importer import import_from_string
|
||||
|
||||
return import_from_string(request.param)
|
||||
|
||||
|
||||
async def test_server_keepalive_ping_pong(
|
||||
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
while True:
|
||||
message = await receive()
|
||||
@ -1241,7 +1262,7 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse
|
||||
|
||||
config = Config(
|
||||
app=app,
|
||||
ws=WebSocketsSansIOProtocol,
|
||||
ws=keepalive_ws_protocol_cls,
|
||||
http=http_protocol_cls,
|
||||
lifespan="off",
|
||||
ws_ping_interval=0.1,
|
||||
@ -1252,7 +1273,7 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse
|
||||
# The websockets client auto-responds to ping frames, keeping the connection alive.
|
||||
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
|
||||
protocol = list(server.server_state.connections)[0]
|
||||
assert isinstance(protocol, WebSocketsSansIOProtocol)
|
||||
assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol))
|
||||
|
||||
# Wait until the server sends at least one keepalive ping, then
|
||||
# sleep past the timeout window and ensure the connection stays open.
|
||||
@ -1267,7 +1288,9 @@ async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unuse
|
||||
assert not protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
async def test_server_keepalive_ping_timeout(
|
||||
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
while True:
|
||||
message = await receive()
|
||||
@ -1278,7 +1301,7 @@ async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, un
|
||||
|
||||
config = Config(
|
||||
app=app,
|
||||
ws=WebSocketsSansIOProtocol,
|
||||
ws=keepalive_ws_protocol_cls,
|
||||
http=http_protocol_cls,
|
||||
lifespan="off",
|
||||
ws_ping_interval=0.1,
|
||||
@ -1297,7 +1320,9 @@ async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, un
|
||||
assert exc_info.value.rcvd.reason == "keepalive ping timeout"
|
||||
|
||||
|
||||
async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
async def test_server_keepalive_disabled(
|
||||
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
while True:
|
||||
message = await receive()
|
||||
@ -1308,7 +1333,7 @@ async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused
|
||||
|
||||
config = Config(
|
||||
app=app,
|
||||
ws=WebSocketsSansIOProtocol,
|
||||
ws=keepalive_ws_protocol_cls,
|
||||
http=http_protocol_cls,
|
||||
lifespan="off",
|
||||
ws_ping_interval=None,
|
||||
@ -1317,5 +1342,5 @@ async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused
|
||||
async with run_server(config) as server:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
|
||||
protocol = list(server.server_state.connections)[0]
|
||||
assert isinstance(protocol, WebSocketsSansIOProtocol)
|
||||
assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol))
|
||||
assert protocol.ping_timer is None
|
||||
|
||||
@ -2,6 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import struct
|
||||
from asyncio import TimerHandle
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
@ -99,6 +102,15 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.writable = asyncio.Event()
|
||||
self.writable.set()
|
||||
|
||||
# Keepalive state
|
||||
self.ping_interval = config.ws_ping_interval
|
||||
self.ping_timeout = config.ws_ping_timeout
|
||||
self.ping_timer: TimerHandle | None = None
|
||||
self.pong_timer: TimerHandle | None = None
|
||||
self.pending_ping_payload: bytes | None = None
|
||||
self.ping_sent_at: float = 0.0
|
||||
self.last_ping_rtt: float = 0.0
|
||||
|
||||
# Buffer
|
||||
self.buffer = WebsocketBuffer(self.config.ws_max_size)
|
||||
|
||||
@ -116,6 +128,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.stop_keepalive()
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
@ -153,6 +166,8 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.handle_close(event)
|
||||
elif isinstance(event, events.Ping):
|
||||
self.handle_ping(event)
|
||||
elif isinstance(event, events.Pong):
|
||||
self.handle_pong(event)
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
@ -167,6 +182,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.writable.set() # pragma: full coverage
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.stop_keepalive()
|
||||
if self.handshake_complete:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
|
||||
@ -235,6 +251,65 @@ class WSProtocol(asyncio.Protocol):
|
||||
def handle_ping(self, event: events.Ping) -> None:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
|
||||
def handle_pong(self, event: events.Pong) -> None:
|
||||
# Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight.
|
||||
if self.pending_ping_payload is None or bytes(event.payload) != self.pending_ping_payload:
|
||||
return # pragma: no cover
|
||||
|
||||
self.last_ping_rtt = self.loop.time() - self.ping_sent_at
|
||||
self.pending_ping_payload = None
|
||||
# The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is
|
||||
# what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is
|
||||
# already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here.
|
||||
if self.pong_timer is not None:
|
||||
self.pong_timer.cancel()
|
||||
self.pong_timer = None
|
||||
self.schedule_ping()
|
||||
|
||||
def start_keepalive(self) -> None:
|
||||
if self.ping_interval is not None and self.ping_interval > 0:
|
||||
self.schedule_ping()
|
||||
|
||||
def stop_keepalive(self) -> None:
|
||||
if self.ping_timer is not None:
|
||||
self.ping_timer.cancel()
|
||||
self.ping_timer = None
|
||||
if self.pong_timer is not None: # pragma: no cover
|
||||
self.pong_timer.cancel()
|
||||
self.pong_timer = None
|
||||
self.pending_ping_payload = None
|
||||
|
||||
def schedule_ping(self) -> None:
|
||||
assert self.ping_interval is not None
|
||||
delay = max(0.0, self.ping_interval - self.last_ping_rtt)
|
||||
self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping)
|
||||
|
||||
def send_keepalive_ping(self) -> None:
|
||||
self.ping_timer = None
|
||||
if self.close_sent or self.transport.is_closing(): # pragma: no cover
|
||||
return
|
||||
# Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs.
|
||||
self.pending_ping_payload = struct.pack("!I", random.getrandbits(32))
|
||||
self.ping_sent_at = self.loop.time()
|
||||
self.transport.write(self.conn.send(wsproto.events.Ping(payload=self.pending_ping_payload)))
|
||||
if self.ping_timeout is not None:
|
||||
self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout)
|
||||
else: # pragma: no cover
|
||||
self.schedule_ping()
|
||||
|
||||
def keepalive_timeout(self) -> None:
|
||||
self.pong_timer = None
|
||||
self.pending_ping_payload = None
|
||||
if self.close_sent or self.transport.is_closing(): # pragma: no cover
|
||||
return
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix)
|
||||
reason = "keepalive ping timeout"
|
||||
self.transport.write(self.conn.send(wsproto.events.CloseConnection(code=1011, reason=reason)))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
if self.response_started or self.handshake_complete:
|
||||
return # we cannot send responses anymore
|
||||
@ -288,6 +363,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
)
|
||||
)
|
||||
self.transport.write(output)
|
||||
self.start_keepalive()
|
||||
|
||||
elif message["type"] == "websocket.close":
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user