Use bytearray for incoming WebSocket message buffer in websockets-sansio (#2917)

This commit is contained in:
Marcelo Trylesinski 2026-04-22 22:11:28 +02:00 committed by GitHub
parent d438fb16fe
commit 7375b5bf66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 6 deletions

View File

@ -777,6 +777,37 @@ async def test_fragmented_message_exceeding_max_size(
assert exc_info.value.rcvd.code == 1009
async def test_fragmented_message_reassembly(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
"""Server reassembles a fragmented message and delivers it to the app intact."""
received: list[bytes] = []
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
connect = await receive()
assert connect["type"] == "websocket.connect"
await send({"type": "websocket.accept"})
message = await receive()
assert message["type"] == "websocket.receive"
payload = message.get("bytes")
assert payload is not None
received.append(payload)
await send({"type": "websocket.close"})
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}") as ws:
payload = b"A" * 512
await ws.write_frame(False, Opcode.BINARY, payload)
for _ in range(4):
await ws.write_frame(False, Opcode.CONT, payload)
await ws.write_frame(True, Opcode.CONT, payload)
assert received == [b"A" * 512 * 6]
async def test_server_reject_connection(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):

View File

@ -105,7 +105,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.last_ping_rtt: float = 0.0
# Buffers
self.bytes = b""
self.bytes = bytearray()
def connection_made(self, transport: BaseTransport) -> None:
"""Called when a connection is made."""
@ -216,19 +216,19 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
task.add_done_callback(self.on_task_complete)
self.tasks.add(task)
def handle_cont(self, event: Frame) -> None: # pragma: no cover
self.bytes += event.data
def handle_cont(self, event: Frame) -> None:
self.bytes.extend(event.data)
if event.fin:
self.send_receive_event_to_app()
def handle_text(self, event: Frame) -> None:
self.bytes = event.data
self.bytes = bytearray(event.data)
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
if event.fin:
self.send_receive_event_to_app()
def handle_bytes(self, event: Frame) -> None:
self.bytes = event.data
self.bytes = bytearray(event.data)
self.curr_msg_data_type = "bytes"
if event.fin:
self.send_receive_event_to_app()
@ -243,7 +243,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.handle_parser_exception()
return
else:
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
self.queue.put_nowait({"type": "websocket.receive", "bytes": bytes(self.bytes)})
if not self.read_paused:
self.read_paused = True
self.transport.pause_reading()