PYTHON-5078 Convert test.test_discovery_and_monitoring to async (#2093)

Co-authored-by: Noah Stapp <noah@noahstapp.com>
This commit is contained in:
Iris 2025-03-03 10:14:04 -08:00 committed by GitHub
parent 6da1fdbed9
commit 150a3ba756
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 631 additions and 40 deletions

View File

@ -0,0 +1,503 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the topology module."""
from __future__ import annotations
import asyncio
import os
import socketserver
import sys
import threading
from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.asynchronous.helpers import ConcurrentRunner
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, AsyncUnitTest, unittest
from test.asynchronous.pymongo_mocks import DummyMonitor
from test.asynchronous.unified_format import generate_test_classes
from test.utils import (
CMAPListener,
HeartbeatEventListener,
HeartbeatEventsListListener,
assertion_context,
async_barrier_wait,
async_client_context,
async_create_barrier,
async_get_pool,
async_wait_until,
server_name_to_type,
)
from unittest.mock import patch
from bson import Timestamp, json_util
from pymongo import AsyncMongoClient, common, monitoring
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.errors import (
AutoReconnect,
ConfigurationError,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _check_command_response, _check_write_command_response
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE
from pymongo.uri_parser import parse_uri
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring")
else:
SDAM_PATH = os.path.join(
Path(__file__).resolve().parent.parent,
"discovery_and_monitoring",
)
async def create_mock_topology(uri, monitor_class=DummyMonitor):
parsed_uri = parse_uri(uri)
replica_set_name = None
direct_connection = None
load_balanced = None
if "replicaset" in parsed_uri["options"]:
replica_set_name = parsed_uri["options"]["replicaset"]
if "directConnection" in parsed_uri["options"]:
direct_connection = parsed_uri["options"]["directConnection"]
if "loadBalanced" in parsed_uri["options"]:
load_balanced = parsed_uri["options"]["loadBalanced"]
topology_settings = TopologySettings(
parsed_uri["nodelist"],
replica_set_name=replica_set_name,
monitor_class=monitor_class,
direct_connection=direct_connection,
load_balanced=load_balanced,
)
c = Topology(topology_settings)
await c.open()
return c
async def got_hello(topology, server_address, hello_response):
server_description = ServerDescription(server_address, Hello(hello_response), 0)
await topology.on_change(server_description)
async def got_app_error(topology, app_error):
server_address = common.partition_node(app_error["address"])
server = topology.get_server_by_address(server_address)
error_type = app_error["type"]
generation = app_error.get("generation", server.pool.gen.get_overall())
when = app_error["when"]
max_wire_version = app_error["maxWireVersion"]
# XXX: We could get better test coverage by mocking the errors on the
# Pool/AsyncConnection.
try:
if error_type == "command":
_check_command_response(app_error["response"], max_wire_version)
_check_write_command_response(app_error["response"])
elif error_type == "network":
raise AutoReconnect("mock non-timeout network error")
elif error_type == "timeout":
raise NetworkTimeout("mock network timeout error")
else:
raise AssertionError(f"unknown error type: {error_type}")
raise AssertionError
except (AutoReconnect, NotPrimaryError, OperationFailure) as e:
if when == "beforeHandshakeCompletes":
completed_handshake = False
elif when == "afterHandshakeCompletes":
completed_handshake = True
else:
raise AssertionError(f"Unknown when field {when}")
await topology.handle_error(
server_address,
_ErrorContext(e, max_wire_version, generation, completed_handshake, None),
)
def get_type(topology, hostname):
description = topology.get_server_by_address((hostname, 27017)).description
return description.server_type
class TestAllScenarios(AsyncUnitTest):
pass
def topology_type_name(topology_type):
return TOPOLOGY_TYPE._fields[topology_type]
def server_type_name(server_type):
return SERVER_TYPE._fields[server_type]
def check_outcome(self, topology, outcome):
expected_servers = outcome["servers"]
# Check weak equality before proceeding.
self.assertEqual(len(topology.description.server_descriptions()), len(expected_servers))
if outcome.get("compatible") is False:
with self.assertRaises(ConfigurationError):
topology.description.check_compatible()
else:
# No error.
topology.description.check_compatible()
# Since lengths are equal, every actual server must have a corresponding
# expected server.
for expected_server_address, expected_server in expected_servers.items():
node = common.partition_node(expected_server_address)
self.assertTrue(topology.has_server(node))
actual_server = topology.get_server_by_address(node)
actual_server_description = actual_server.description
expected_server_type = server_name_to_type(expected_server["type"])
self.assertEqual(
server_type_name(expected_server_type),
server_type_name(actual_server_description.server_type),
)
self.assertEqual(expected_server.get("setName"), actual_server_description.replica_set_name)
self.assertEqual(expected_server.get("setVersion"), actual_server_description.set_version)
self.assertEqual(expected_server.get("electionId"), actual_server_description.election_id)
self.assertEqual(
expected_server.get("topologyVersion"), actual_server_description.topology_version
)
expected_pool = expected_server.get("pool")
if expected_pool:
self.assertEqual(expected_pool.get("generation"), actual_server.pool.gen.get_overall())
self.assertEqual(outcome["setName"], topology.description.replica_set_name)
self.assertEqual(
outcome.get("logicalSessionTimeoutMinutes"),
topology.description.logical_session_timeout_minutes,
)
expected_topology_type = getattr(TOPOLOGY_TYPE, outcome["topologyType"])
self.assertEqual(
topology_type_name(expected_topology_type),
topology_type_name(topology.description.topology_type),
)
self.assertEqual(outcome.get("maxSetVersion"), topology.description.max_set_version)
self.assertEqual(outcome.get("maxElectionId"), topology.description.max_election_id)
def create_test(scenario_def):
async def run_scenario(self):
c = await create_mock_topology(scenario_def["uri"])
for i, phase in enumerate(scenario_def["phases"]):
# Including the phase description makes failures easier to debug.
description = phase.get("description", str(i))
with assertion_context(f"phase: {description}"):
for response in phase.get("responses", []):
await got_hello(c, common.partition_node(response[0]), response[1])
for app_error in phase.get("applicationErrors", []):
await got_app_error(c, app_error)
check_outcome(self, c, phase["outcome"])
return run_scenario
def create_tests():
for dirpath, _, filenames in os.walk(SDAM_PATH):
dirname = os.path.split(dirpath)[-1]
# SDAM unified tests are handled separately.
if dirname == "unified":
continue
for filename in filenames:
if os.path.splitext(filename)[1] != ".json":
continue
with open(os.path.join(dirpath, filename)) as scenario_stream:
scenario_def = json_util.loads(scenario_stream.read())
# Construct test from scenario.
new_test = create_test(scenario_def)
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"
new_test.__name__ = test_name
setattr(TestAllScenarios, new_test.__name__, new_test)
create_tests()
class TestClusterTimeComparison(AsyncPyMongoTestCase):
async def test_cluster_time_comparison(self):
t = await create_mock_topology("mongodb://host")
async def send_cluster_time(time, inc):
old = t.max_cluster_time()
new = {"clusterTime": Timestamp(time, inc)}
await got_hello(
t,
("host", 27017),
{
"ok": 1,
"minWireVersion": 0,
"maxWireVersion": common.MIN_SUPPORTED_WIRE_VERSION,
"$clusterTime": new,
},
)
actual = t.max_cluster_time()
# We never update $clusterTime from monitoring connections.
self.assertEqual(actual, old)
await send_cluster_time(0, 1)
await send_cluster_time(2, 2)
await send_cluster_time(2, 1)
await send_cluster_time(1, 3)
await send_cluster_time(2, 3)
class TestIgnoreStaleErrors(AsyncIntegrationTest):
async def test_ignore_stale_connection_errors(self):
if not _IS_SYNC and sys.version_info < (3, 11):
self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)")
N_TASKS = 5
barrier = async_create_barrier(N_TASKS, timeout=30)
client = await self.async_rs_or_single_client(minPoolSize=N_TASKS)
# Wait for initial discovery.
await client.admin.command("ping")
pool = await async_get_pool(client)
starting_generation = pool.gen.get_overall()
await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
async def mock_command(*args, **kwargs):
# Synchronize all tasks to ensure they use the same generation.
await async_barrier_wait(barrier, timeout=30)
raise AutoReconnect("mock AsyncConnection.command error")
for conn in pool.conns:
conn.command = mock_command
async def insert_command(i):
try:
await client.test.command("insert", "test", documents=[{"i": i}])
except AutoReconnect:
pass
tasks = []
for i in range(N_TASKS):
tasks.append(ConcurrentRunner(target=insert_command, args=(i,)))
for t in tasks:
await t.start()
for t in tasks:
await t.join()
# Expect a single pool reset for the network error
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
# Server should be selectable.
await client.admin.command("ping")
class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener):
pass
class TestPoolManagement(AsyncIntegrationTest):
@async_client_context.require_failCommand_appName
async def test_pool_unpause(self):
# This test implements the prose test "AsyncConnection Pool Management"
listener = CMAPHeartbeatListener()
_ = await self.async_single_client(
appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener]
)
# Assert that AsyncConnectionPoolReadyEvent occurs after the first
# ServerHeartbeatSucceededEvent.
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
pool_ready = listener.events_by_type(monitoring.PoolReadyEvent)[0]
hb_succeeded = listener.events_by_type(monitoring.ServerHeartbeatSucceededEvent)[0]
self.assertGreater(listener.events.index(pool_ready), listener.events.index(hb_succeeded))
listener.reset()
fail_hello = {
"mode": {"times": 2},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"errorCode": 1234,
"appName": "SDAMPoolManagementTest",
},
}
async with self.fail_point(fail_hello):
await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
await listener.async_wait_for_event(monitoring.PoolClearedEvent, 1)
await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
class TestServerMonitoringMode(AsyncIntegrationTest):
@async_client_context.require_no_serverless
@async_client_context.require_no_load_balancer
async def asyncSetUp(self):
await super().asyncSetUp()
async def test_rtt_connection_is_enabled_stream(self):
client = await self.async_rs_or_single_client(serverMonitoringMode="stream")
await client.admin.command("ping")
def predicate():
for _, server in client._topology._servers.items():
monitor = server._monitor
if not monitor._stream:
return False
if async_client_context.version >= (4, 4):
if _IS_SYNC:
if monitor._rtt_monitor._executor._thread is None:
return False
else:
if monitor._rtt_monitor._executor._task is None:
return False
else:
if _IS_SYNC:
if monitor._rtt_monitor._executor._thread is not None:
return False
else:
if monitor._rtt_monitor._executor._task is not None:
return False
return True
await async_wait_until(predicate, "find all RTT monitors")
async def test_rtt_connection_is_disabled_poll(self):
client = await self.async_rs_or_single_client(serverMonitoringMode="poll")
await self.assert_rtt_connection_is_disabled(client)
async def test_rtt_connection_is_disabled_auto(self):
envs = [
{"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9"},
{"FUNCTIONS_WORKER_RUNTIME": "python"},
{"K_SERVICE": "gcpservicename"},
{"FUNCTION_NAME": "gcpfunctionname"},
{"VERCEL": "1"},
]
for env in envs:
with patch.dict("os.environ", env):
client = await self.async_rs_or_single_client(serverMonitoringMode="auto")
await self.assert_rtt_connection_is_disabled(client)
async def assert_rtt_connection_is_disabled(self, client):
await client.admin.command("ping")
for _, server in client._topology._servers.items():
monitor = server._monitor
self.assertFalse(monitor._stream)
if _IS_SYNC:
self.assertIsNone(monitor._rtt_monitor._executor._thread)
else:
self.assertIsNone(monitor._rtt_monitor._executor._task)
class MockTCPHandler(socketserver.BaseRequestHandler):
def handle(self):
self.server.events.append("client connected")
if self.request.recv(1024).strip():
self.server.events.append("client hello received")
self.request.close()
class TCPServer(socketserver.TCPServer):
allow_reuse_address = True
def handle_request_and_shutdown(self):
self.handle_request()
self.server_close()
class TestHeartbeatStartOrdering(AsyncPyMongoTestCase):
async def test_heartbeat_start_ordering(self):
events = []
listener = HeartbeatEventsListListener(events)
if _IS_SYNC:
server = TCPServer(("localhost", 9999), MockTCPHandler)
server.events = events
server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown)
await server_thread.start()
_c = await self.simple_client(
"mongodb://localhost:9999",
serverSelectionTimeoutMS=500,
event_listeners=(listener,),
)
await server_thread.join()
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
else:
async def handle_client(reader: StreamReader, writer: StreamWriter):
events.append("client connected")
if (await reader.read(1024)).strip():
events.append("client hello received")
writer.close()
await writer.wait_closed()
server = await asyncio.start_server(handle_client, "localhost", 9999)
server.events = events
await server.start_serving()
_c = self.simple_client(
"mongodb://localhost:9999",
serverSelectionTimeoutMS=500,
event_listeners=(listener,),
)
await _c.aconnect()
await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1)
await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1)
server.close()
await server.wait_closed()
await _c.close()
self.assertEqual(
events,
[
"serverHeartbeatStartedEvent",
"client connected",
"client hello received",
"serverHeartbeatFailedEvent",
],
)
# Generate unified tests.
globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__))
if __name__ == "__main__":
unittest.main()

View File

@ -544,6 +544,14 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
self.skipTest("Implement PYTHON-1894")
if "timeoutMS applied to entire download" in spec["description"]:
self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime")
if (
"Error returned from connection pool clear with interruptInUseConnections=true is retryable"
in spec["description"]
and not _IS_SYNC
):
self.skipTest("PYTHON-5170 tests are flakey")
if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC:
self.skipTest("PYTHON-5174 tests are flakey")
class_name = self.__class__.__name__.lower()
description = spec["description"].lower()
@ -1151,7 +1159,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
self.assertIsInstance(description, TopologyDescription)
self.assertEqual(description.topology_type_name, spec["topologyType"])
def _testOperation_waitForPrimaryChange(self, spec: dict) -> None:
async def _testOperation_waitForPrimaryChange(self, spec: dict) -> None:
"""Run the waitForPrimaryChange test operation."""
client = self.entity_map[spec["client"]]
old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]]
@ -1165,13 +1173,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
old_primary = get_primary(old_description)
def primary_changed() -> bool:
primary = client.primary
async def primary_changed() -> bool:
primary = await client.primary
if primary is None:
return False
return primary != old_primary
wait_until(primary_changed, "change primary", timeout=timeout)
await async_wait_until(primary_changed, "change primary", timeout=timeout)
async def _testOperation_runOnThread(self, spec):
"""Run the 'runOnThread' operation."""

View File

@ -15,14 +15,18 @@
"""Test the topology module."""
from __future__ import annotations
import asyncio
import os
import socketserver
import sys
import threading
from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.helpers import ConcurrentRunner
sys.path[0:0] = [""]
from test import IntegrationTest, PyMongoTestCase, unittest
from test import IntegrationTest, PyMongoTestCase, UnitTest, unittest
from test.pymongo_mocks import DummyMonitor
from test.unified_format import generate_test_classes
from test.utils import (
@ -30,7 +34,9 @@ from test.utils import (
HeartbeatEventListener,
HeartbeatEventsListListener,
assertion_context,
barrier_wait,
client_context,
create_barrier,
get_pool,
server_name_to_type,
wait_until,
@ -55,8 +61,16 @@ from pymongo.synchronous.topology import Topology, _ErrorContext
from pymongo.topology_description import TOPOLOGY_TYPE
from pymongo.uri_parser import parse_uri
_IS_SYNC = True
# Location of JSON test specifications.
SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring")
if _IS_SYNC:
SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring")
else:
SDAM_PATH = os.path.join(
Path(__file__).resolve().parent.parent,
"discovery_and_monitoring",
)
def create_mock_topology(uri, monitor_class=DummyMonitor):
@ -128,7 +142,7 @@ def get_type(topology, hostname):
return description.server_type
class TestAllScenarios(unittest.TestCase):
class TestAllScenarios(UnitTest):
pass
@ -240,7 +254,7 @@ def create_tests():
create_tests()
class TestClusterTimeComparison(unittest.TestCase):
class TestClusterTimeComparison(PyMongoTestCase):
def test_cluster_time_comparison(self):
t = create_mock_topology("mongodb://host")
@ -271,20 +285,21 @@ class TestClusterTimeComparison(unittest.TestCase):
class TestIgnoreStaleErrors(IntegrationTest):
def test_ignore_stale_connection_errors(self):
N_THREADS = 5
barrier = threading.Barrier(N_THREADS, timeout=30)
client = self.rs_or_single_client(minPoolSize=N_THREADS)
self.addCleanup(client.close)
if not _IS_SYNC and sys.version_info < (3, 11):
self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)")
N_TASKS = 5
barrier = create_barrier(N_TASKS, timeout=30)
client = self.rs_or_single_client(minPoolSize=N_TASKS)
# Wait for initial discovery.
client.admin.command("ping")
pool = get_pool(client)
starting_generation = pool.gen.get_overall()
wait_until(lambda: len(pool.conns) == N_THREADS, "created conns")
wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
def mock_command(*args, **kwargs):
# Synchronize all threads to ensure they use the same generation.
barrier.wait()
# Synchronize all tasks to ensure they use the same generation.
barrier_wait(barrier, timeout=30)
raise AutoReconnect("mock Connection.command error")
for conn in pool.conns:
@ -296,12 +311,12 @@ class TestIgnoreStaleErrors(IntegrationTest):
except AutoReconnect:
pass
threads = []
for i in range(N_THREADS):
threads.append(threading.Thread(target=insert_command, args=(i,)))
for t in threads:
tasks = []
for i in range(N_TASKS):
tasks.append(ConcurrentRunner(target=insert_command, args=(i,)))
for t in tasks:
t.start()
for t in threads:
for t in tasks:
t.join()
# Expect a single pool reset for the network error
@ -320,10 +335,9 @@ class TestPoolManagement(IntegrationTest):
def test_pool_unpause(self):
# This test implements the prose test "Connection Pool Management"
listener = CMAPHeartbeatListener()
client = self.single_client(
_ = self.single_client(
appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener]
)
self.addCleanup(client.close)
# Assert that ConnectionPoolReadyEvent occurs after the first
# ServerHeartbeatSucceededEvent.
listener.wait_for_event(monitoring.PoolReadyEvent, 1)
@ -355,7 +369,6 @@ class TestServerMonitoringMode(IntegrationTest):
def test_rtt_connection_is_enabled_stream(self):
client = self.rs_or_single_client(serverMonitoringMode="stream")
self.addCleanup(client.close)
client.admin.command("ping")
def predicate():
@ -364,18 +377,26 @@ class TestServerMonitoringMode(IntegrationTest):
if not monitor._stream:
return False
if client_context.version >= (4, 4):
if monitor._rtt_monitor._executor._thread is None:
return False
if _IS_SYNC:
if monitor._rtt_monitor._executor._thread is None:
return False
else:
if monitor._rtt_monitor._executor._task is None:
return False
else:
if monitor._rtt_monitor._executor._thread is not None:
return False
if _IS_SYNC:
if monitor._rtt_monitor._executor._thread is not None:
return False
else:
if monitor._rtt_monitor._executor._task is not None:
return False
return True
wait_until(predicate, "find all RTT monitors")
def test_rtt_connection_is_disabled_poll(self):
client = self.rs_or_single_client(serverMonitoringMode="poll")
self.addCleanup(client.close)
self.assert_rtt_connection_is_disabled(client)
def test_rtt_connection_is_disabled_auto(self):
@ -389,7 +410,6 @@ class TestServerMonitoringMode(IntegrationTest):
for env in envs:
with patch.dict("os.environ", env):
client = self.rs_or_single_client(serverMonitoringMode="auto")
self.addCleanup(client.close)
self.assert_rtt_connection_is_disabled(client)
def assert_rtt_connection_is_disabled(self, client):
@ -397,7 +417,10 @@ class TestServerMonitoringMode(IntegrationTest):
for _, server in client._topology._servers.items():
monitor = server._monitor
self.assertFalse(monitor._stream)
self.assertIsNone(monitor._rtt_monitor._executor._thread)
if _IS_SYNC:
self.assertIsNone(monitor._rtt_monitor._executor._thread)
else:
self.assertIsNone(monitor._rtt_monitor._executor._task)
class MockTCPHandler(socketserver.BaseRequestHandler):
@ -420,16 +443,46 @@ class TestHeartbeatStartOrdering(PyMongoTestCase):
def test_heartbeat_start_ordering(self):
events = []
listener = HeartbeatEventsListListener(events)
server = TCPServer(("localhost", 9999), MockTCPHandler)
server.events = events
server_thread = threading.Thread(target=server.handle_request_and_shutdown)
server_thread.start()
_c = self.simple_client(
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
)
server_thread.join()
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
if _IS_SYNC:
server = TCPServer(("localhost", 9999), MockTCPHandler)
server.events = events
server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown)
server_thread.start()
_c = self.simple_client(
"mongodb://localhost:9999",
serverSelectionTimeoutMS=500,
event_listeners=(listener,),
)
server_thread.join()
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
else:
def handle_client(reader: StreamReader, writer: StreamWriter):
events.append("client connected")
if (reader.read(1024)).strip():
events.append("client hello received")
writer.close()
writer.wait_closed()
server = asyncio.start_server(handle_client, "localhost", 9999)
server.events = events
server.start_serving()
_c = self.simple_client(
"mongodb://localhost:9999",
serverSelectionTimeoutMS=500,
event_listeners=(listener,),
)
_c._connect()
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
server.close()
server.wait_closed()
_c.close()
self.assertEqual(
events,

View File

@ -543,6 +543,14 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
self.skipTest("Implement PYTHON-1894")
if "timeoutMS applied to entire download" in spec["description"]:
self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime")
if (
"Error returned from connection pool clear with interruptInUseConnections=true is retryable"
in spec["description"]
and not _IS_SYNC
):
self.skipTest("PYTHON-5170 tests are flakey")
if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC:
self.skipTest("PYTHON-5174 tests are flakey")
class_name = self.__class__.__name__.lower()
description = spec["description"].lower()

View File

@ -1078,3 +1078,19 @@ def create_async_event():
def create_event():
return threading.Event()
def async_create_barrier(N_TASKS, timeout: float | None = None):
return asyncio.Barrier(N_TASKS)
def create_barrier(N_TASKS, timeout: float | None = None):
return threading.Barrier(N_TASKS, timeout=timeout)
async def async_barrier_wait(barrier, timeout: float | None = None):
await asyncio.wait_for(barrier.wait(), timeout=timeout)
def barrier_wait(barrier, timeout: float | None = None):
barrier.wait()

View File

@ -124,6 +124,8 @@ replacements = {
"AsyncMockPool": "MockPool",
"StopAsyncIteration": "StopIteration",
"create_async_event": "create_event",
"async_create_barrier": "create_barrier",
"async_barrier_wait": "barrier_wait",
"async_joinall": "joinall",
}
@ -213,6 +215,7 @@ converted_tests = [
"test_custom_types.py",
"test_database.py",
"test_data_lake.py",
"test_discovery_and_monitoring.py",
"test_dns.py",
"test_encryption.py",
"test_examples.py",