Use bytearray for incoming WebSocket message buffer in websockets-sansio (#2917)
This commit is contained in:
parent
d438fb16fe
commit
7375b5bf66
@ -777,6 +777,37 @@ async def test_fragmented_message_exceeding_max_size(
|
||||
assert exc_info.value.rcvd.code == 1009
|
||||
|
||||
|
||||
async def test_fragmented_message_reassembly(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
"""Server reassembles a fragmented message and delivers it to the app intact."""
|
||||
|
||||
received: list[bytes] = []
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
assert scope["type"] == "websocket"
|
||||
connect = await receive()
|
||||
assert connect["type"] == "websocket.connect"
|
||||
await send({"type": "websocket.accept"})
|
||||
message = await receive()
|
||||
assert message["type"] == "websocket.receive"
|
||||
payload = message.get("bytes")
|
||||
assert payload is not None
|
||||
received.append(payload)
|
||||
await send({"type": "websocket.close"})
|
||||
|
||||
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
||||
async with run_server(config):
|
||||
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}") as ws:
|
||||
payload = b"A" * 512
|
||||
await ws.write_frame(False, Opcode.BINARY, payload)
|
||||
for _ in range(4):
|
||||
await ws.write_frame(False, Opcode.CONT, payload)
|
||||
await ws.write_frame(True, Opcode.CONT, payload)
|
||||
|
||||
assert received == [b"A" * 512 * 6]
|
||||
|
||||
|
||||
async def test_server_reject_connection(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
|
||||
@ -105,7 +105,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.last_ping_rtt: float = 0.0
|
||||
|
||||
# Buffers
|
||||
self.bytes = b""
|
||||
self.bytes = bytearray()
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called when a connection is made."""
|
||||
@ -216,19 +216,19 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
|
||||
def handle_cont(self, event: Frame) -> None: # pragma: no cover
|
||||
self.bytes += event.data
|
||||
def handle_cont(self, event: Frame) -> None:
|
||||
self.bytes.extend(event.data)
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_text(self, event: Frame) -> None:
|
||||
self.bytes = event.data
|
||||
self.bytes = bytearray(event.data)
|
||||
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_bytes(self, event: Frame) -> None:
|
||||
self.bytes = event.data
|
||||
self.bytes = bytearray(event.data)
|
||||
self.curr_msg_data_type = "bytes"
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
@ -243,7 +243,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.handle_parser_exception()
|
||||
return
|
||||
else:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": bytes(self.bytes)})
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user