Avoid double-copy of single-frame binary payload
This commit is contained in:
parent
c1615ebb21
commit
0014792263
@ -777,12 +777,17 @@ async def test_fragmented_message_exceeding_max_size(
|
||||
assert exc_info.value.rcvd.code == 1009
|
||||
|
||||
|
||||
@pytest.mark.parametrize("opcode,key", [(Opcode.BINARY, "bytes"), (Opcode.TEXT, "text")])
|
||||
async def test_fragmented_message_reassembly(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
ws_protocol_cls: WSProtocol,
|
||||
http_protocol_cls: HTTPProtocol,
|
||||
unused_tcp_port: int,
|
||||
opcode: Opcode,
|
||||
key: str,
|
||||
):
|
||||
"""Server reassembles a fragmented message and delivers it to the app intact."""
|
||||
|
||||
received: list[bytes] = []
|
||||
received: list[str | bytes] = []
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
assert scope["type"] == "websocket"
|
||||
@ -791,8 +796,8 @@ async def test_fragmented_message_reassembly(
|
||||
await send({"type": "websocket.accept"})
|
||||
message = await receive()
|
||||
assert message["type"] == "websocket.receive"
|
||||
payload = message.get("bytes")
|
||||
assert payload is not None
|
||||
payload = message.get(key)
|
||||
assert isinstance(payload, (str, bytes))
|
||||
received.append(payload)
|
||||
await send({"type": "websocket.close"})
|
||||
|
||||
@ -800,12 +805,13 @@ async def test_fragmented_message_reassembly(
|
||||
async with run_server(config):
|
||||
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}") as ws:
|
||||
payload = b"A" * 512
|
||||
await ws.write_frame(False, Opcode.BINARY, payload)
|
||||
await ws.write_frame(False, opcode, payload)
|
||||
for _ in range(4):
|
||||
await ws.write_frame(False, Opcode.CONT, payload)
|
||||
await ws.write_frame(True, Opcode.CONT, payload)
|
||||
|
||||
assert received == [b"A" * 512 * 6]
|
||||
expected: str | bytes = "A" * 512 * 6 if opcode is Opcode.TEXT else b"A" * 512 * 6
|
||||
assert received == [expected]
|
||||
|
||||
|
||||
async def test_server_reject_connection(
|
||||
|
||||
@ -105,7 +105,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.last_ping_rtt: float = 0.0
|
||||
|
||||
# Buffers
|
||||
self.bytes = bytearray()
|
||||
self.bytes: bytes | bytearray = b""
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called when a connection is made."""
|
||||
@ -217,21 +217,28 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.tasks.add(task)
|
||||
|
||||
def handle_cont(self, event: Frame) -> None:
|
||||
assert isinstance(self.bytes, bytearray)
|
||||
self.bytes.extend(event.data)
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_text(self, event: Frame) -> None:
|
||||
self.bytes = bytearray(event.data)
|
||||
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
|
||||
if event.fin:
|
||||
# Single-frame message: pass through as bytes (zero copies).
|
||||
self.bytes = event.data
|
||||
self.send_receive_event_to_app()
|
||||
else:
|
||||
# Fragmented: promote to bytearray so continuations extend in place.
|
||||
self.bytes = bytearray(event.data)
|
||||
|
||||
def handle_bytes(self, event: Frame) -> None:
|
||||
self.bytes = bytearray(event.data)
|
||||
self.curr_msg_data_type = "bytes"
|
||||
if event.fin:
|
||||
self.bytes = event.data
|
||||
self.send_receive_event_to_app()
|
||||
else:
|
||||
self.bytes = bytearray(event.data)
|
||||
|
||||
def send_receive_event_to_app(self) -> None:
|
||||
if self.curr_msg_data_type == "text":
|
||||
@ -243,7 +250,9 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.handle_parser_exception()
|
||||
return
|
||||
else:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": bytes(self.bytes)})
|
||||
# Freeze the buffer to bytes only when it was grown across fragments.
|
||||
payload = self.bytes if isinstance(self.bytes, bytes) else bytes(self.bytes)
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": payload})
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user