From 8b668898b880e848a638f84c6d024e0640079f14 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 3 Apr 2025 12:05:45 -0700 Subject: [PATCH 1/5] PYTHON-5208 Add spec test for wait queue timeout errors do not clear the pool (#2199) Also stop running the ping command to advance session cluster times in the unified tests. --- test/asynchronous/unified_format.py | 16 ++- test/csot/waitQueueTimeout.json | 176 ++++++++++++++++++++++++++++ test/unified_format.py | 16 ++- 3 files changed, 190 insertions(+), 18 deletions(-) create mode 100644 test/csot/waitQueueTimeout.json diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index cc516ee82..9099efbf0 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -222,7 +222,6 @@ class EntityMapUtil: self._listeners: Dict[str, EventListenerUtil] = {} self._session_lsids: Dict[str, Mapping[str, Any]] = {} self.test: UnifiedSpecTestMixinV1 = test_class - self._cluster_time: Mapping[str, Any] = {} def __contains__(self, item): return item in self._entities @@ -421,13 +420,11 @@ class EntityMapUtil: # session has been closed. return self._session_lsids[session_name] - async def advance_cluster_times(self) -> None: + async def advance_cluster_times(self, cluster_time) -> None: """Manually synchronize entities when desired""" - if not self._cluster_time: - self._cluster_time = (await self.test.client.admin.command("ping")).get("$clusterTime") for entity in self._entities.values(): - if isinstance(entity, AsyncClientSession) and self._cluster_time: - entity.advance_cluster_time(self._cluster_time) + if isinstance(entity, AsyncClientSession) and cluster_time: + entity.advance_cluster_time(cluster_time) class UnifiedSpecTestMixinV1(AsyncIntegrationTest): @@ -1044,7 +1041,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): async def _testOperation_createEntities(self, spec): await self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri) - await self.entity_map.advance_cluster_times() + await self.entity_map.advance_cluster_times(self._cluster_time) def _testOperation_assertSessionTransactionState(self, spec): session = self.entity_map[spec["session"]] @@ -1443,11 +1440,12 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): await self.entity_map.create_entities_from_spec( self.TEST_SPEC.get("createEntities", []), uri=uri ) + self._cluster_time = None # process initialData if "initialData" in self.TEST_SPEC: await self.insert_initial_data(self.TEST_SPEC["initialData"]) - self._cluster_time = (await self.client.admin.command("ping")).get("$clusterTime") - await self.entity_map.advance_cluster_times() + self._cluster_time = self.client._topology.max_cluster_time() + await self.entity_map.advance_cluster_times(self._cluster_time) if "expectLogMessages" in spec: expect_log_messages = spec["expectLogMessages"] diff --git a/test/csot/waitQueueTimeout.json b/test/csot/waitQueueTimeout.json new file mode 100644 index 000000000..138d5cc16 --- /dev/null +++ b/test/csot/waitQueueTimeout.json @@ -0,0 +1,176 @@ +{ + "description": "WaitQueueTimeoutError does not clear the pool", + "schemaVersion": "1.9", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "single", + "replicaset", + "sharded" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "failPointClient", + "useMultipleMongoses": false + } + }, + { + "client": { + "id": "client", + "uriOptions": { + "maxPoolSize": 1, + "appname": "waitQueueTimeoutErrorTest" + }, + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent", + "poolClearedEvent" + ] + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "test" + } + } + ], + "tests": [ + { + "description": "WaitQueueTimeoutError does not clear the pool", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "ping" + ], + "blockConnection": true, + "blockTimeMS": 500, + "appName": "waitQueueTimeoutErrorTest" + } + } + } + }, + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "thread": { + "id": "thread0" + } + } + ] + } + }, + { + "name": "runOnThread", + "object": "testRunner", + "arguments": { + "thread": "thread0", + "operation": { + "name": "runCommand", + "object": "database", + "arguments": { + "command": { + "ping": 1 + }, + "commandName": "ping" + } + } + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "commandStartedEvent": { + "commandName": "ping" + } + }, + "count": 1 + } + }, + { + "name": "runCommand", + "object": "database", + "arguments": { + "timeoutMS": 100, + "command": { + "hello": 1 + }, + "commandName": "hello" + }, + "expectError": { + "isTimeoutError": true + } + }, + { + "name": "waitForThread", + "object": "testRunner", + "arguments": { + "thread": "thread0" + } + }, + { + "name": "runCommand", + "object": "database", + "arguments": { + "command": { + "hello": 1 + }, + "commandName": "hello" + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "command", + "events": [ + { + "commandStartedEvent": { + "commandName": "ping", + "databaseName": "test", + "command": { + "ping": 1 + } + } + }, + { + "commandStartedEvent": { + "commandName": "hello", + "databaseName": "test", + "command": { + "hello": 1 + } + } + } + ] + }, + { + "client": "client", + "eventType": "cmap", + "events": [] + } + ] + } + ] +} diff --git a/test/unified_format.py b/test/unified_format.py index fd7f92909..71d6cd50d 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -221,7 +221,6 @@ class EntityMapUtil: self._listeners: Dict[str, EventListenerUtil] = {} self._session_lsids: Dict[str, Mapping[str, Any]] = {} self.test: UnifiedSpecTestMixinV1 = test_class - self._cluster_time: Mapping[str, Any] = {} def __contains__(self, item): return item in self._entities @@ -420,13 +419,11 @@ class EntityMapUtil: # session has been closed. return self._session_lsids[session_name] - def advance_cluster_times(self) -> None: + def advance_cluster_times(self, cluster_time) -> None: """Manually synchronize entities when desired""" - if not self._cluster_time: - self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime") for entity in self._entities.values(): - if isinstance(entity, ClientSession) and self._cluster_time: - entity.advance_cluster_time(self._cluster_time) + if isinstance(entity, ClientSession) and cluster_time: + entity.advance_cluster_time(cluster_time) class UnifiedSpecTestMixinV1(IntegrationTest): @@ -1035,7 +1032,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): def _testOperation_createEntities(self, spec): self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri) - self.entity_map.advance_cluster_times() + self.entity_map.advance_cluster_times(self._cluster_time) def _testOperation_assertSessionTransactionState(self, spec): session = self.entity_map[spec["session"]] @@ -1428,11 +1425,12 @@ class UnifiedSpecTestMixinV1(IntegrationTest): self._uri = uri self.entity_map = EntityMapUtil(self) self.entity_map.create_entities_from_spec(self.TEST_SPEC.get("createEntities", []), uri=uri) + self._cluster_time = None # process initialData if "initialData" in self.TEST_SPEC: self.insert_initial_data(self.TEST_SPEC["initialData"]) - self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime") - self.entity_map.advance_cluster_times() + self._cluster_time = self.client._topology.max_cluster_time() + self.entity_map.advance_cluster_times(self._cluster_time) if "expectLogMessages" in spec: expect_log_messages = spec["expectLogMessages"] From b40223938c9646e0b1a8f60d22e36120991b394d Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 3 Apr 2025 15:32:47 -0400 Subject: [PATCH 2/5] PYTHON-5219 - Avoid awaiting coroutines when holding locks (#2250) --- pymongo/asynchronous/pool.py | 23 ++++++++++++++++------- pymongo/asynchronous/topology.py | 2 -- pymongo/synchronous/pool.py | 23 ++++++++++++++++------- pymongo/synchronous/topology.py | 2 -- 4 files changed, 32 insertions(+), 18 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 18644cf7d..a67cc5f3c 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -931,13 +931,15 @@ class Pool: return if self.opts.max_idle_time_seconds is not None: + close_conns = [] async with self.lock: while ( self.conns and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): - conn = self.conns.pop() - await conn.close_conn(ConnectionClosedReason.IDLE) + close_conns.append(self.conns.pop()) + for conn in close_conns: + await conn.close_conn(ConnectionClosedReason.IDLE) while True: async with self.size_cond: @@ -957,14 +959,18 @@ class Pool: self._pending += 1 incremented = True conn = await self.connect() + close_conn = False async with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. if self.gen.get_overall() != reference_generation: - await conn.close_conn(ConnectionClosedReason.STALE) - return - self.conns.appendleft(conn) - self.active_contexts.discard(conn.cancel_context) + close_conn = True + if not close_conn: + self.conns.appendleft(conn) + self.active_contexts.discard(conn.cancel_context) + if close_conn: + await conn.close_conn(ConnectionClosedReason.STALE) + return finally: if incremented: # Notify after adding the socket to the pool. @@ -1343,17 +1349,20 @@ class Pool: error=ConnectionClosedReason.ERROR, ) else: + close_conn = False async with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). if self.stale_generation(conn.generation, conn.service_id): - await conn.close_conn(ConnectionClosedReason.STALE) + close_conn = True else: conn.update_last_checkin_time() conn.update_is_writable(bool(self.is_writable)) self.conns.appendleft(conn) # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() + if close_conn: + await conn.close_conn(ConnectionClosedReason.STALE) async with self.size_cond: if txn: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 9de069af7..b315cc33b 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -244,8 +244,6 @@ class Topology: # Close servers and clear the pools. for server in self._servers.values(): await server.close() - if not _IS_SYNC: - self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 1151776b9..224834af3 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -927,13 +927,15 @@ class Pool: return if self.opts.max_idle_time_seconds is not None: + close_conns = [] with self.lock: while ( self.conns and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): - conn = self.conns.pop() - conn.close_conn(ConnectionClosedReason.IDLE) + close_conns.append(self.conns.pop()) + for conn in close_conns: + conn.close_conn(ConnectionClosedReason.IDLE) while True: with self.size_cond: @@ -953,14 +955,18 @@ class Pool: self._pending += 1 incremented = True conn = self.connect() + close_conn = False with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. if self.gen.get_overall() != reference_generation: - conn.close_conn(ConnectionClosedReason.STALE) - return - self.conns.appendleft(conn) - self.active_contexts.discard(conn.cancel_context) + close_conn = True + if not close_conn: + self.conns.appendleft(conn) + self.active_contexts.discard(conn.cancel_context) + if close_conn: + conn.close_conn(ConnectionClosedReason.STALE) + return finally: if incremented: # Notify after adding the socket to the pool. @@ -1339,17 +1345,20 @@ class Pool: error=ConnectionClosedReason.ERROR, ) else: + close_conn = False with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). if self.stale_generation(conn.generation, conn.service_id): - conn.close_conn(ConnectionClosedReason.STALE) + close_conn = True else: conn.update_last_checkin_time() conn.update_is_writable(bool(self.is_writable)) self.conns.appendleft(conn) # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() + if close_conn: + conn.close_conn(ConnectionClosedReason.STALE) with self.size_cond: if txn: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index bccc8a2eb..7df475b4c 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -244,8 +244,6 @@ class Topology: # Close servers and clear the pools. for server in self._servers.values(): server.close() - if not _IS_SYNC: - self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() From e7c0814512ef4bc104bd17919acd80319460e1a0 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 3 Apr 2025 15:33:11 -0400 Subject: [PATCH 3/5] PYTHON-4557 - Fix write log messages for retried commands (#2260) --- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/mongo_client.py | 2 +- test/asynchronous/test_logger.py | 31 ++++++++++++++++++++++++++-- test/test_logger.py | 31 ++++++++++++++++++++++++++-- 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a0ff8741a..7c8f7180b 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2840,7 +2840,7 @@ class _ClientConnectionRetryable(Generic[T]): _debug_log( _COMMAND_LOGGER, message=f"Retrying write attempt number {self._attempt_number}", - clientId=self._client.client_id, + clientId=self._client._topology_settings._topology_id, commandName=self._operation, operationId=self._operation_id, ) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index a674bfb66..14fdefcb6 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2826,7 +2826,7 @@ class _ClientConnectionRetryable(Generic[T]): _debug_log( _COMMAND_LOGGER, message=f"Retrying write attempt number {self._attempt_number}", - clientId=self._client.client_id, + clientId=self._client._topology_settings._topology_id, commandName=self._operation, operationId=self._operation_id, ) diff --git a/test/asynchronous/test_logger.py b/test/asynchronous/test_logger.py index 92c29e111..d024735fd 100644 --- a/test/asynchronous/test_logger.py +++ b/test/asynchronous/test_logger.py @@ -102,7 +102,14 @@ class TestLogger(AsyncIntegrationTest): await self.db.test.insert_one({"x": "1"}) async with self.fail_point( - {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} + { + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorCode": 10107, + "errorLabels": ["RetryableWriteError"], + }, + } ): with self.assertLogs("pymongo.command", level="DEBUG") as cm: await self.db.test.find_one({"x": "1"}) @@ -110,7 +117,27 @@ class TestLogger(AsyncIntegrationTest): retry_messages = [ r.getMessage() for r in cm.records if "Retrying read attempt" in r.getMessage() ] - print(retry_messages) + self.assertEqual(len(retry_messages), 1) + + @async_client_context.require_failCommand_fail_point + @async_client_context.require_retryable_writes + async def test_logging_retry_write_attempts(self): + async with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "errorCode": 10107, + "errorLabels": ["RetryableWriteError"], + "failCommands": ["insert"], + }, + } + ): + with self.assertLogs("pymongo.command", level="DEBUG") as cm: + await self.db.test.insert_one({"x": "1"}) + + retry_messages = [ + r.getMessage() for r in cm.records if "Retrying write attempt" in r.getMessage() + ] self.assertEqual(len(retry_messages), 1) diff --git a/test/test_logger.py b/test/test_logger.py index 398f768c9..a7d97927f 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -101,7 +101,14 @@ class TestLogger(IntegrationTest): self.db.test.insert_one({"x": "1"}) with self.fail_point( - {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} + { + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorCode": 10107, + "errorLabels": ["RetryableWriteError"], + }, + } ): with self.assertLogs("pymongo.command", level="DEBUG") as cm: self.db.test.find_one({"x": "1"}) @@ -109,7 +116,27 @@ class TestLogger(IntegrationTest): retry_messages = [ r.getMessage() for r in cm.records if "Retrying read attempt" in r.getMessage() ] - print(retry_messages) + self.assertEqual(len(retry_messages), 1) + + @client_context.require_failCommand_fail_point + @client_context.require_retryable_writes + def test_logging_retry_write_attempts(self): + with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "errorCode": 10107, + "errorLabels": ["RetryableWriteError"], + "failCommands": ["insert"], + }, + } + ): + with self.assertLogs("pymongo.command", level="DEBUG") as cm: + self.db.test.insert_one({"x": "1"}) + + retry_messages = [ + r.getMessage() for r in cm.records if "Retrying write attempt" in r.getMessage() + ] self.assertEqual(len(retry_messages), 1) From 1c813dc6489214bfab50aed04dfd63455bb4d231 Mon Sep 17 00:00:00 2001 From: "Jeffrey A. Clark" Date: Fri, 4 Apr 2025 13:09:04 -0400 Subject: [PATCH 4/5] PYTHON-4575 Allow valid SRV hostnames with less than 3 parts (#2234) --- doc/changelog.rst | 1 + pymongo/asynchronous/srv_resolver.py | 10 +-- pymongo/synchronous/srv_resolver.py | 10 +-- test/asynchronous/test_dns.py | 95 ++++++++++++++++++++++++++-- test/test_dns.py | 95 ++++++++++++++++++++++++++-- test/test_uri_parser.py | 1 + tools/synchro.py | 1 + 7 files changed, 193 insertions(+), 20 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 4d8b26a5e..1ab3bc49a 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -24,6 +24,7 @@ PyMongo 4.12 brings a number of changes including: :class:`~pymongo.read_preferences.SecondaryPreferred`, :class:`~pymongo.read_preferences.Nearest`. Support for ``hedge`` will be removed in PyMongo 5.0. - Removed PyOpenSSL support from the asynchronous API due to limitations of the CPython asyncio.Protocol SSL implementation. +- Allow valid SRV hostnames with less than 3 parts. Issues Resolved ............... diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py index 8b811e5dc..f7c67af3e 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/asynchronous/srv_resolver.py @@ -90,14 +90,12 @@ class _SrvResolver: raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) except ValueError: pass - try: - self.__plist = self.__fqdn.split(".")[1:] + split_fqdn = self.__fqdn.split(".") + self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn except Exception: raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None self.__slen = len(self.__plist) - if self.__slen < 2: - raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) async def get_options(self) -> Optional[str]: from dns import resolver @@ -139,6 +137,10 @@ class _SrvResolver: # Validate hosts for node in nodes: + if self.__fqdn == node[0].lower(): + raise ConfigurationError( + "Invalid SRV host: return address is identical to SRV hostname" + ) try: nlist = node[0].lower().split(".")[1:][-self.__slen :] except Exception: diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py index 1b36efd1c..cf7b0842a 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -90,14 +90,12 @@ class _SrvResolver: raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) except ValueError: pass - try: - self.__plist = self.__fqdn.split(".")[1:] + split_fqdn = self.__fqdn.split(".") + self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn except Exception: raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None self.__slen = len(self.__plist) - if self.__slen < 2: - raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) def get_options(self) -> Optional[str]: from dns import resolver @@ -139,6 +137,10 @@ class _SrvResolver: # Validate hosts for node in nodes: + if self.__fqdn == node[0].lower(): + raise ConfigurationError( + "Invalid SRV host: return address is identical to SRV hostname" + ) try: nlist = node[0].lower().split(".")[1:][-self.__slen :] except Exception: diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index d0e801e12..01c8d7b40 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -30,6 +30,7 @@ from test.asynchronous import ( unittest, ) from test.utils_shared import async_wait_until +from unittest.mock import MagicMock, patch from pymongo.asynchronous.uri_parser import parse_uri from pymongo.common import validate_read_preference_tags @@ -186,12 +187,6 @@ create_tests(TestDNSSharded) class TestParsingErrors(AsyncPyMongoTestCase): async def test_invalid_host(self): - with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"): - client = self.simple_client("mongodb+srv://mongodb") - await client.aconnect() - with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"): - client = self.simple_client("mongodb+srv://mongodb.com") - await client.aconnect() with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): client = self.simple_client("mongodb+srv://127.0.0.1") await client.aconnect() @@ -207,5 +202,93 @@ class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest): self.assertGreater(len(client.topology_description.server_descriptions()), 1) +class TestInitialDnsSeedlistDiscovery(AsyncPyMongoTestCase): + """ + Initial DNS Seedlist Discovery prose tests + https://github.com/mongodb/specifications/blob/0a7a8b5/source/initial-dns-seedlist-discovery/tests/README.md#prose-tests + """ + + async def run_initial_dns_seedlist_discovery_prose_tests(self, test_cases): + for case in test_cases: + with patch("dns.asyncresolver.resolve") as mock_resolver: + + async def mock_resolve(query, record_type, *args, **kwargs): + mock_srv = MagicMock() + mock_srv.target.to_text.return_value = case["mock_target"] + return [mock_srv] + + mock_resolver.side_effect = mock_resolve + domain = case["query"].split("._tcp.")[1] + connection_string = f"mongodb+srv://{domain}" + try: + await parse_uri(connection_string) + except ConfigurationError as e: + self.assertIn(case["expected_error"], str(e)) + else: + self.fail(f"ConfigurationError was not raised for query: {case['query']}") + + async def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self): + with patch("dns.asyncresolver.resolve"): + await parse_uri("mongodb+srv://localhost/") + await parse_uri("mongodb+srv://mongo.local/") + + async def test_2_throw_when_return_address_does_not_end_with_srv_domain(self): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "localhost.mongodb", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "blogs.evil.com", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongo.local", + "mock_target": "test_1.evil.com", + "expected_error": "Invalid SRV host", + }, + ] + await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + async def test_3_throw_when_return_address_is_identical_to_srv_hostname(self): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "localhost", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.mongo.local", + "mock_target": "mongo.local", + "expected_error": "Invalid SRV host", + }, + ] + await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + async def test_4_throw_when_return_address_does_not_contain_dot_separating_shared_part_of_domain( + self + ): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "test_1.cluster_1localhost", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.mongo.local", + "mock_target": "test_1.my_hostmongo.local", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "cluster.testmongodb.com", + "expected_error": "Invalid SRV host", + }, + ] + await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_dns.py b/test/test_dns.py index 0290eb16d..9360f3f28 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -30,6 +30,7 @@ from test import ( unittest, ) from test.utils_shared import wait_until +from unittest.mock import MagicMock, patch from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError @@ -184,12 +185,6 @@ create_tests(TestDNSSharded) class TestParsingErrors(PyMongoTestCase): def test_invalid_host(self): - with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"): - client = self.simple_client("mongodb+srv://mongodb") - client._connect() - with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"): - client = self.simple_client("mongodb+srv://mongodb.com") - client._connect() with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): client = self.simple_client("mongodb+srv://127.0.0.1") client._connect() @@ -205,5 +200,93 @@ class TestCaseInsensitive(IntegrationTest): self.assertGreater(len(client.topology_description.server_descriptions()), 1) +class TestInitialDnsSeedlistDiscovery(PyMongoTestCase): + """ + Initial DNS Seedlist Discovery prose tests + https://github.com/mongodb/specifications/blob/0a7a8b5/source/initial-dns-seedlist-discovery/tests/README.md#prose-tests + """ + + def run_initial_dns_seedlist_discovery_prose_tests(self, test_cases): + for case in test_cases: + with patch("dns.resolver.resolve") as mock_resolver: + + def mock_resolve(query, record_type, *args, **kwargs): + mock_srv = MagicMock() + mock_srv.target.to_text.return_value = case["mock_target"] + return [mock_srv] + + mock_resolver.side_effect = mock_resolve + domain = case["query"].split("._tcp.")[1] + connection_string = f"mongodb+srv://{domain}" + try: + parse_uri(connection_string) + except ConfigurationError as e: + self.assertIn(case["expected_error"], str(e)) + else: + self.fail(f"ConfigurationError was not raised for query: {case['query']}") + + def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self): + with patch("dns.resolver.resolve"): + parse_uri("mongodb+srv://localhost/") + parse_uri("mongodb+srv://mongo.local/") + + def test_2_throw_when_return_address_does_not_end_with_srv_domain(self): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "localhost.mongodb", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "blogs.evil.com", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongo.local", + "mock_target": "test_1.evil.com", + "expected_error": "Invalid SRV host", + }, + ] + self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + def test_3_throw_when_return_address_is_identical_to_srv_hostname(self): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "localhost", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.mongo.local", + "mock_target": "mongo.local", + "expected_error": "Invalid SRV host", + }, + ] + self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + def test_4_throw_when_return_address_does_not_contain_dot_separating_shared_part_of_domain( + self + ): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "test_1.cluster_1localhost", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.mongo.local", + "mock_target": "test_1.my_hostmongo.local", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "cluster.testmongodb.com", + "expected_error": "Invalid SRV host", + }, + ] + self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index 0baefa0c3..d4d17ac21 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -24,6 +24,7 @@ from urllib.parse import quote_plus sys.path[0:0] = [""] from test import unittest +from unittest.mock import patch from bson.binary import JAVA_LEGACY from pymongo import ReadPreference diff --git a/tools/synchro.py b/tools/synchro.py index f451d09a2..37bf9bc74 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -133,6 +133,7 @@ replacements = { "async_joinall": "joinall", "_async_create_connection": "_create_connection", "pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts", + "dns.asyncresolver.resolve": "dns.resolver.resolve", } docstring_replacements: dict[tuple[str, str], str] = { From 708ce16961f077b7a03fc97e66ccfaccaa0d847a Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 4 Apr 2025 13:22:22 -0400 Subject: [PATCH 5/5] PYTHON-4724 - Prohibit AsyncMongoClient from being used across multiple event loops (#2256) --- pymongo/asynchronous/mongo_client.py | 8 +++++ pymongo/synchronous/mongo_client.py | 8 +++++ pyproject.toml | 2 ++ test/asynchronous/test_async_loop_safety.py | 36 +++++++++++++++++++++ tools/synchro.py | 7 +++- 5 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_async_loop_safety.py diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 7c8f7180b..7744a75d9 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -878,6 +878,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self._opened = False self._closed = False + self._loop: Optional[asyncio.AbstractEventLoop] = None if not is_srv: self._init_background() @@ -1709,6 +1710,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): If this client was created with "connect=False", calling _get_topology launches the connection process in the background. """ + if not _IS_SYNC: + if self._loop is None: + self._loop = asyncio.get_running_loop() + elif self._loop != asyncio.get_running_loop(): + raise RuntimeError( + "Cannot use AsyncMongoClient in different event loop. AsyncMongoClient uses low-level asyncio APIs that bind it to the event loop it was created on." + ) if not self._opened: if self._resolve_srv_info["is_srv"]: await self._resolve_srv() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 14fdefcb6..1c0adb5d6 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -876,6 +876,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): self._opened = False self._closed = False + self._loop: Optional[asyncio.AbstractEventLoop] = None if not is_srv: self._init_background() @@ -1703,6 +1704,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): If this client was created with "connect=False", calling _get_topology launches the connection process in the background. """ + if not _IS_SYNC: + if self._loop is None: + self._loop = asyncio.get_running_loop() + elif self._loop != asyncio.get_running_loop(): + raise RuntimeError( + "Cannot use MongoClient in different event loop. MongoClient uses low-level asyncio APIs that bind it to the event loop it was created on." + ) if not self._opened: if self._resolve_srv_info["is_srv"]: self._resolve_srv() diff --git a/pyproject.toml b/pyproject.toml index 611cac13a..4da75b4c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,8 @@ filterwarnings = [ "module:unclosed bool: """Return True for async tests that should not be converted to sync.""" - return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"] + return f in [ + "test_locks.py", + "test_concurrency.py", + "test_async_cancellation.py", + "test_async_loop_safety.py", + ] test_files = [