Compare commits

...

2 Commits

Author SHA1 Message Date
Marcelo Trylesinski
0014792263 Avoid double-copy of single-frame binary payload 2026-04-22 22:21:31 +02:00
Marcelo Trylesinski
c1615ebb21 Use bytearray for incoming WebSocket message buffer in websockets-sansio 2026-04-22 22:08:27 +02:00
2 changed files with 52 additions and 6 deletions

View File

@ -777,6 +777,43 @@ async def test_fragmented_message_exceeding_max_size(
assert exc_info.value.rcvd.code == 1009
@pytest.mark.parametrize("opcode,key", [(Opcode.BINARY, "bytes"), (Opcode.TEXT, "text")])
async def test_fragmented_message_reassembly(
ws_protocol_cls: WSProtocol,
http_protocol_cls: HTTPProtocol,
unused_tcp_port: int,
opcode: Opcode,
key: str,
):
"""Server reassembles a fragmented message and delivers it to the app intact."""
received: list[str | bytes] = []
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
connect = await receive()
assert connect["type"] == "websocket.connect"
await send({"type": "websocket.accept"})
message = await receive()
assert message["type"] == "websocket.receive"
payload = message.get(key)
assert isinstance(payload, (str, bytes))
received.append(payload)
await send({"type": "websocket.close"})
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}") as ws:
payload = b"A" * 512
await ws.write_frame(False, opcode, payload)
for _ in range(4):
await ws.write_frame(False, Opcode.CONT, payload)
await ws.write_frame(True, Opcode.CONT, payload)
expected: str | bytes = "A" * 512 * 6 if opcode is Opcode.TEXT else b"A" * 512 * 6
assert received == [expected]
async def test_server_reject_connection(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):

View File

@ -105,7 +105,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.last_ping_rtt: float = 0.0
# Buffers
self.bytes = b""
self.bytes: bytes | bytearray = b""
def connection_made(self, transport: BaseTransport) -> None:
"""Called when a connection is made."""
@ -216,22 +216,29 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
task.add_done_callback(self.on_task_complete)
self.tasks.add(task)
def handle_cont(self, event: Frame) -> None: # pragma: no cover
self.bytes += event.data
def handle_cont(self, event: Frame) -> None:
assert isinstance(self.bytes, bytearray)
self.bytes.extend(event.data)
if event.fin:
self.send_receive_event_to_app()
def handle_text(self, event: Frame) -> None:
self.bytes = event.data
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
if event.fin:
# Single-frame message: pass through as bytes (zero copies).
self.bytes = event.data
self.send_receive_event_to_app()
else:
# Fragmented: promote to bytearray so continuations extend in place.
self.bytes = bytearray(event.data)
def handle_bytes(self, event: Frame) -> None:
self.bytes = event.data
self.curr_msg_data_type = "bytes"
if event.fin:
self.bytes = event.data
self.send_receive_event_to_app()
else:
self.bytes = bytearray(event.data)
def send_receive_event_to_app(self) -> None:
if self.curr_msg_data_type == "text":
@ -243,7 +250,9 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.handle_parser_exception()
return
else:
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
# Freeze the buffer to bytes only when it was grown across fragments.
payload = self.bytes if isinstance(self.bytes, bytes) else bytes(self.bytes)
self.queue.put_nowait({"type": "websocket.receive", "bytes": payload})
if not self.read_paused:
self.read_paused = True
self.transport.pause_reading()