Support ws_max_size in wsproto implementation (#2915)

This commit is contained in:
Marcelo Trylesinski 2026-04-22 19:17:00 +02:00 committed by GitHub
parent 2c423bd82b
commit 3e6b964466
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 78 additions and 31 deletions

View File

@ -94,7 +94,7 @@ Using Uvicorn with watchfiles will enable the following options (which are other
* `--loop <str>` - Set the event loop implementation. The uvloop implementation provides greater performance, but is not compatible with Windows or PyPy. **Options:** *'auto', 'asyncio', 'uvloop'.* **Default:** *'auto'*.
* `--http <str>` - Set the HTTP protocol implementation. The httptools implementation provides greater performance, but it not compatible with PyPy. **Options:** *'auto', 'h11', 'httptools'.* **Default:** *'auto'*.
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. Only available with the `websockets` protocol. **Default:** *16777216* (16 MB).
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. **Default:** *16777216* (16 MB).
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.

View File

@ -10,6 +10,7 @@ import websockets
import websockets.client
import websockets.exceptions
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
from websockets.frames import Opcode
from websockets.typing import Subprotocol
from tests.response import Response
@ -751,6 +752,30 @@ async def test_send_binary_data_to_server_bigger_than_default_on_websockets(
assert ws.close_code == expected_result
async def test_fragmented_message_exceeding_max_size(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
"""Stream non-FIN fragments past `ws_max_size` - the server must close with 1009."""
class App(WebSocketResponse):
async def websocket_connect(self, message: WebSocketConnectEvent):
await self.send({"type": "websocket.accept"})
config = Config(
app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", ws_max_size=2048, port=unused_tcp_port
)
async with run_server(config):
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}") as ws:
payload = b"A" * 1024
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
await ws.write_frame(False, Opcode.BINARY, payload)
for _ in range(63): # 64 KiB total, well past 2 KiB budget
await ws.write_frame(False, Opcode.CONT, payload)
await ws.recv()
assert exc_info.value.rcvd is not None
assert exc_info.value.rcvd.code == 1009
async def test_server_reject_connection(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import logging
from io import BytesIO, StringIO
from typing import Any, Literal, cast
from urllib.parse import unquote
@ -11,12 +12,7 @@ from wsproto.connection import ConnectionState
from wsproto.extensions import Extension, PerMessageDeflate
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketEvent,
WebSocketScope,
)
from uvicorn._types import ASGI3Application, ASGISendEvent, WebSocketEvent, WebSocketReceiveEvent, WebSocketScope
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
@ -30,6 +26,36 @@ from uvicorn.protocols.utils import (
from uvicorn.server import ServerState
class FrameTooLargeError(Exception):
"""Raised when accumulated websocket message bytes exceed `ws_max_size`."""
class WebsocketBuffer:
def __init__(self, max_length: int) -> None:
self.value: BytesIO | StringIO | None = None
self.length = 0
self.max_length = max_length
def extend(self, event: events.TextMessage | events.BytesMessage) -> None:
if self.value is None:
self.value = StringIO() if isinstance(event, events.TextMessage) else BytesIO()
self.value.write(event.data) # type: ignore[arg-type]
# `ws_max_size` is a byte budget, so count UTF-8 bytes for text.
self.length += len(event.data.encode()) if isinstance(event, events.TextMessage) else len(event.data)
if self.length > self.max_length:
raise FrameTooLargeError
def clear(self) -> None:
self.value = None
self.length = 0
def to_message(self) -> WebSocketReceiveEvent:
if isinstance(self.value, StringIO):
return {"type": "websocket.receive", "text": self.value.getvalue()}
assert isinstance(self.value, BytesIO)
return {"type": "websocket.receive", "bytes": self.value.getvalue()}
class WSProtocol(asyncio.Protocol):
def __init__(
self,
@ -73,15 +99,12 @@ class WSProtocol(asyncio.Protocol):
self.writable = asyncio.Event()
self.writable.set()
# Buffers
self.bytes = b""
self.text = ""
# Buffer
self.buffer = WebsocketBuffer(self.config.ws_max_size)
# Protocol interface
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
def connection_made(self, transport: asyncio.Transport) -> None: # type: ignore[override]
self.connections.add(self)
self.transport = transport
self.server = get_local_addr(transport)
@ -120,12 +143,12 @@ class WSProtocol(asyncio.Protocol):
def handle_events(self) -> None:
for event in self.conn.events():
if self.close_sent:
return
if isinstance(event, events.Request):
self.handle_connect(event)
elif isinstance(event, events.TextMessage):
self.handle_text(event)
elif isinstance(event, events.BytesMessage):
self.handle_bytes(event)
elif isinstance(event, (events.TextMessage, events.BytesMessage)):
self.handle_message(event)
elif isinstance(event, events.CloseConnection):
self.handle_close(event)
elif isinstance(event, events.Ping):
@ -185,21 +208,20 @@ class WSProtocol(asyncio.Protocol):
task.add_done_callback(self.on_task_complete)
self.tasks.add(task)
def handle_text(self, event: events.TextMessage) -> None:
self.text += event.data
def handle_message(self, event: events.TextMessage | events.BytesMessage) -> None:
try:
self.buffer.extend(event)
except FrameTooLargeError:
self.close_sent = True
reason = f"Message exceeds the maximum size ({self.config.ws_max_size} bytes)"
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1009, "reason": reason})
if not self.transport.is_closing():
self.transport.write(self.conn.send(wsproto.events.CloseConnection(code=1009, reason=reason)))
self.transport.close()
return
if event.message_finished:
self.queue.put_nowait({"type": "websocket.receive", "text": self.text})
self.text = ""
if not self.read_paused:
self.read_paused = True
self.transport.pause_reading()
def handle_bytes(self, event: events.BytesMessage) -> None:
self.bytes += event.data
# todo: we may want to guard the size of self.bytes and self.text
if event.message_finished:
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
self.bytes = b""
self.queue.put_nowait(self.buffer.to_message())
self.buffer.clear()
if not self.read_paused:
self.read_paused = True
self.transport.pause_reading()