Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2025-04-04 18:07:23 -05:00
commit 0dece2fff4
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
20 changed files with 535 additions and 63 deletions

View File

@ -24,6 +24,7 @@ PyMongo 4.12 brings a number of changes including:
:class:`~pymongo.read_preferences.SecondaryPreferred`,
:class:`~pymongo.read_preferences.Nearest`. Support for ``hedge`` will be removed in PyMongo 5.0.
- Removed PyOpenSSL support from the asynchronous API due to limitations of the CPython asyncio.Protocol SSL implementation.
- Allow valid SRV hostnames with less than 3 parts.
Issues Resolved
...............

View File

@ -878,6 +878,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._opened = False
self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv:
self._init_background()
@ -1709,6 +1710,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
If this client was created with "connect=False", calling _get_topology
launches the connection process in the background.
"""
if not _IS_SYNC:
if self._loop is None:
self._loop = asyncio.get_running_loop()
elif self._loop != asyncio.get_running_loop():
raise RuntimeError(
"Cannot use AsyncMongoClient in different event loop. AsyncMongoClient uses low-level asyncio APIs that bind it to the event loop it was created on."
)
if not self._opened:
if self._resolve_srv_info["is_srv"]:
await self._resolve_srv()
@ -2840,7 +2848,7 @@ class _ClientConnectionRetryable(Generic[T]):
_debug_log(
_COMMAND_LOGGER,
message=f"Retrying write attempt number {self._attempt_number}",
clientId=self._client.client_id,
clientId=self._client._topology_settings._topology_id,
commandName=self._operation,
operationId=self._operation_id,
)

View File

@ -931,13 +931,15 @@ class Pool:
return
if self.opts.max_idle_time_seconds is not None:
close_conns = []
async with self.lock:
while (
self.conns
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
conn = self.conns.pop()
await conn.close_conn(ConnectionClosedReason.IDLE)
close_conns.append(self.conns.pop())
for conn in close_conns:
await conn.close_conn(ConnectionClosedReason.IDLE)
while True:
async with self.size_cond:
@ -957,14 +959,18 @@ class Pool:
self._pending += 1
incremented = True
conn = await self.connect()
close_conn = False
async with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
await conn.close_conn(ConnectionClosedReason.STALE)
return
self.conns.appendleft(conn)
self.active_contexts.discard(conn.cancel_context)
close_conn = True
if not close_conn:
self.conns.appendleft(conn)
self.active_contexts.discard(conn.cancel_context)
if close_conn:
await conn.close_conn(ConnectionClosedReason.STALE)
return
finally:
if incremented:
# Notify after adding the socket to the pool.
@ -1343,17 +1349,20 @@ class Pool:
error=ConnectionClosedReason.ERROR,
)
else:
close_conn = False
async with self.lock:
# Hold the lock to ensure this section does not race with
# Pool.reset().
if self.stale_generation(conn.generation, conn.service_id):
await conn.close_conn(ConnectionClosedReason.STALE)
close_conn = True
else:
conn.update_last_checkin_time()
conn.update_is_writable(bool(self.is_writable))
self.conns.appendleft(conn)
# Notify any threads waiting to create a connection.
self._max_connecting_cond.notify()
if close_conn:
await conn.close_conn(ConnectionClosedReason.STALE)
async with self.size_cond:
if txn:

View File

@ -90,14 +90,12 @@ class _SrvResolver:
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
except ValueError:
pass
try:
self.__plist = self.__fqdn.split(".")[1:]
split_fqdn = self.__fqdn.split(".")
self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
async def get_options(self) -> Optional[str]:
from dns import resolver
@ -139,6 +137,10 @@ class _SrvResolver:
# Validate hosts
for node in nodes:
if self.__fqdn == node[0].lower():
raise ConfigurationError(
"Invalid SRV host: return address is identical to SRV hostname"
)
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
except Exception:

View File

@ -244,8 +244,6 @@ class Topology:
# Close servers and clear the pools.
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()

View File

@ -876,6 +876,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self._opened = False
self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv:
self._init_background()
@ -1703,6 +1704,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
If this client was created with "connect=False", calling _get_topology
launches the connection process in the background.
"""
if not _IS_SYNC:
if self._loop is None:
self._loop = asyncio.get_running_loop()
elif self._loop != asyncio.get_running_loop():
raise RuntimeError(
"Cannot use MongoClient in different event loop. MongoClient uses low-level asyncio APIs that bind it to the event loop it was created on."
)
if not self._opened:
if self._resolve_srv_info["is_srv"]:
self._resolve_srv()
@ -2826,7 +2834,7 @@ class _ClientConnectionRetryable(Generic[T]):
_debug_log(
_COMMAND_LOGGER,
message=f"Retrying write attempt number {self._attempt_number}",
clientId=self._client.client_id,
clientId=self._client._topology_settings._topology_id,
commandName=self._operation,
operationId=self._operation_id,
)

View File

@ -927,13 +927,15 @@ class Pool:
return
if self.opts.max_idle_time_seconds is not None:
close_conns = []
with self.lock:
while (
self.conns
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
conn = self.conns.pop()
conn.close_conn(ConnectionClosedReason.IDLE)
close_conns.append(self.conns.pop())
for conn in close_conns:
conn.close_conn(ConnectionClosedReason.IDLE)
while True:
with self.size_cond:
@ -953,14 +955,18 @@ class Pool:
self._pending += 1
incremented = True
conn = self.connect()
close_conn = False
with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
conn.close_conn(ConnectionClosedReason.STALE)
return
self.conns.appendleft(conn)
self.active_contexts.discard(conn.cancel_context)
close_conn = True
if not close_conn:
self.conns.appendleft(conn)
self.active_contexts.discard(conn.cancel_context)
if close_conn:
conn.close_conn(ConnectionClosedReason.STALE)
return
finally:
if incremented:
# Notify after adding the socket to the pool.
@ -1339,17 +1345,20 @@ class Pool:
error=ConnectionClosedReason.ERROR,
)
else:
close_conn = False
with self.lock:
# Hold the lock to ensure this section does not race with
# Pool.reset().
if self.stale_generation(conn.generation, conn.service_id):
conn.close_conn(ConnectionClosedReason.STALE)
close_conn = True
else:
conn.update_last_checkin_time()
conn.update_is_writable(bool(self.is_writable))
self.conns.appendleft(conn)
# Notify any threads waiting to create a connection.
self._max_connecting_cond.notify()
if close_conn:
conn.close_conn(ConnectionClosedReason.STALE)
with self.size_cond:
if txn:

View File

@ -90,14 +90,12 @@ class _SrvResolver:
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
except ValueError:
pass
try:
self.__plist = self.__fqdn.split(".")[1:]
split_fqdn = self.__fqdn.split(".")
self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
def get_options(self) -> Optional[str]:
from dns import resolver
@ -139,6 +137,10 @@ class _SrvResolver:
# Validate hosts
for node in nodes:
if self.__fqdn == node[0].lower():
raise ConfigurationError(
"Invalid SRV host: return address is identical to SRV hostname"
)
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
except Exception:

View File

@ -244,8 +244,6 @@ class Topology:
# Close servers and clear the pools.
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()

View File

@ -117,6 +117,8 @@ filterwarnings = [
"module:unclosed <ssl.SSLSocket:ResourceWarning",
"module:unclosed <socket object:ResourceWarning",
"module:unclosed transport:ResourceWarning",
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/724
"module:unclosed event loop:ResourceWarning",
# https://github.com/eventlet/eventlet/issues/818
"module:please use dns.resolver.Resolver.resolve:DeprecationWarning",
# https://github.com/dateutil/dateutil/issues/1314

View File

@ -0,0 +1,36 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test that the asynchronous API detects event loop changes and fails correctly."""
from __future__ import annotations
import asyncio
import unittest
from pymongo import AsyncMongoClient
class TestClientLoopSafety(unittest.TestCase):
def test_client_errors_on_different_loop(self):
client = AsyncMongoClient()
loop1 = asyncio.new_event_loop()
loop1.run_until_complete(client.aconnect())
loop2 = asyncio.new_event_loop()
with self.assertRaisesRegex(
RuntimeError, "Cannot use AsyncMongoClient in different event loop"
):
loop2.run_until_complete(client.aconnect())
loop1.run_until_complete(client.close())
loop1.close()
loop2.close()

View File

@ -30,6 +30,7 @@ from test.asynchronous import (
unittest,
)
from test.utils_shared import async_wait_until
from unittest.mock import MagicMock, patch
from pymongo.asynchronous.uri_parser import parse_uri
from pymongo.common import validate_read_preference_tags
@ -186,12 +187,6 @@ create_tests(TestDNSSharded)
class TestParsingErrors(AsyncPyMongoTestCase):
async def test_invalid_host(self):
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
client = self.simple_client("mongodb+srv://mongodb")
await client.aconnect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
client = self.simple_client("mongodb+srv://mongodb.com")
await client.aconnect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
client = self.simple_client("mongodb+srv://127.0.0.1")
await client.aconnect()
@ -207,5 +202,93 @@ class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest):
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
class TestInitialDnsSeedlistDiscovery(AsyncPyMongoTestCase):
"""
Initial DNS Seedlist Discovery prose tests
https://github.com/mongodb/specifications/blob/0a7a8b5/source/initial-dns-seedlist-discovery/tests/README.md#prose-tests
"""
async def run_initial_dns_seedlist_discovery_prose_tests(self, test_cases):
for case in test_cases:
with patch("dns.asyncresolver.resolve") as mock_resolver:
async def mock_resolve(query, record_type, *args, **kwargs):
mock_srv = MagicMock()
mock_srv.target.to_text.return_value = case["mock_target"]
return [mock_srv]
mock_resolver.side_effect = mock_resolve
domain = case["query"].split("._tcp.")[1]
connection_string = f"mongodb+srv://{domain}"
try:
await parse_uri(connection_string)
except ConfigurationError as e:
self.assertIn(case["expected_error"], str(e))
else:
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
async def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
with patch("dns.asyncresolver.resolve"):
await parse_uri("mongodb+srv://localhost/")
await parse_uri("mongodb+srv://mongo.local/")
async def test_2_throw_when_return_address_does_not_end_with_srv_domain(self):
test_cases = [
{
"query": "_mongodb._tcp.localhost",
"mock_target": "localhost.mongodb",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongodb.com",
"mock_target": "blogs.evil.com",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongo.local",
"mock_target": "test_1.evil.com",
"expected_error": "Invalid SRV host",
},
]
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
async def test_3_throw_when_return_address_is_identical_to_srv_hostname(self):
test_cases = [
{
"query": "_mongodb._tcp.localhost",
"mock_target": "localhost",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.mongo.local",
"mock_target": "mongo.local",
"expected_error": "Invalid SRV host",
},
]
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
async def test_4_throw_when_return_address_does_not_contain_dot_separating_shared_part_of_domain(
self
):
test_cases = [
{
"query": "_mongodb._tcp.localhost",
"mock_target": "test_1.cluster_1localhost",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.mongo.local",
"mock_target": "test_1.my_hostmongo.local",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongodb.com",
"mock_target": "cluster.testmongodb.com",
"expected_error": "Invalid SRV host",
},
]
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
if __name__ == "__main__":
unittest.main()

View File

@ -102,7 +102,14 @@ class TestLogger(AsyncIntegrationTest):
await self.db.test.insert_one({"x": "1"})
async with self.fail_point(
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
{
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorCode": 10107,
"errorLabels": ["RetryableWriteError"],
},
}
):
with self.assertLogs("pymongo.command", level="DEBUG") as cm:
await self.db.test.find_one({"x": "1"})
@ -110,7 +117,27 @@ class TestLogger(AsyncIntegrationTest):
retry_messages = [
r.getMessage() for r in cm.records if "Retrying read attempt" in r.getMessage()
]
print(retry_messages)
self.assertEqual(len(retry_messages), 1)
@async_client_context.require_failCommand_fail_point
@async_client_context.require_retryable_writes
async def test_logging_retry_write_attempts(self):
async with self.fail_point(
{
"mode": {"times": 1},
"data": {
"errorCode": 10107,
"errorLabels": ["RetryableWriteError"],
"failCommands": ["insert"],
},
}
):
with self.assertLogs("pymongo.command", level="DEBUG") as cm:
await self.db.test.insert_one({"x": "1"})
retry_messages = [
r.getMessage() for r in cm.records if "Retrying write attempt" in r.getMessage()
]
self.assertEqual(len(retry_messages), 1)

View File

@ -222,7 +222,6 @@ class EntityMapUtil:
self._listeners: Dict[str, EventListenerUtil] = {}
self._session_lsids: Dict[str, Mapping[str, Any]] = {}
self.test: UnifiedSpecTestMixinV1 = test_class
self._cluster_time: Mapping[str, Any] = {}
def __contains__(self, item):
return item in self._entities
@ -421,13 +420,11 @@ class EntityMapUtil:
# session has been closed.
return self._session_lsids[session_name]
async def advance_cluster_times(self) -> None:
async def advance_cluster_times(self, cluster_time) -> None:
"""Manually synchronize entities when desired"""
if not self._cluster_time:
self._cluster_time = (await self.test.client.admin.command("ping")).get("$clusterTime")
for entity in self._entities.values():
if isinstance(entity, AsyncClientSession) and self._cluster_time:
entity.advance_cluster_time(self._cluster_time)
if isinstance(entity, AsyncClientSession) and cluster_time:
entity.advance_cluster_time(cluster_time)
class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
@ -1044,7 +1041,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
async def _testOperation_createEntities(self, spec):
await self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri)
await self.entity_map.advance_cluster_times()
await self.entity_map.advance_cluster_times(self._cluster_time)
def _testOperation_assertSessionTransactionState(self, spec):
session = self.entity_map[spec["session"]]
@ -1443,11 +1440,12 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
await self.entity_map.create_entities_from_spec(
self.TEST_SPEC.get("createEntities", []), uri=uri
)
self._cluster_time = None
# process initialData
if "initialData" in self.TEST_SPEC:
await self.insert_initial_data(self.TEST_SPEC["initialData"])
self._cluster_time = (await self.client.admin.command("ping")).get("$clusterTime")
await self.entity_map.advance_cluster_times()
self._cluster_time = self.client._topology.max_cluster_time()
await self.entity_map.advance_cluster_times(self._cluster_time)
if "expectLogMessages" in spec:
expect_log_messages = spec["expectLogMessages"]

View File

@ -0,0 +1,176 @@
{
"description": "WaitQueueTimeoutError does not clear the pool",
"schemaVersion": "1.9",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"single",
"replicaset",
"sharded"
]
}
],
"createEntities": [
{
"client": {
"id": "failPointClient",
"useMultipleMongoses": false
}
},
{
"client": {
"id": "client",
"uriOptions": {
"maxPoolSize": 1,
"appname": "waitQueueTimeoutErrorTest"
},
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent",
"poolClearedEvent"
]
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "test"
}
}
],
"tests": [
{
"description": "WaitQueueTimeoutError does not clear the pool",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "failPointClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"ping"
],
"blockConnection": true,
"blockTimeMS": 500,
"appName": "waitQueueTimeoutErrorTest"
}
}
}
},
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"thread": {
"id": "thread0"
}
}
]
}
},
{
"name": "runOnThread",
"object": "testRunner",
"arguments": {
"thread": "thread0",
"operation": {
"name": "runCommand",
"object": "database",
"arguments": {
"command": {
"ping": 1
},
"commandName": "ping"
}
}
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"commandStartedEvent": {
"commandName": "ping"
}
},
"count": 1
}
},
{
"name": "runCommand",
"object": "database",
"arguments": {
"timeoutMS": 100,
"command": {
"hello": 1
},
"commandName": "hello"
},
"expectError": {
"isTimeoutError": true
}
},
{
"name": "waitForThread",
"object": "testRunner",
"arguments": {
"thread": "thread0"
}
},
{
"name": "runCommand",
"object": "database",
"arguments": {
"command": {
"hello": 1
},
"commandName": "hello"
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "command",
"events": [
{
"commandStartedEvent": {
"commandName": "ping",
"databaseName": "test",
"command": {
"ping": 1
}
}
},
{
"commandStartedEvent": {
"commandName": "hello",
"databaseName": "test",
"command": {
"hello": 1
}
}
}
]
},
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -30,6 +30,7 @@ from test import (
unittest,
)
from test.utils_shared import wait_until
from unittest.mock import MagicMock, patch
from pymongo.common import validate_read_preference_tags
from pymongo.errors import ConfigurationError
@ -184,12 +185,6 @@ create_tests(TestDNSSharded)
class TestParsingErrors(PyMongoTestCase):
def test_invalid_host(self):
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
client = self.simple_client("mongodb+srv://mongodb")
client._connect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
client = self.simple_client("mongodb+srv://mongodb.com")
client._connect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
client = self.simple_client("mongodb+srv://127.0.0.1")
client._connect()
@ -205,5 +200,93 @@ class TestCaseInsensitive(IntegrationTest):
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
class TestInitialDnsSeedlistDiscovery(PyMongoTestCase):
"""
Initial DNS Seedlist Discovery prose tests
https://github.com/mongodb/specifications/blob/0a7a8b5/source/initial-dns-seedlist-discovery/tests/README.md#prose-tests
"""
def run_initial_dns_seedlist_discovery_prose_tests(self, test_cases):
for case in test_cases:
with patch("dns.resolver.resolve") as mock_resolver:
def mock_resolve(query, record_type, *args, **kwargs):
mock_srv = MagicMock()
mock_srv.target.to_text.return_value = case["mock_target"]
return [mock_srv]
mock_resolver.side_effect = mock_resolve
domain = case["query"].split("._tcp.")[1]
connection_string = f"mongodb+srv://{domain}"
try:
parse_uri(connection_string)
except ConfigurationError as e:
self.assertIn(case["expected_error"], str(e))
else:
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
with patch("dns.resolver.resolve"):
parse_uri("mongodb+srv://localhost/")
parse_uri("mongodb+srv://mongo.local/")
def test_2_throw_when_return_address_does_not_end_with_srv_domain(self):
test_cases = [
{
"query": "_mongodb._tcp.localhost",
"mock_target": "localhost.mongodb",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongodb.com",
"mock_target": "blogs.evil.com",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongo.local",
"mock_target": "test_1.evil.com",
"expected_error": "Invalid SRV host",
},
]
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
def test_3_throw_when_return_address_is_identical_to_srv_hostname(self):
test_cases = [
{
"query": "_mongodb._tcp.localhost",
"mock_target": "localhost",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.mongo.local",
"mock_target": "mongo.local",
"expected_error": "Invalid SRV host",
},
]
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
def test_4_throw_when_return_address_does_not_contain_dot_separating_shared_part_of_domain(
self
):
test_cases = [
{
"query": "_mongodb._tcp.localhost",
"mock_target": "test_1.cluster_1localhost",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.mongo.local",
"mock_target": "test_1.my_hostmongo.local",
"expected_error": "Invalid SRV host",
},
{
"query": "_mongodb._tcp.blogs.mongodb.com",
"mock_target": "cluster.testmongodb.com",
"expected_error": "Invalid SRV host",
},
]
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
if __name__ == "__main__":
unittest.main()

View File

@ -101,7 +101,14 @@ class TestLogger(IntegrationTest):
self.db.test.insert_one({"x": "1"})
with self.fail_point(
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
{
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorCode": 10107,
"errorLabels": ["RetryableWriteError"],
},
}
):
with self.assertLogs("pymongo.command", level="DEBUG") as cm:
self.db.test.find_one({"x": "1"})
@ -109,7 +116,27 @@ class TestLogger(IntegrationTest):
retry_messages = [
r.getMessage() for r in cm.records if "Retrying read attempt" in r.getMessage()
]
print(retry_messages)
self.assertEqual(len(retry_messages), 1)
@client_context.require_failCommand_fail_point
@client_context.require_retryable_writes
def test_logging_retry_write_attempts(self):
with self.fail_point(
{
"mode": {"times": 1},
"data": {
"errorCode": 10107,
"errorLabels": ["RetryableWriteError"],
"failCommands": ["insert"],
},
}
):
with self.assertLogs("pymongo.command", level="DEBUG") as cm:
self.db.test.insert_one({"x": "1"})
retry_messages = [
r.getMessage() for r in cm.records if "Retrying write attempt" in r.getMessage()
]
self.assertEqual(len(retry_messages), 1)

View File

@ -24,6 +24,7 @@ from urllib.parse import quote_plus
sys.path[0:0] = [""]
from test import unittest
from unittest.mock import patch
from bson.binary import JAVA_LEGACY
from pymongo import ReadPreference

View File

@ -221,7 +221,6 @@ class EntityMapUtil:
self._listeners: Dict[str, EventListenerUtil] = {}
self._session_lsids: Dict[str, Mapping[str, Any]] = {}
self.test: UnifiedSpecTestMixinV1 = test_class
self._cluster_time: Mapping[str, Any] = {}
def __contains__(self, item):
return item in self._entities
@ -420,13 +419,11 @@ class EntityMapUtil:
# session has been closed.
return self._session_lsids[session_name]
def advance_cluster_times(self) -> None:
def advance_cluster_times(self, cluster_time) -> None:
"""Manually synchronize entities when desired"""
if not self._cluster_time:
self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime")
for entity in self._entities.values():
if isinstance(entity, ClientSession) and self._cluster_time:
entity.advance_cluster_time(self._cluster_time)
if isinstance(entity, ClientSession) and cluster_time:
entity.advance_cluster_time(cluster_time)
class UnifiedSpecTestMixinV1(IntegrationTest):
@ -1035,7 +1032,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
def _testOperation_createEntities(self, spec):
self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri)
self.entity_map.advance_cluster_times()
self.entity_map.advance_cluster_times(self._cluster_time)
def _testOperation_assertSessionTransactionState(self, spec):
session = self.entity_map[spec["session"]]
@ -1428,11 +1425,12 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
self._uri = uri
self.entity_map = EntityMapUtil(self)
self.entity_map.create_entities_from_spec(self.TEST_SPEC.get("createEntities", []), uri=uri)
self._cluster_time = None
# process initialData
if "initialData" in self.TEST_SPEC:
self.insert_initial_data(self.TEST_SPEC["initialData"])
self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime")
self.entity_map.advance_cluster_times()
self._cluster_time = self.client._topology.max_cluster_time()
self.entity_map.advance_cluster_times(self._cluster_time)
if "expectLogMessages" in spec:
expect_log_messages = spec["expectLogMessages"]

View File

@ -133,6 +133,7 @@ replacements = {
"async_joinall": "joinall",
"_async_create_connection": "_create_connection",
"pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts",
"dns.asyncresolver.resolve": "dns.resolver.resolve",
}
docstring_replacements: dict[tuple[str, str], str] = {
@ -179,7 +180,12 @@ gridfs_files = [
def async_only_test(f: str) -> bool:
"""Return True for async tests that should not be converted to sync."""
return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"]
return f in [
"test_locks.py",
"test_concurrency.py",
"test_async_cancellation.py",
"test_async_loop_safety.py",
]
test_files = [