Noah + Copilot review

This commit is contained in:
Jeffrey 'Alex' Clark 2026-05-06 17:22:56 -04:00
parent cdc07e3ed9
commit 33c8039298
4 changed files with 225 additions and 554 deletions

View File

@ -0,0 +1,211 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Async-only unit tests for network_layer.py."""
from __future__ import annotations
import asyncio
import struct
import sys
from unittest.mock import AsyncMock, MagicMock, patch
sys.path[0:0] = [""]
from test.asynchronous import AsyncUnitTest, unittest
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive
async def _make_protocol(timeout=None):
protocol = PyMongoProtocol(timeout=timeout)
mock_transport = MagicMock()
mock_transport.is_closing.return_value = False
protocol.transport = mock_transport
return protocol
def _make_header(length, request_id, response_to, op_code):
return struct.pack("<iiii", length, request_id, response_to, op_code)
class TestPyMongoProtocol(AsyncUnitTest):
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
protocol = await _make_protocol()
protocol._max_message_size = max_size
protocol._header = memoryview(bytearray(header_bytes))
return protocol
async def test_initial_timeout_from_constructor(self):
protocol = await _make_protocol(timeout=3.0)
self.assertEqual(protocol.gettimeout, 3.0)
async def test_settimeout_updates_value(self):
protocol = await _make_protocol()
protocol.settimeout(7.5)
self.assertEqual(protocol.gettimeout, 7.5)
async def test_default_timeout_is_none(self):
protocol = await _make_protocol()
self.assertIsNone(protocol.gettimeout)
async def test_normal_op_msg(self):
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
protocol = await self._make_proto_with_header(header)
body_len, op_code, response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 16)
self.assertEqual(op_code, 2013)
self.assertEqual(response_to, 99)
self.assertFalse(expecting_compression)
async def test_op_compressed(self):
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
# length=35 → after compression sub-header: 26 → body: 10
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
protocol = await self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 10)
self.assertEqual(op_code, 2012)
self.assertTrue(expecting_compression)
async def test_op_compressed_length_too_small_raises(self):
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
async def test_non_compressed_length_too_small_raises(self):
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
async def test_length_exceeds_max_raises(self):
header = _make_header(
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
async def test_op_reply_op_code(self):
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
protocol = await self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 4)
self.assertEqual(op_code, 1)
self.assertFalse(expecting_compression)
async def test_compression_header_snappy_compressor_id(self):
protocol = await _make_protocol()
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
data = struct.pack("<iiB", 2013, 0, 1)
protocol._compression_header = memoryview(bytearray(data))
op_code, compressor_id = protocol.process_compression_header()
self.assertEqual(op_code, 2013)
self.assertEqual(compressor_id, 1)
async def test_compression_header_zlib_compressor_id(self):
protocol = await _make_protocol()
data = struct.pack("<iiB", 2013, 0, 2)
protocol._compression_header = memoryview(bytearray(data))
_, compressor_id = protocol.process_compression_header()
self.assertEqual(compressor_id, 2)
async def test_message_complete_resolves_pending_future(self):
protocol = await _make_protocol()
protocol._expecting_header = False
protocol._expecting_compression = False
protocol._message_size = 10
protocol._message = memoryview(bytearray(10))
protocol._message_index = 0
protocol._op_code = 2013
protocol._compressor_id = None
protocol._response_to = 42
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
protocol.buffer_updated(10)
self.assertTrue(future.done())
op_code, compressor_id, response_to, _ = future.result()
self.assertEqual(op_code, 2013)
self.assertIsNone(compressor_id)
self.assertEqual(response_to, 42)
async def test_close_aborts_transport(self):
protocol = await _make_protocol()
protocol.close()
self.assertTrue(protocol.transport.abort.called)
async def test_connection_lost_twice_does_not_raise(self):
protocol = await _make_protocol()
protocol.connection_lost(None)
protocol.connection_lost(None)
async def test_close_with_exception_propagates_to_pending(self):
protocol = await _make_protocol()
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
exc = OSError("connection reset")
protocol.close(exc)
with self.assertRaisesRegex(OSError, "connection reset"):
await future
class TestAsyncSocketReceive(AsyncUnitTest):
async def test_reads_data_in_multiple_chunks(self):
# Covers the loop in _async_socket_receive that accumulates short reads
# until the requested length has been received.
data = b"abcdefgh"
length = len(data)
chunk1, chunk2 = data[:4], data[4:]
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
calls = 0
async def fake_recv_into(sock, buf):
nonlocal calls
if calls == 0:
buf[: len(chunk1)] = chunk1
calls += 1
return len(chunk1)
buf[: len(chunk2)] = chunk2
calls += 1
return len(chunk2)
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
result = await _async_socket_receive(mock_socket, length, loop)
self.assertEqual(bytes(result), data)
self.assertEqual(calls, 2)
async def test_raises_on_connection_closed(self):
# Covers the explicit `raise OSError("connection closed")` branch when
# sock_recv_into returns 0.
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
async def fake_recv_into(sock, buf):
return 0
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
with self.assertRaisesRegex(OSError, "connection closed"):
await _async_socket_receive(mock_socket, 10, loop)
if __name__ == "__main__":
unittest.main()

View File

@ -12,54 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for network_layer.py."""
"""Unit tests for code in network_layer.py shared between sync and async APIs.
Async-only tests live in ``test_async_network_layer.py``.
"""
from __future__ import annotations
import asyncio
import struct
import sys
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock
sys.path[0:0] = [""]
from test.asynchronous import AsyncUnitTest, unittest
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError
from pymongo.network_layer import (
AsyncNetworkingInterface,
NetworkingInterface,
NetworkingInterfaceBase,
PyMongoProtocol,
_async_socket_receive,
sendall,
)
from pymongo.network_layer import NetworkingInterfaceBase
_IS_SYNC = False
async def _make_protocol(timeout=None):
protocol = PyMongoProtocol(timeout=timeout)
mock_transport = MagicMock()
mock_transport.is_closing.return_value = False
protocol.transport = mock_transport
return protocol
def _make_header(length, request_id, response_to, op_code):
return struct.pack("<iiii", length, request_id, response_to, op_code)
class TestSendall(AsyncUnitTest):
def test_delegates_to_sock_sendall(self):
mock_socket = MagicMock()
sendall(mock_socket, b"hello")
mock_socket.sendall.assert_called_once_with(b"hello")
class TestNetworkingInterfaceBase(AsyncUnitTest):
def setUp(self):
async def asyncSetUp(self):
self.base = NetworkingInterfaceBase(MagicMock())
def test_gettimeout_raises(self):
@ -87,249 +60,5 @@ class TestNetworkingInterfaceBase(AsyncUnitTest):
_ = self.base.sock
class TestNetworkingInterface(AsyncUnitTest):
def setUp(self):
self.mock_socket = MagicMock()
self.network_interface = NetworkingInterface(self.mock_socket)
def test_gettimeout_delegates(self):
self.mock_socket.gettimeout.return_value = 5.0
self.assertEqual(self.network_interface.gettimeout(), 5.0)
def test_settimeout_delegates(self):
self.network_interface.settimeout(3.0)
self.mock_socket.settimeout.assert_called_once_with(3.0)
def test_close_delegates(self):
self.network_interface.close()
self.mock_socket.close.assert_called_once()
def test_is_closing_delegates(self):
self.mock_socket.is_closing.return_value = True
self.assertTrue(self.network_interface.is_closing())
def test_fileno_delegates(self):
self.mock_socket.fileno.return_value = 42
self.assertEqual(self.network_interface.fileno(), 42)
def test_recv_into_delegates(self):
buf = memoryview(bytearray(10))
self.mock_socket.recv_into.return_value = 7
result = self.network_interface.recv_into(buf)
self.assertEqual(result, 7)
self.mock_socket.recv_into.assert_called_once_with(buf)
def test_get_conn_returns_socket(self):
self.assertIs(self.network_interface.get_conn, self.mock_socket)
def test_sock_returns_socket(self):
self.assertIs(self.network_interface.sock, self.mock_socket)
if not _IS_SYNC:
class TestAsyncNetworkingInterface(AsyncUnitTest):
def _make_network_interface(self):
mock_transport = MagicMock()
mock_protocol = MagicMock()
mock_protocol.gettimeout = 10.0
return AsyncNetworkingInterface((mock_transport, mock_protocol))
def test_gettimeout_returns_protocol_timeout(self):
network_interface = self._make_network_interface()
self.assertEqual(network_interface.gettimeout, 10.0)
def test_settimeout_delegates_to_protocol(self):
network_interface = self._make_network_interface()
network_interface.settimeout(7.0)
network_interface.conn[1].settimeout.assert_called_once_with(7.0)
def test_is_closing_delegates_to_transport(self):
network_interface = self._make_network_interface()
network_interface.conn[0].is_closing.return_value = False
self.assertFalse(network_interface.is_closing())
def test_get_conn_returns_protocol(self):
network_interface = self._make_network_interface()
self.assertIs(network_interface.get_conn, network_interface.conn[1])
def test_sock_returns_transport_socket(self):
network_interface = self._make_network_interface()
sentinel = object()
network_interface.conn[0].get_extra_info.return_value = sentinel
self.assertIs(network_interface.sock, sentinel)
network_interface.conn[0].get_extra_info.assert_called_once_with("socket")
class TestPyMongoProtocol(AsyncUnitTest):
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
protocol = await _make_protocol()
protocol._max_message_size = max_size
protocol._header = memoryview(bytearray(header_bytes))
return protocol
async def test_initial_timeout_from_constructor(self):
protocol = await _make_protocol(timeout=3.0)
self.assertEqual(protocol.gettimeout, 3.0)
async def test_settimeout_updates_value(self):
protocol = await _make_protocol()
protocol.settimeout(7.5)
self.assertEqual(protocol.gettimeout, 7.5)
async def test_default_timeout_is_none(self):
protocol = await _make_protocol()
self.assertIsNone(protocol.gettimeout)
async def test_normal_op_msg(self):
header = _make_header(32, 1, 99, 2013)
protocol = await self._make_proto_with_header(header)
body_len, op_code, response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 16)
self.assertEqual(op_code, 2013)
self.assertEqual(response_to, 99)
self.assertFalse(expecting_compression)
async def test_op_compressed(self):
# OP_COMPRESSED=2012, length=35 → adjusted=35-9=26 → body=26-16=10
header = _make_header(35, 1, 0, 2012)
protocol = await self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 10)
self.assertEqual(op_code, 2012)
self.assertTrue(expecting_compression)
async def test_op_compressed_length_too_small_raises(self):
header = _make_header(25, 1, 0, 2012)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
async def test_non_compressed_length_too_small_raises(self):
header = _make_header(16, 1, 0, 2013)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
async def test_length_exceeds_max_raises(self):
header = _make_header(MAX_MESSAGE_SIZE + 1, 1, 0, 2013)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
async def test_op_reply_op_code(self):
header = _make_header(20, 0, 0, 1)
protocol = await self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 4)
self.assertEqual(op_code, 1)
self.assertFalse(expecting_compression)
async def test_compression_header_returns_op_code_and_compressor_id(self):
protocol = await _make_protocol()
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
data = struct.pack("<iiB", 2013, 0, 1)
protocol._compression_header = memoryview(bytearray(data))
op_code, compressor_id = protocol.process_compression_header()
self.assertEqual(op_code, 2013)
self.assertEqual(compressor_id, 1)
async def test_compression_header_zlib_compressor_id(self):
protocol = await _make_protocol()
data = struct.pack("<iiB", 2013, 0, 2)
protocol._compression_header = memoryview(bytearray(data))
_, compressor_id = protocol.process_compression_header()
self.assertEqual(compressor_id, 2)
async def test_message_complete_resolves_pending_future(self):
protocol = await _make_protocol()
protocol._expecting_header = False
protocol._expecting_compression = False
protocol._message_size = 10
protocol._message = memoryview(bytearray(10))
protocol._message_index = 0
protocol._op_code = 2013
protocol._compressor_id = None
protocol._response_to = 42
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
protocol.buffer_updated(10)
self.assertTrue(future.done())
op_code, compressor_id, response_to, _ = future.result()
self.assertEqual(op_code, 2013)
self.assertIsNone(compressor_id)
self.assertEqual(response_to, 42)
async def test_close_aborts_transport(self):
protocol = await _make_protocol()
protocol.close()
self.assertTrue(protocol.transport.abort.called)
async def test_connection_lost_twice_does_not_raise(self):
protocol = await _make_protocol()
protocol.connection_lost(None)
protocol.connection_lost(None)
async def test_close_with_exception_propagates_to_pending(self):
protocol = await _make_protocol()
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
exc = OSError("connection reset")
protocol.close(exc)
with self.assertRaises(OSError) as ctx:
await future
self.assertIn("connection reset", str(ctx.exception))
class TestAsyncSocketReceive(AsyncUnitTest):
async def test_reads_full_data_in_one_call(self):
data = b"hello world!"
length = len(data)
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
async def fake_recv_into(sock, buf):
buf[:length] = data
return length
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
result = await _async_socket_receive(mock_socket, length, loop)
self.assertEqual(bytes(result), data)
async def test_reads_data_in_multiple_chunks(self):
data = b"abcdefgh"
length = len(data)
chunk1, chunk2 = data[:4], data[4:]
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
calls = 0
async def fake_recv_into(sock, buf):
nonlocal calls
if calls == 0:
buf[: len(chunk1)] = chunk1
calls += 1
return len(chunk1)
buf[: len(chunk2)] = chunk2
calls += 1
return len(chunk2)
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
result = await _async_socket_receive(mock_socket, length, loop)
self.assertEqual(bytes(result), data)
async def test_raises_on_connection_closed(self):
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
async def fake_recv_into(sock, buf):
return 0
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
with self.assertRaises(OSError) as ctx:
await _async_socket_receive(mock_socket, 10, loop)
self.assertIn("connection closed", str(ctx.exception))
if __name__ == "__main__":
unittest.main()

View File

@ -12,51 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for network_layer.py."""
"""Unit tests for code in network_layer.py shared between sync and async APIs.
Async-only tests live in ``test_async_network_layer.py``.
"""
from __future__ import annotations
import asyncio
import struct
import sys
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
sys.path[0:0] = [""]
from test import UnitTest, unittest
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError
from pymongo.network_layer import (
NetworkingInterface,
NetworkingInterfaceBase,
PyMongoProtocol,
_async_socket_receive,
sendall,
)
from pymongo.network_layer import NetworkingInterfaceBase
_IS_SYNC = True
def _make_protocol(timeout=None):
protocol = PyMongoProtocol(timeout=timeout)
mock_transport = MagicMock()
mock_transport.is_closing.return_value = False
protocol.transport = mock_transport
return protocol
def _make_header(length, request_id, response_to, op_code):
return struct.pack("<iiii", length, request_id, response_to, op_code)
class TestSendall(UnitTest):
def test_delegates_to_sock_sendall(self):
mock_socket = MagicMock()
sendall(mock_socket, b"hello")
mock_socket.sendall.assert_called_once_with(b"hello")
class TestNetworkingInterfaceBase(UnitTest):
def setUp(self):
self.base = NetworkingInterfaceBase(MagicMock())
@ -86,249 +60,5 @@ class TestNetworkingInterfaceBase(UnitTest):
_ = self.base.sock
class TestNetworkingInterface(UnitTest):
def setUp(self):
self.mock_socket = MagicMock()
self.network_interface = NetworkingInterface(self.mock_socket)
def test_gettimeout_delegates(self):
self.mock_socket.gettimeout.return_value = 5.0
self.assertEqual(self.network_interface.gettimeout(), 5.0)
def test_settimeout_delegates(self):
self.network_interface.settimeout(3.0)
self.mock_socket.settimeout.assert_called_once_with(3.0)
def test_close_delegates(self):
self.network_interface.close()
self.mock_socket.close.assert_called_once()
def test_is_closing_delegates(self):
self.mock_socket.is_closing.return_value = True
self.assertTrue(self.network_interface.is_closing())
def test_fileno_delegates(self):
self.mock_socket.fileno.return_value = 42
self.assertEqual(self.network_interface.fileno(), 42)
def test_recv_into_delegates(self):
buf = memoryview(bytearray(10))
self.mock_socket.recv_into.return_value = 7
result = self.network_interface.recv_into(buf)
self.assertEqual(result, 7)
self.mock_socket.recv_into.assert_called_once_with(buf)
def test_get_conn_returns_socket(self):
self.assertIs(self.network_interface.get_conn, self.mock_socket)
def test_sock_returns_socket(self):
self.assertIs(self.network_interface.sock, self.mock_socket)
if not _IS_SYNC:
class TestNetworkingInterface(UnitTest):
def _make_network_interface(self):
mock_transport = MagicMock()
mock_protocol = MagicMock()
mock_protocol.gettimeout = 10.0
return NetworkingInterface((mock_transport, mock_protocol))
def test_gettimeout_returns_protocol_timeout(self):
network_interface = self._make_network_interface()
self.assertEqual(network_interface.gettimeout, 10.0)
def test_settimeout_delegates_to_protocol(self):
network_interface = self._make_network_interface()
network_interface.settimeout(7.0)
network_interface.conn[1].settimeout.assert_called_once_with(7.0)
def test_is_closing_delegates_to_transport(self):
network_interface = self._make_network_interface()
network_interface.conn[0].is_closing.return_value = False
self.assertFalse(network_interface.is_closing())
def test_get_conn_returns_protocol(self):
network_interface = self._make_network_interface()
self.assertIs(network_interface.get_conn, network_interface.conn[1])
def test_sock_returns_transport_socket(self):
network_interface = self._make_network_interface()
sentinel = object()
network_interface.conn[0].get_extra_info.return_value = sentinel
self.assertIs(network_interface.sock, sentinel)
network_interface.conn[0].get_extra_info.assert_called_once_with("socket")
class TestPyMongoProtocol(UnitTest):
def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
protocol = _make_protocol()
protocol._max_message_size = max_size
protocol._header = memoryview(bytearray(header_bytes))
return protocol
def test_initial_timeout_from_constructor(self):
protocol = _make_protocol(timeout=3.0)
self.assertEqual(protocol.gettimeout, 3.0)
def test_settimeout_updates_value(self):
protocol = _make_protocol()
protocol.settimeout(7.5)
self.assertEqual(protocol.gettimeout, 7.5)
def test_default_timeout_is_none(self):
protocol = _make_protocol()
self.assertIsNone(protocol.gettimeout)
def test_normal_op_msg(self):
header = _make_header(32, 1, 99, 2013)
protocol = self._make_proto_with_header(header)
body_len, op_code, response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 16)
self.assertEqual(op_code, 2013)
self.assertEqual(response_to, 99)
self.assertFalse(expecting_compression)
def test_op_compressed(self):
# OP_COMPRESSED=2012, length=35 → adjusted=35-9=26 → body=26-16=10
header = _make_header(35, 1, 0, 2012)
protocol = self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 10)
self.assertEqual(op_code, 2012)
self.assertTrue(expecting_compression)
def test_op_compressed_length_too_small_raises(self):
header = _make_header(25, 1, 0, 2012)
protocol = self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
def test_non_compressed_length_too_small_raises(self):
header = _make_header(16, 1, 0, 2013)
protocol = self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
def test_length_exceeds_max_raises(self):
header = _make_header(MAX_MESSAGE_SIZE + 1, 1, 0, 2013)
protocol = self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()
def test_op_reply_op_code(self):
header = _make_header(20, 0, 0, 1)
protocol = self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 4)
self.assertEqual(op_code, 1)
self.assertFalse(expecting_compression)
def test_compression_header_returns_op_code_and_compressor_id(self):
protocol = _make_protocol()
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
data = struct.pack("<iiB", 2013, 0, 1)
protocol._compression_header = memoryview(bytearray(data))
op_code, compressor_id = protocol.process_compression_header()
self.assertEqual(op_code, 2013)
self.assertEqual(compressor_id, 1)
def test_compression_header_zlib_compressor_id(self):
protocol = _make_protocol()
data = struct.pack("<iiB", 2013, 0, 2)
protocol._compression_header = memoryview(bytearray(data))
_, compressor_id = protocol.process_compression_header()
self.assertEqual(compressor_id, 2)
def test_message_complete_resolves_pending_future(self):
protocol = _make_protocol()
protocol._expecting_header = False
protocol._expecting_compression = False
protocol._message_size = 10
protocol._message = memoryview(bytearray(10))
protocol._message_index = 0
protocol._op_code = 2013
protocol._compressor_id = None
protocol._response_to = 42
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
protocol.buffer_updated(10)
self.assertTrue(future.done())
op_code, compressor_id, response_to, _ = future.result()
self.assertEqual(op_code, 2013)
self.assertIsNone(compressor_id)
self.assertEqual(response_to, 42)
def test_close_aborts_transport(self):
protocol = _make_protocol()
protocol.close()
self.assertTrue(protocol.transport.abort.called)
def test_connection_lost_twice_does_not_raise(self):
protocol = _make_protocol()
protocol.connection_lost(None)
protocol.connection_lost(None)
def test_close_with_exception_propagates_to_pending(self):
protocol = _make_protocol()
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
exc = OSError("connection reset")
protocol.close(exc)
with self.assertRaises(OSError) as ctx:
future
self.assertIn("connection reset", str(ctx.exception))
class TestAsyncSocketReceive(UnitTest):
def test_reads_full_data_in_one_call(self):
data = b"hello world!"
length = len(data)
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
def fake_recv_into(sock, buf):
buf[:length] = data
return length
with patch.object(loop, "sock_recv_into", new=MagicMock(side_effect=fake_recv_into)):
result = _async_socket_receive(mock_socket, length, loop)
self.assertEqual(bytes(result), data)
def test_reads_data_in_multiple_chunks(self):
data = b"abcdefgh"
length = len(data)
chunk1, chunk2 = data[:4], data[4:]
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
calls = 0
def fake_recv_into(sock, buf):
nonlocal calls
if calls == 0:
buf[: len(chunk1)] = chunk1
calls += 1
return len(chunk1)
buf[: len(chunk2)] = chunk2
calls += 1
return len(chunk2)
with patch.object(loop, "sock_recv_into", new=MagicMock(side_effect=fake_recv_into)):
result = _async_socket_receive(mock_socket, length, loop)
self.assertEqual(bytes(result), data)
def test_raises_on_connection_closed(self):
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
def fake_recv_into(sock, buf):
return 0
with patch.object(loop, "sock_recv_into", new=MagicMock(side_effect=fake_recv_into)):
with self.assertRaises(OSError) as ctx:
_async_socket_receive(mock_socket, 10, loop)
self.assertIn("connection closed", str(ctx.exception))
if __name__ == "__main__":
unittest.main()

View File

@ -191,6 +191,7 @@ def async_only_test(f: str) -> bool:
"test_async_loop_safety.py",
"test_async_contextvars_reset.py",
"test_async_loop_unblocked.py",
"test_async_network_layer.py",
]