Merge 33c8039298 into 8dc7efade2
This commit is contained in:
commit
32bb4926a6
211
test/asynchronous/test_async_network_layer.py
Normal file
211
test/asynchronous/test_async_network_layer.py
Normal file
@ -0,0 +1,211 @@
|
||||
# Copyright 2026-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Async-only unit tests for network_layer.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncUnitTest, unittest
|
||||
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.errors import ProtocolError
|
||||
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive
|
||||
|
||||
|
||||
async def _make_protocol(timeout=None):
|
||||
protocol = PyMongoProtocol(timeout=timeout)
|
||||
mock_transport = MagicMock()
|
||||
mock_transport.is_closing.return_value = False
|
||||
protocol.transport = mock_transport
|
||||
return protocol
|
||||
|
||||
|
||||
def _make_header(length, request_id, response_to, op_code):
|
||||
return struct.pack("<iiii", length, request_id, response_to, op_code)
|
||||
|
||||
|
||||
class TestPyMongoProtocol(AsyncUnitTest):
|
||||
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
|
||||
protocol = await _make_protocol()
|
||||
protocol._max_message_size = max_size
|
||||
protocol._header = memoryview(bytearray(header_bytes))
|
||||
return protocol
|
||||
|
||||
async def test_initial_timeout_from_constructor(self):
|
||||
protocol = await _make_protocol(timeout=3.0)
|
||||
self.assertEqual(protocol.gettimeout, 3.0)
|
||||
|
||||
async def test_settimeout_updates_value(self):
|
||||
protocol = await _make_protocol()
|
||||
protocol.settimeout(7.5)
|
||||
self.assertEqual(protocol.gettimeout, 7.5)
|
||||
|
||||
async def test_default_timeout_is_none(self):
|
||||
protocol = await _make_protocol()
|
||||
self.assertIsNone(protocol.gettimeout)
|
||||
|
||||
async def test_normal_op_msg(self):
|
||||
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
|
||||
protocol = await self._make_proto_with_header(header)
|
||||
body_len, op_code, response_to, expecting_compression = protocol.process_header()
|
||||
self.assertEqual(body_len, 16)
|
||||
self.assertEqual(op_code, 2013)
|
||||
self.assertEqual(response_to, 99)
|
||||
self.assertFalse(expecting_compression)
|
||||
|
||||
async def test_op_compressed(self):
|
||||
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
|
||||
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
|
||||
# length=35 → after compression sub-header: 26 → body: 10
|
||||
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
|
||||
protocol = await self._make_proto_with_header(header)
|
||||
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
|
||||
self.assertEqual(body_len, 10)
|
||||
self.assertEqual(op_code, 2012)
|
||||
self.assertTrue(expecting_compression)
|
||||
|
||||
async def test_op_compressed_length_too_small_raises(self):
|
||||
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
|
||||
protocol = await self._make_proto_with_header(header)
|
||||
with self.assertRaises(ProtocolError):
|
||||
protocol.process_header()
|
||||
|
||||
async def test_non_compressed_length_too_small_raises(self):
|
||||
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
|
||||
protocol = await self._make_proto_with_header(header)
|
||||
with self.assertRaises(ProtocolError):
|
||||
protocol.process_header()
|
||||
|
||||
async def test_length_exceeds_max_raises(self):
|
||||
header = _make_header(
|
||||
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
|
||||
)
|
||||
protocol = await self._make_proto_with_header(header)
|
||||
with self.assertRaises(ProtocolError):
|
||||
protocol.process_header()
|
||||
|
||||
async def test_op_reply_op_code(self):
|
||||
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
|
||||
protocol = await self._make_proto_with_header(header)
|
||||
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
|
||||
self.assertEqual(body_len, 4)
|
||||
self.assertEqual(op_code, 1)
|
||||
self.assertFalse(expecting_compression)
|
||||
|
||||
async def test_compression_header_snappy_compressor_id(self):
|
||||
protocol = await _make_protocol()
|
||||
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
|
||||
data = struct.pack("<iiB", 2013, 0, 1)
|
||||
protocol._compression_header = memoryview(bytearray(data))
|
||||
op_code, compressor_id = protocol.process_compression_header()
|
||||
self.assertEqual(op_code, 2013)
|
||||
self.assertEqual(compressor_id, 1)
|
||||
|
||||
async def test_compression_header_zlib_compressor_id(self):
|
||||
protocol = await _make_protocol()
|
||||
data = struct.pack("<iiB", 2013, 0, 2)
|
||||
protocol._compression_header = memoryview(bytearray(data))
|
||||
_, compressor_id = protocol.process_compression_header()
|
||||
self.assertEqual(compressor_id, 2)
|
||||
|
||||
async def test_message_complete_resolves_pending_future(self):
|
||||
protocol = await _make_protocol()
|
||||
protocol._expecting_header = False
|
||||
protocol._expecting_compression = False
|
||||
protocol._message_size = 10
|
||||
protocol._message = memoryview(bytearray(10))
|
||||
protocol._message_index = 0
|
||||
protocol._op_code = 2013
|
||||
protocol._compressor_id = None
|
||||
protocol._response_to = 42
|
||||
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
protocol._pending_messages.append(future)
|
||||
|
||||
protocol.buffer_updated(10)
|
||||
self.assertTrue(future.done())
|
||||
op_code, compressor_id, response_to, _ = future.result()
|
||||
self.assertEqual(op_code, 2013)
|
||||
self.assertIsNone(compressor_id)
|
||||
self.assertEqual(response_to, 42)
|
||||
|
||||
async def test_close_aborts_transport(self):
|
||||
protocol = await _make_protocol()
|
||||
protocol.close()
|
||||
self.assertTrue(protocol.transport.abort.called)
|
||||
|
||||
async def test_connection_lost_twice_does_not_raise(self):
|
||||
protocol = await _make_protocol()
|
||||
protocol.connection_lost(None)
|
||||
protocol.connection_lost(None)
|
||||
|
||||
async def test_close_with_exception_propagates_to_pending(self):
|
||||
protocol = await _make_protocol()
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
protocol._pending_messages.append(future)
|
||||
exc = OSError("connection reset")
|
||||
protocol.close(exc)
|
||||
with self.assertRaisesRegex(OSError, "connection reset"):
|
||||
await future
|
||||
|
||||
|
||||
class TestAsyncSocketReceive(AsyncUnitTest):
|
||||
async def test_reads_data_in_multiple_chunks(self):
|
||||
# Covers the loop in _async_socket_receive that accumulates short reads
|
||||
# until the requested length has been received.
|
||||
data = b"abcdefgh"
|
||||
length = len(data)
|
||||
chunk1, chunk2 = data[:4], data[4:]
|
||||
mock_socket = MagicMock()
|
||||
loop = asyncio.get_running_loop()
|
||||
calls = 0
|
||||
|
||||
async def fake_recv_into(sock, buf):
|
||||
nonlocal calls
|
||||
if calls == 0:
|
||||
buf[: len(chunk1)] = chunk1
|
||||
calls += 1
|
||||
return len(chunk1)
|
||||
buf[: len(chunk2)] = chunk2
|
||||
calls += 1
|
||||
return len(chunk2)
|
||||
|
||||
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
|
||||
result = await _async_socket_receive(mock_socket, length, loop)
|
||||
self.assertEqual(bytes(result), data)
|
||||
self.assertEqual(calls, 2)
|
||||
|
||||
async def test_raises_on_connection_closed(self):
|
||||
# Covers the explicit `raise OSError("connection closed")` branch when
|
||||
# sock_recv_into returns 0.
|
||||
mock_socket = MagicMock()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def fake_recv_into(sock, buf):
|
||||
return 0
|
||||
|
||||
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
|
||||
with self.assertRaisesRegex(OSError, "connection closed"):
|
||||
await _async_socket_receive(mock_socket, 10, loop)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
64
test/asynchronous/test_network_layer.py
Normal file
64
test/asynchronous/test_network_layer.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2026-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for code in network_layer.py shared between sync and async APIs.
|
||||
|
||||
Async-only tests live in ``test_async_network_layer.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncUnitTest, unittest
|
||||
|
||||
from pymongo.network_layer import NetworkingInterfaceBase
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestNetworkingInterfaceBase(AsyncUnitTest):
|
||||
async def asyncSetUp(self):
|
||||
self.base = NetworkingInterfaceBase(MagicMock())
|
||||
|
||||
def test_gettimeout_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = self.base.gettimeout
|
||||
|
||||
def test_settimeout_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.base.settimeout(1.0)
|
||||
|
||||
def test_close_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.base.close()
|
||||
|
||||
def test_is_closing_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.base.is_closing()
|
||||
|
||||
def test_get_conn_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = self.base.get_conn
|
||||
|
||||
def test_sock_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = self.base.sock
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
64
test/test_network_layer.py
Normal file
64
test/test_network_layer.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2026-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for code in network_layer.py shared between sync and async APIs.
|
||||
|
||||
Async-only tests live in ``test_async_network_layer.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import UnitTest, unittest
|
||||
|
||||
from pymongo.network_layer import NetworkingInterfaceBase
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestNetworkingInterfaceBase(UnitTest):
|
||||
def setUp(self):
|
||||
self.base = NetworkingInterfaceBase(MagicMock())
|
||||
|
||||
def test_gettimeout_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = self.base.gettimeout
|
||||
|
||||
def test_settimeout_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.base.settimeout(1.0)
|
||||
|
||||
def test_close_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.base.close()
|
||||
|
||||
def test_is_closing_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.base.is_closing()
|
||||
|
||||
def test_get_conn_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = self.base.get_conn
|
||||
|
||||
def test_sock_raises(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = self.base.sock
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -128,6 +128,7 @@ replacements = {
|
||||
"SpecRunnerTask": "SpecRunnerThread",
|
||||
"AsyncMockConnection": "MockConnection",
|
||||
"AsyncMockPool": "MockPool",
|
||||
"AsyncMock": "MagicMock",
|
||||
"StopAsyncIteration": "StopIteration",
|
||||
"create_async_event": "create_event",
|
||||
"async_create_barrier": "create_barrier",
|
||||
@ -190,6 +191,7 @@ def async_only_test(f: str) -> bool:
|
||||
"test_async_loop_safety.py",
|
||||
"test_async_contextvars_reset.py",
|
||||
"test_async_loop_unblocked.py",
|
||||
"test_async_network_layer.py",
|
||||
]
|
||||
|
||||
|
||||
@ -251,6 +253,7 @@ converted_tests = [
|
||||
"test_monitor.py",
|
||||
"test_monitoring.py",
|
||||
"test_mongos_load_balancing.py",
|
||||
"test_network_layer.py",
|
||||
"test_on_demand_csfle.py",
|
||||
"test_pooling.py",
|
||||
"test_raw_bson.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user