PYTHON-5078 Convert test.test_discovery_and_monitoring to async (#2093)
Co-authored-by: Noah Stapp <noah@noahstapp.com>
This commit is contained in:
parent
6da1fdbed9
commit
150a3ba756
503
test/asynchronous/test_discovery_and_monitoring.py
Normal file
503
test/asynchronous/test_discovery_and_monitoring.py
Normal file
@ -0,0 +1,503 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the topology module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import socketserver
|
||||
import sys
|
||||
import threading
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from pathlib import Path
|
||||
from test.asynchronous.helpers import ConcurrentRunner
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, AsyncUnitTest, unittest
|
||||
from test.asynchronous.pymongo_mocks import DummyMonitor
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
HeartbeatEventListener,
|
||||
HeartbeatEventsListListener,
|
||||
assertion_context,
|
||||
async_barrier_wait,
|
||||
async_client_context,
|
||||
async_create_barrier,
|
||||
async_get_pool,
|
||||
async_wait_until,
|
||||
server_name_to_type,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
from bson import Timestamp, json_util
|
||||
from pymongo import AsyncMongoClient, common, monitoring
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology, _ErrorContext
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
ConfigurationError,
|
||||
NetworkTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.helpers_shared import _check_command_response, _check_write_command_response
|
||||
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
|
||||
from pymongo.server_description import SERVER_TYPE, ServerDescription
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring")
|
||||
else:
|
||||
SDAM_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent.parent,
|
||||
"discovery_and_monitoring",
|
||||
)
|
||||
|
||||
|
||||
async def create_mock_topology(uri, monitor_class=DummyMonitor):
|
||||
parsed_uri = parse_uri(uri)
|
||||
replica_set_name = None
|
||||
direct_connection = None
|
||||
load_balanced = None
|
||||
if "replicaset" in parsed_uri["options"]:
|
||||
replica_set_name = parsed_uri["options"]["replicaset"]
|
||||
if "directConnection" in parsed_uri["options"]:
|
||||
direct_connection = parsed_uri["options"]["directConnection"]
|
||||
if "loadBalanced" in parsed_uri["options"]:
|
||||
load_balanced = parsed_uri["options"]["loadBalanced"]
|
||||
|
||||
topology_settings = TopologySettings(
|
||||
parsed_uri["nodelist"],
|
||||
replica_set_name=replica_set_name,
|
||||
monitor_class=monitor_class,
|
||||
direct_connection=direct_connection,
|
||||
load_balanced=load_balanced,
|
||||
)
|
||||
|
||||
c = Topology(topology_settings)
|
||||
await c.open()
|
||||
return c
|
||||
|
||||
|
||||
async def got_hello(topology, server_address, hello_response):
|
||||
server_description = ServerDescription(server_address, Hello(hello_response), 0)
|
||||
await topology.on_change(server_description)
|
||||
|
||||
|
||||
async def got_app_error(topology, app_error):
|
||||
server_address = common.partition_node(app_error["address"])
|
||||
server = topology.get_server_by_address(server_address)
|
||||
error_type = app_error["type"]
|
||||
generation = app_error.get("generation", server.pool.gen.get_overall())
|
||||
when = app_error["when"]
|
||||
max_wire_version = app_error["maxWireVersion"]
|
||||
# XXX: We could get better test coverage by mocking the errors on the
|
||||
# Pool/AsyncConnection.
|
||||
try:
|
||||
if error_type == "command":
|
||||
_check_command_response(app_error["response"], max_wire_version)
|
||||
_check_write_command_response(app_error["response"])
|
||||
elif error_type == "network":
|
||||
raise AutoReconnect("mock non-timeout network error")
|
||||
elif error_type == "timeout":
|
||||
raise NetworkTimeout("mock network timeout error")
|
||||
else:
|
||||
raise AssertionError(f"unknown error type: {error_type}")
|
||||
raise AssertionError
|
||||
except (AutoReconnect, NotPrimaryError, OperationFailure) as e:
|
||||
if when == "beforeHandshakeCompletes":
|
||||
completed_handshake = False
|
||||
elif when == "afterHandshakeCompletes":
|
||||
completed_handshake = True
|
||||
else:
|
||||
raise AssertionError(f"Unknown when field {when}")
|
||||
|
||||
await topology.handle_error(
|
||||
server_address,
|
||||
_ErrorContext(e, max_wire_version, generation, completed_handshake, None),
|
||||
)
|
||||
|
||||
|
||||
def get_type(topology, hostname):
|
||||
description = topology.get_server_by_address((hostname, 27017)).description
|
||||
return description.server_type
|
||||
|
||||
|
||||
class TestAllScenarios(AsyncUnitTest):
|
||||
pass
|
||||
|
||||
|
||||
def topology_type_name(topology_type):
|
||||
return TOPOLOGY_TYPE._fields[topology_type]
|
||||
|
||||
|
||||
def server_type_name(server_type):
|
||||
return SERVER_TYPE._fields[server_type]
|
||||
|
||||
|
||||
def check_outcome(self, topology, outcome):
|
||||
expected_servers = outcome["servers"]
|
||||
|
||||
# Check weak equality before proceeding.
|
||||
self.assertEqual(len(topology.description.server_descriptions()), len(expected_servers))
|
||||
|
||||
if outcome.get("compatible") is False:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
topology.description.check_compatible()
|
||||
else:
|
||||
# No error.
|
||||
topology.description.check_compatible()
|
||||
|
||||
# Since lengths are equal, every actual server must have a corresponding
|
||||
# expected server.
|
||||
for expected_server_address, expected_server in expected_servers.items():
|
||||
node = common.partition_node(expected_server_address)
|
||||
self.assertTrue(topology.has_server(node))
|
||||
actual_server = topology.get_server_by_address(node)
|
||||
actual_server_description = actual_server.description
|
||||
expected_server_type = server_name_to_type(expected_server["type"])
|
||||
|
||||
self.assertEqual(
|
||||
server_type_name(expected_server_type),
|
||||
server_type_name(actual_server_description.server_type),
|
||||
)
|
||||
|
||||
self.assertEqual(expected_server.get("setName"), actual_server_description.replica_set_name)
|
||||
|
||||
self.assertEqual(expected_server.get("setVersion"), actual_server_description.set_version)
|
||||
|
||||
self.assertEqual(expected_server.get("electionId"), actual_server_description.election_id)
|
||||
|
||||
self.assertEqual(
|
||||
expected_server.get("topologyVersion"), actual_server_description.topology_version
|
||||
)
|
||||
|
||||
expected_pool = expected_server.get("pool")
|
||||
if expected_pool:
|
||||
self.assertEqual(expected_pool.get("generation"), actual_server.pool.gen.get_overall())
|
||||
|
||||
self.assertEqual(outcome["setName"], topology.description.replica_set_name)
|
||||
self.assertEqual(
|
||||
outcome.get("logicalSessionTimeoutMinutes"),
|
||||
topology.description.logical_session_timeout_minutes,
|
||||
)
|
||||
|
||||
expected_topology_type = getattr(TOPOLOGY_TYPE, outcome["topologyType"])
|
||||
self.assertEqual(
|
||||
topology_type_name(expected_topology_type),
|
||||
topology_type_name(topology.description.topology_type),
|
||||
)
|
||||
|
||||
self.assertEqual(outcome.get("maxSetVersion"), topology.description.max_set_version)
|
||||
self.assertEqual(outcome.get("maxElectionId"), topology.description.max_election_id)
|
||||
|
||||
|
||||
def create_test(scenario_def):
|
||||
async def run_scenario(self):
|
||||
c = await create_mock_topology(scenario_def["uri"])
|
||||
|
||||
for i, phase in enumerate(scenario_def["phases"]):
|
||||
# Including the phase description makes failures easier to debug.
|
||||
description = phase.get("description", str(i))
|
||||
with assertion_context(f"phase: {description}"):
|
||||
for response in phase.get("responses", []):
|
||||
await got_hello(c, common.partition_node(response[0]), response[1])
|
||||
|
||||
for app_error in phase.get("applicationErrors", []):
|
||||
await got_app_error(c, app_error)
|
||||
|
||||
check_outcome(self, c, phase["outcome"])
|
||||
|
||||
return run_scenario
|
||||
|
||||
|
||||
def create_tests():
|
||||
for dirpath, _, filenames in os.walk(SDAM_PATH):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
# SDAM unified tests are handled separately.
|
||||
if dirname == "unified":
|
||||
continue
|
||||
|
||||
for filename in filenames:
|
||||
if os.path.splitext(filename)[1] != ".json":
|
||||
continue
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream:
|
||||
scenario_def = json_util.loads(scenario_stream.read())
|
||||
|
||||
# Construct test from scenario.
|
||||
new_test = create_test(scenario_def)
|
||||
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(TestAllScenarios, new_test.__name__, new_test)
|
||||
|
||||
|
||||
create_tests()
|
||||
|
||||
|
||||
class TestClusterTimeComparison(AsyncPyMongoTestCase):
|
||||
async def test_cluster_time_comparison(self):
|
||||
t = await create_mock_topology("mongodb://host")
|
||||
|
||||
async def send_cluster_time(time, inc):
|
||||
old = t.max_cluster_time()
|
||||
new = {"clusterTime": Timestamp(time, inc)}
|
||||
await got_hello(
|
||||
t,
|
||||
("host", 27017),
|
||||
{
|
||||
"ok": 1,
|
||||
"minWireVersion": 0,
|
||||
"maxWireVersion": common.MIN_SUPPORTED_WIRE_VERSION,
|
||||
"$clusterTime": new,
|
||||
},
|
||||
)
|
||||
|
||||
actual = t.max_cluster_time()
|
||||
# We never update $clusterTime from monitoring connections.
|
||||
self.assertEqual(actual, old)
|
||||
|
||||
await send_cluster_time(0, 1)
|
||||
await send_cluster_time(2, 2)
|
||||
await send_cluster_time(2, 1)
|
||||
await send_cluster_time(1, 3)
|
||||
await send_cluster_time(2, 3)
|
||||
|
||||
|
||||
class TestIgnoreStaleErrors(AsyncIntegrationTest):
|
||||
async def test_ignore_stale_connection_errors(self):
|
||||
if not _IS_SYNC and sys.version_info < (3, 11):
|
||||
self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)")
|
||||
N_TASKS = 5
|
||||
barrier = async_create_barrier(N_TASKS, timeout=30)
|
||||
client = await self.async_rs_or_single_client(minPoolSize=N_TASKS)
|
||||
|
||||
# Wait for initial discovery.
|
||||
await client.admin.command("ping")
|
||||
pool = await async_get_pool(client)
|
||||
starting_generation = pool.gen.get_overall()
|
||||
await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
|
||||
|
||||
async def mock_command(*args, **kwargs):
|
||||
# Synchronize all tasks to ensure they use the same generation.
|
||||
await async_barrier_wait(barrier, timeout=30)
|
||||
raise AutoReconnect("mock AsyncConnection.command error")
|
||||
|
||||
for conn in pool.conns:
|
||||
conn.command = mock_command
|
||||
|
||||
async def insert_command(i):
|
||||
try:
|
||||
await client.test.command("insert", "test", documents=[{"i": i}])
|
||||
except AutoReconnect:
|
||||
pass
|
||||
|
||||
tasks = []
|
||||
for i in range(N_TASKS):
|
||||
tasks.append(ConcurrentRunner(target=insert_command, args=(i,)))
|
||||
for t in tasks:
|
||||
await t.start()
|
||||
for t in tasks:
|
||||
await t.join()
|
||||
|
||||
# Expect a single pool reset for the network error
|
||||
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
|
||||
|
||||
# Server should be selectable.
|
||||
await client.admin.command("ping")
|
||||
|
||||
|
||||
class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener):
|
||||
pass
|
||||
|
||||
|
||||
class TestPoolManagement(AsyncIntegrationTest):
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_pool_unpause(self):
|
||||
# This test implements the prose test "AsyncConnection Pool Management"
|
||||
listener = CMAPHeartbeatListener()
|
||||
_ = await self.async_single_client(
|
||||
appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener]
|
||||
)
|
||||
# Assert that AsyncConnectionPoolReadyEvent occurs after the first
|
||||
# ServerHeartbeatSucceededEvent.
|
||||
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
|
||||
pool_ready = listener.events_by_type(monitoring.PoolReadyEvent)[0]
|
||||
hb_succeeded = listener.events_by_type(monitoring.ServerHeartbeatSucceededEvent)[0]
|
||||
self.assertGreater(listener.events.index(pool_ready), listener.events.index(hb_succeeded))
|
||||
|
||||
listener.reset()
|
||||
fail_hello = {
|
||||
"mode": {"times": 2},
|
||||
"data": {
|
||||
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
|
||||
"errorCode": 1234,
|
||||
"appName": "SDAMPoolManagementTest",
|
||||
},
|
||||
}
|
||||
async with self.fail_point(fail_hello):
|
||||
await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
|
||||
await listener.async_wait_for_event(monitoring.PoolClearedEvent, 1)
|
||||
await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
|
||||
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
|
||||
|
||||
|
||||
class TestServerMonitoringMode(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
@async_client_context.require_no_load_balancer
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
|
||||
async def test_rtt_connection_is_enabled_stream(self):
|
||||
client = await self.async_rs_or_single_client(serverMonitoringMode="stream")
|
||||
await client.admin.command("ping")
|
||||
|
||||
def predicate():
|
||||
for _, server in client._topology._servers.items():
|
||||
monitor = server._monitor
|
||||
if not monitor._stream:
|
||||
return False
|
||||
if async_client_context.version >= (4, 4):
|
||||
if _IS_SYNC:
|
||||
if monitor._rtt_monitor._executor._thread is None:
|
||||
return False
|
||||
else:
|
||||
if monitor._rtt_monitor._executor._task is None:
|
||||
return False
|
||||
else:
|
||||
if _IS_SYNC:
|
||||
if monitor._rtt_monitor._executor._thread is not None:
|
||||
return False
|
||||
else:
|
||||
if monitor._rtt_monitor._executor._task is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
await async_wait_until(predicate, "find all RTT monitors")
|
||||
|
||||
async def test_rtt_connection_is_disabled_poll(self):
|
||||
client = await self.async_rs_or_single_client(serverMonitoringMode="poll")
|
||||
|
||||
await self.assert_rtt_connection_is_disabled(client)
|
||||
|
||||
async def test_rtt_connection_is_disabled_auto(self):
|
||||
envs = [
|
||||
{"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9"},
|
||||
{"FUNCTIONS_WORKER_RUNTIME": "python"},
|
||||
{"K_SERVICE": "gcpservicename"},
|
||||
{"FUNCTION_NAME": "gcpfunctionname"},
|
||||
{"VERCEL": "1"},
|
||||
]
|
||||
for env in envs:
|
||||
with patch.dict("os.environ", env):
|
||||
client = await self.async_rs_or_single_client(serverMonitoringMode="auto")
|
||||
await self.assert_rtt_connection_is_disabled(client)
|
||||
|
||||
async def assert_rtt_connection_is_disabled(self, client):
|
||||
await client.admin.command("ping")
|
||||
for _, server in client._topology._servers.items():
|
||||
monitor = server._monitor
|
||||
self.assertFalse(monitor._stream)
|
||||
if _IS_SYNC:
|
||||
self.assertIsNone(monitor._rtt_monitor._executor._thread)
|
||||
else:
|
||||
self.assertIsNone(monitor._rtt_monitor._executor._task)
|
||||
|
||||
|
||||
class MockTCPHandler(socketserver.BaseRequestHandler):
|
||||
def handle(self):
|
||||
self.server.events.append("client connected")
|
||||
if self.request.recv(1024).strip():
|
||||
self.server.events.append("client hello received")
|
||||
self.request.close()
|
||||
|
||||
|
||||
class TCPServer(socketserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
|
||||
def handle_request_and_shutdown(self):
|
||||
self.handle_request()
|
||||
self.server_close()
|
||||
|
||||
|
||||
class TestHeartbeatStartOrdering(AsyncPyMongoTestCase):
|
||||
async def test_heartbeat_start_ordering(self):
|
||||
events = []
|
||||
listener = HeartbeatEventsListListener(events)
|
||||
|
||||
if _IS_SYNC:
|
||||
server = TCPServer(("localhost", 9999), MockTCPHandler)
|
||||
server.events = events
|
||||
server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown)
|
||||
await server_thread.start()
|
||||
_c = await self.simple_client(
|
||||
"mongodb://localhost:9999",
|
||||
serverSelectionTimeoutMS=500,
|
||||
event_listeners=(listener,),
|
||||
)
|
||||
await server_thread.join()
|
||||
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
|
||||
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
|
||||
|
||||
else:
|
||||
|
||||
async def handle_client(reader: StreamReader, writer: StreamWriter):
|
||||
events.append("client connected")
|
||||
if (await reader.read(1024)).strip():
|
||||
events.append("client hello received")
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
server = await asyncio.start_server(handle_client, "localhost", 9999)
|
||||
server.events = events
|
||||
await server.start_serving()
|
||||
_c = self.simple_client(
|
||||
"mongodb://localhost:9999",
|
||||
serverSelectionTimeoutMS=500,
|
||||
event_listeners=(listener,),
|
||||
)
|
||||
await _c.aconnect()
|
||||
|
||||
await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1)
|
||||
await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1)
|
||||
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
await _c.close()
|
||||
|
||||
self.assertEqual(
|
||||
events,
|
||||
[
|
||||
"serverHeartbeatStartedEvent",
|
||||
"client connected",
|
||||
"client hello received",
|
||||
"serverHeartbeatFailedEvent",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -544,6 +544,14 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
self.skipTest("Implement PYTHON-1894")
|
||||
if "timeoutMS applied to entire download" in spec["description"]:
|
||||
self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime")
|
||||
if (
|
||||
"Error returned from connection pool clear with interruptInUseConnections=true is retryable"
|
||||
in spec["description"]
|
||||
and not _IS_SYNC
|
||||
):
|
||||
self.skipTest("PYTHON-5170 tests are flakey")
|
||||
if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC:
|
||||
self.skipTest("PYTHON-5174 tests are flakey")
|
||||
|
||||
class_name = self.__class__.__name__.lower()
|
||||
description = spec["description"].lower()
|
||||
@ -1151,7 +1159,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
self.assertIsInstance(description, TopologyDescription)
|
||||
self.assertEqual(description.topology_type_name, spec["topologyType"])
|
||||
|
||||
def _testOperation_waitForPrimaryChange(self, spec: dict) -> None:
|
||||
async def _testOperation_waitForPrimaryChange(self, spec: dict) -> None:
|
||||
"""Run the waitForPrimaryChange test operation."""
|
||||
client = self.entity_map[spec["client"]]
|
||||
old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]]
|
||||
@ -1165,13 +1173,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
|
||||
old_primary = get_primary(old_description)
|
||||
|
||||
def primary_changed() -> bool:
|
||||
primary = client.primary
|
||||
async def primary_changed() -> bool:
|
||||
primary = await client.primary
|
||||
if primary is None:
|
||||
return False
|
||||
return primary != old_primary
|
||||
|
||||
wait_until(primary_changed, "change primary", timeout=timeout)
|
||||
await async_wait_until(primary_changed, "change primary", timeout=timeout)
|
||||
|
||||
async def _testOperation_runOnThread(self, spec):
|
||||
"""Run the 'runOnThread' operation."""
|
||||
|
||||
@ -15,14 +15,18 @@
|
||||
"""Test the topology module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import socketserver
|
||||
import sys
|
||||
import threading
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from pathlib import Path
|
||||
from test.helpers import ConcurrentRunner
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, PyMongoTestCase, unittest
|
||||
from test import IntegrationTest, PyMongoTestCase, UnitTest, unittest
|
||||
from test.pymongo_mocks import DummyMonitor
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
@ -30,7 +34,9 @@ from test.utils import (
|
||||
HeartbeatEventListener,
|
||||
HeartbeatEventsListListener,
|
||||
assertion_context,
|
||||
barrier_wait,
|
||||
client_context,
|
||||
create_barrier,
|
||||
get_pool,
|
||||
server_name_to_type,
|
||||
wait_until,
|
||||
@ -55,8 +61,16 @@ from pymongo.synchronous.topology import Topology, _ErrorContext
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring")
|
||||
if _IS_SYNC:
|
||||
SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring")
|
||||
else:
|
||||
SDAM_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent.parent,
|
||||
"discovery_and_monitoring",
|
||||
)
|
||||
|
||||
|
||||
def create_mock_topology(uri, monitor_class=DummyMonitor):
|
||||
@ -128,7 +142,7 @@ def get_type(topology, hostname):
|
||||
return description.server_type
|
||||
|
||||
|
||||
class TestAllScenarios(unittest.TestCase):
|
||||
class TestAllScenarios(UnitTest):
|
||||
pass
|
||||
|
||||
|
||||
@ -240,7 +254,7 @@ def create_tests():
|
||||
create_tests()
|
||||
|
||||
|
||||
class TestClusterTimeComparison(unittest.TestCase):
|
||||
class TestClusterTimeComparison(PyMongoTestCase):
|
||||
def test_cluster_time_comparison(self):
|
||||
t = create_mock_topology("mongodb://host")
|
||||
|
||||
@ -271,20 +285,21 @@ class TestClusterTimeComparison(unittest.TestCase):
|
||||
|
||||
class TestIgnoreStaleErrors(IntegrationTest):
|
||||
def test_ignore_stale_connection_errors(self):
|
||||
N_THREADS = 5
|
||||
barrier = threading.Barrier(N_THREADS, timeout=30)
|
||||
client = self.rs_or_single_client(minPoolSize=N_THREADS)
|
||||
self.addCleanup(client.close)
|
||||
if not _IS_SYNC and sys.version_info < (3, 11):
|
||||
self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)")
|
||||
N_TASKS = 5
|
||||
barrier = create_barrier(N_TASKS, timeout=30)
|
||||
client = self.rs_or_single_client(minPoolSize=N_TASKS)
|
||||
|
||||
# Wait for initial discovery.
|
||||
client.admin.command("ping")
|
||||
pool = get_pool(client)
|
||||
starting_generation = pool.gen.get_overall()
|
||||
wait_until(lambda: len(pool.conns) == N_THREADS, "created conns")
|
||||
wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
|
||||
|
||||
def mock_command(*args, **kwargs):
|
||||
# Synchronize all threads to ensure they use the same generation.
|
||||
barrier.wait()
|
||||
# Synchronize all tasks to ensure they use the same generation.
|
||||
barrier_wait(barrier, timeout=30)
|
||||
raise AutoReconnect("mock Connection.command error")
|
||||
|
||||
for conn in pool.conns:
|
||||
@ -296,12 +311,12 @@ class TestIgnoreStaleErrors(IntegrationTest):
|
||||
except AutoReconnect:
|
||||
pass
|
||||
|
||||
threads = []
|
||||
for i in range(N_THREADS):
|
||||
threads.append(threading.Thread(target=insert_command, args=(i,)))
|
||||
for t in threads:
|
||||
tasks = []
|
||||
for i in range(N_TASKS):
|
||||
tasks.append(ConcurrentRunner(target=insert_command, args=(i,)))
|
||||
for t in tasks:
|
||||
t.start()
|
||||
for t in threads:
|
||||
for t in tasks:
|
||||
t.join()
|
||||
|
||||
# Expect a single pool reset for the network error
|
||||
@ -320,10 +335,9 @@ class TestPoolManagement(IntegrationTest):
|
||||
def test_pool_unpause(self):
|
||||
# This test implements the prose test "Connection Pool Management"
|
||||
listener = CMAPHeartbeatListener()
|
||||
client = self.single_client(
|
||||
_ = self.single_client(
|
||||
appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
# Assert that ConnectionPoolReadyEvent occurs after the first
|
||||
# ServerHeartbeatSucceededEvent.
|
||||
listener.wait_for_event(monitoring.PoolReadyEvent, 1)
|
||||
@ -355,7 +369,6 @@ class TestServerMonitoringMode(IntegrationTest):
|
||||
|
||||
def test_rtt_connection_is_enabled_stream(self):
|
||||
client = self.rs_or_single_client(serverMonitoringMode="stream")
|
||||
self.addCleanup(client.close)
|
||||
client.admin.command("ping")
|
||||
|
||||
def predicate():
|
||||
@ -364,18 +377,26 @@ class TestServerMonitoringMode(IntegrationTest):
|
||||
if not monitor._stream:
|
||||
return False
|
||||
if client_context.version >= (4, 4):
|
||||
if monitor._rtt_monitor._executor._thread is None:
|
||||
return False
|
||||
if _IS_SYNC:
|
||||
if monitor._rtt_monitor._executor._thread is None:
|
||||
return False
|
||||
else:
|
||||
if monitor._rtt_monitor._executor._task is None:
|
||||
return False
|
||||
else:
|
||||
if monitor._rtt_monitor._executor._thread is not None:
|
||||
return False
|
||||
if _IS_SYNC:
|
||||
if monitor._rtt_monitor._executor._thread is not None:
|
||||
return False
|
||||
else:
|
||||
if monitor._rtt_monitor._executor._task is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
wait_until(predicate, "find all RTT monitors")
|
||||
|
||||
def test_rtt_connection_is_disabled_poll(self):
|
||||
client = self.rs_or_single_client(serverMonitoringMode="poll")
|
||||
self.addCleanup(client.close)
|
||||
|
||||
self.assert_rtt_connection_is_disabled(client)
|
||||
|
||||
def test_rtt_connection_is_disabled_auto(self):
|
||||
@ -389,7 +410,6 @@ class TestServerMonitoringMode(IntegrationTest):
|
||||
for env in envs:
|
||||
with patch.dict("os.environ", env):
|
||||
client = self.rs_or_single_client(serverMonitoringMode="auto")
|
||||
self.addCleanup(client.close)
|
||||
self.assert_rtt_connection_is_disabled(client)
|
||||
|
||||
def assert_rtt_connection_is_disabled(self, client):
|
||||
@ -397,7 +417,10 @@ class TestServerMonitoringMode(IntegrationTest):
|
||||
for _, server in client._topology._servers.items():
|
||||
monitor = server._monitor
|
||||
self.assertFalse(monitor._stream)
|
||||
self.assertIsNone(monitor._rtt_monitor._executor._thread)
|
||||
if _IS_SYNC:
|
||||
self.assertIsNone(monitor._rtt_monitor._executor._thread)
|
||||
else:
|
||||
self.assertIsNone(monitor._rtt_monitor._executor._task)
|
||||
|
||||
|
||||
class MockTCPHandler(socketserver.BaseRequestHandler):
|
||||
@ -420,16 +443,46 @@ class TestHeartbeatStartOrdering(PyMongoTestCase):
|
||||
def test_heartbeat_start_ordering(self):
|
||||
events = []
|
||||
listener = HeartbeatEventsListListener(events)
|
||||
server = TCPServer(("localhost", 9999), MockTCPHandler)
|
||||
server.events = events
|
||||
server_thread = threading.Thread(target=server.handle_request_and_shutdown)
|
||||
server_thread.start()
|
||||
_c = self.simple_client(
|
||||
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
|
||||
)
|
||||
server_thread.join()
|
||||
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
|
||||
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
|
||||
|
||||
if _IS_SYNC:
|
||||
server = TCPServer(("localhost", 9999), MockTCPHandler)
|
||||
server.events = events
|
||||
server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown)
|
||||
server_thread.start()
|
||||
_c = self.simple_client(
|
||||
"mongodb://localhost:9999",
|
||||
serverSelectionTimeoutMS=500,
|
||||
event_listeners=(listener,),
|
||||
)
|
||||
server_thread.join()
|
||||
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
|
||||
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
|
||||
|
||||
else:
|
||||
|
||||
def handle_client(reader: StreamReader, writer: StreamWriter):
|
||||
events.append("client connected")
|
||||
if (reader.read(1024)).strip():
|
||||
events.append("client hello received")
|
||||
writer.close()
|
||||
writer.wait_closed()
|
||||
|
||||
server = asyncio.start_server(handle_client, "localhost", 9999)
|
||||
server.events = events
|
||||
server.start_serving()
|
||||
_c = self.simple_client(
|
||||
"mongodb://localhost:9999",
|
||||
serverSelectionTimeoutMS=500,
|
||||
event_listeners=(listener,),
|
||||
)
|
||||
_c._connect()
|
||||
|
||||
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
|
||||
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
|
||||
|
||||
server.close()
|
||||
server.wait_closed()
|
||||
_c.close()
|
||||
|
||||
self.assertEqual(
|
||||
events,
|
||||
|
||||
@ -543,6 +543,14 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
self.skipTest("Implement PYTHON-1894")
|
||||
if "timeoutMS applied to entire download" in spec["description"]:
|
||||
self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime")
|
||||
if (
|
||||
"Error returned from connection pool clear with interruptInUseConnections=true is retryable"
|
||||
in spec["description"]
|
||||
and not _IS_SYNC
|
||||
):
|
||||
self.skipTest("PYTHON-5170 tests are flakey")
|
||||
if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC:
|
||||
self.skipTest("PYTHON-5174 tests are flakey")
|
||||
|
||||
class_name = self.__class__.__name__.lower()
|
||||
description = spec["description"].lower()
|
||||
|
||||
@ -1078,3 +1078,19 @@ def create_async_event():
|
||||
|
||||
def create_event():
|
||||
return threading.Event()
|
||||
|
||||
|
||||
def async_create_barrier(N_TASKS, timeout: float | None = None):
|
||||
return asyncio.Barrier(N_TASKS)
|
||||
|
||||
|
||||
def create_barrier(N_TASKS, timeout: float | None = None):
|
||||
return threading.Barrier(N_TASKS, timeout=timeout)
|
||||
|
||||
|
||||
async def async_barrier_wait(barrier, timeout: float | None = None):
|
||||
await asyncio.wait_for(barrier.wait(), timeout=timeout)
|
||||
|
||||
|
||||
def barrier_wait(barrier, timeout: float | None = None):
|
||||
barrier.wait()
|
||||
|
||||
@ -124,6 +124,8 @@ replacements = {
|
||||
"AsyncMockPool": "MockPool",
|
||||
"StopAsyncIteration": "StopIteration",
|
||||
"create_async_event": "create_event",
|
||||
"async_create_barrier": "create_barrier",
|
||||
"async_barrier_wait": "barrier_wait",
|
||||
"async_joinall": "joinall",
|
||||
}
|
||||
|
||||
@ -213,6 +215,7 @@ converted_tests = [
|
||||
"test_custom_types.py",
|
||||
"test_database.py",
|
||||
"test_data_lake.py",
|
||||
"test_discovery_and_monitoring.py",
|
||||
"test_dns.py",
|
||||
"test_encryption.py",
|
||||
"test_examples.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user