Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
0dece2fff4
@ -24,6 +24,7 @@ PyMongo 4.12 brings a number of changes including:
|
||||
:class:`~pymongo.read_preferences.SecondaryPreferred`,
|
||||
:class:`~pymongo.read_preferences.Nearest`. Support for ``hedge`` will be removed in PyMongo 5.0.
|
||||
- Removed PyOpenSSL support from the asynchronous API due to limitations of the CPython asyncio.Protocol SSL implementation.
|
||||
- Allow valid SRV hostnames with less than 3 parts.
|
||||
|
||||
Issues Resolved
|
||||
...............
|
||||
|
||||
@ -878,6 +878,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
if not is_srv:
|
||||
self._init_background()
|
||||
|
||||
@ -1709,6 +1710,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
If this client was created with "connect=False", calling _get_topology
|
||||
launches the connection process in the background.
|
||||
"""
|
||||
if not _IS_SYNC:
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
elif self._loop != asyncio.get_running_loop():
|
||||
raise RuntimeError(
|
||||
"Cannot use AsyncMongoClient in different event loop. AsyncMongoClient uses low-level asyncio APIs that bind it to the event loop it was created on."
|
||||
)
|
||||
if not self._opened:
|
||||
if self._resolve_srv_info["is_srv"]:
|
||||
await self._resolve_srv()
|
||||
@ -2840,7 +2848,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=f"Retrying write attempt number {self._attempt_number}",
|
||||
clientId=self._client.client_id,
|
||||
clientId=self._client._topology_settings._topology_id,
|
||||
commandName=self._operation,
|
||||
operationId=self._operation_id,
|
||||
)
|
||||
|
||||
@ -931,13 +931,15 @@ class Pool:
|
||||
return
|
||||
|
||||
if self.opts.max_idle_time_seconds is not None:
|
||||
close_conns = []
|
||||
async with self.lock:
|
||||
while (
|
||||
self.conns
|
||||
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
|
||||
):
|
||||
conn = self.conns.pop()
|
||||
await conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
close_conns.append(self.conns.pop())
|
||||
for conn in close_conns:
|
||||
await conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
|
||||
while True:
|
||||
async with self.size_cond:
|
||||
@ -957,14 +959,18 @@ class Pool:
|
||||
self._pending += 1
|
||||
incremented = True
|
||||
conn = await self.connect()
|
||||
close_conn = False
|
||||
async with self.lock:
|
||||
# Close connection and return if the pool was reset during
|
||||
# socket creation or while acquiring the pool lock.
|
||||
if self.gen.get_overall() != reference_generation:
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
return
|
||||
self.conns.appendleft(conn)
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
close_conn = True
|
||||
if not close_conn:
|
||||
self.conns.appendleft(conn)
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
if close_conn:
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
return
|
||||
finally:
|
||||
if incremented:
|
||||
# Notify after adding the socket to the pool.
|
||||
@ -1343,17 +1349,20 @@ class Pool:
|
||||
error=ConnectionClosedReason.ERROR,
|
||||
)
|
||||
else:
|
||||
close_conn = False
|
||||
async with self.lock:
|
||||
# Hold the lock to ensure this section does not race with
|
||||
# Pool.reset().
|
||||
if self.stale_generation(conn.generation, conn.service_id):
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
close_conn = True
|
||||
else:
|
||||
conn.update_last_checkin_time()
|
||||
conn.update_is_writable(bool(self.is_writable))
|
||||
self.conns.appendleft(conn)
|
||||
# Notify any threads waiting to create a connection.
|
||||
self._max_connecting_cond.notify()
|
||||
if close_conn:
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
async with self.size_cond:
|
||||
if txn:
|
||||
|
||||
@ -90,14 +90,12 @@ class _SrvResolver:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.__plist = self.__fqdn.split(".")[1:]
|
||||
split_fqdn = self.__fqdn.split(".")
|
||||
self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn
|
||||
except Exception:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||
self.__slen = len(self.__plist)
|
||||
if self.__slen < 2:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
|
||||
|
||||
async def get_options(self) -> Optional[str]:
|
||||
from dns import resolver
|
||||
@ -139,6 +137,10 @@ class _SrvResolver:
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
if self.__fqdn == node[0].lower():
|
||||
raise ConfigurationError(
|
||||
"Invalid SRV host: return address is identical to SRV hostname"
|
||||
)
|
||||
try:
|
||||
nlist = node[0].lower().split(".")[1:][-self.__slen :]
|
||||
except Exception:
|
||||
|
||||
@ -244,8 +244,6 @@ class Topology:
|
||||
# Close servers and clear the pools.
|
||||
for server in self._servers.values():
|
||||
await server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
# Reset the session pool to avoid duplicate sessions in
|
||||
# the child process.
|
||||
self._session_pool.reset()
|
||||
|
||||
@ -876,6 +876,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
if not is_srv:
|
||||
self._init_background()
|
||||
|
||||
@ -1703,6 +1704,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
If this client was created with "connect=False", calling _get_topology
|
||||
launches the connection process in the background.
|
||||
"""
|
||||
if not _IS_SYNC:
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
elif self._loop != asyncio.get_running_loop():
|
||||
raise RuntimeError(
|
||||
"Cannot use MongoClient in different event loop. MongoClient uses low-level asyncio APIs that bind it to the event loop it was created on."
|
||||
)
|
||||
if not self._opened:
|
||||
if self._resolve_srv_info["is_srv"]:
|
||||
self._resolve_srv()
|
||||
@ -2826,7 +2834,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
message=f"Retrying write attempt number {self._attempt_number}",
|
||||
clientId=self._client.client_id,
|
||||
clientId=self._client._topology_settings._topology_id,
|
||||
commandName=self._operation,
|
||||
operationId=self._operation_id,
|
||||
)
|
||||
|
||||
@ -927,13 +927,15 @@ class Pool:
|
||||
return
|
||||
|
||||
if self.opts.max_idle_time_seconds is not None:
|
||||
close_conns = []
|
||||
with self.lock:
|
||||
while (
|
||||
self.conns
|
||||
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
|
||||
):
|
||||
conn = self.conns.pop()
|
||||
conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
close_conns.append(self.conns.pop())
|
||||
for conn in close_conns:
|
||||
conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
|
||||
while True:
|
||||
with self.size_cond:
|
||||
@ -953,14 +955,18 @@ class Pool:
|
||||
self._pending += 1
|
||||
incremented = True
|
||||
conn = self.connect()
|
||||
close_conn = False
|
||||
with self.lock:
|
||||
# Close connection and return if the pool was reset during
|
||||
# socket creation or while acquiring the pool lock.
|
||||
if self.gen.get_overall() != reference_generation:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
return
|
||||
self.conns.appendleft(conn)
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
close_conn = True
|
||||
if not close_conn:
|
||||
self.conns.appendleft(conn)
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
if close_conn:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
return
|
||||
finally:
|
||||
if incremented:
|
||||
# Notify after adding the socket to the pool.
|
||||
@ -1339,17 +1345,20 @@ class Pool:
|
||||
error=ConnectionClosedReason.ERROR,
|
||||
)
|
||||
else:
|
||||
close_conn = False
|
||||
with self.lock:
|
||||
# Hold the lock to ensure this section does not race with
|
||||
# Pool.reset().
|
||||
if self.stale_generation(conn.generation, conn.service_id):
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
close_conn = True
|
||||
else:
|
||||
conn.update_last_checkin_time()
|
||||
conn.update_is_writable(bool(self.is_writable))
|
||||
self.conns.appendleft(conn)
|
||||
# Notify any threads waiting to create a connection.
|
||||
self._max_connecting_cond.notify()
|
||||
if close_conn:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
with self.size_cond:
|
||||
if txn:
|
||||
|
||||
@ -90,14 +90,12 @@ class _SrvResolver:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.__plist = self.__fqdn.split(".")[1:]
|
||||
split_fqdn = self.__fqdn.split(".")
|
||||
self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn
|
||||
except Exception:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||
self.__slen = len(self.__plist)
|
||||
if self.__slen < 2:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
|
||||
|
||||
def get_options(self) -> Optional[str]:
|
||||
from dns import resolver
|
||||
@ -139,6 +137,10 @@ class _SrvResolver:
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
if self.__fqdn == node[0].lower():
|
||||
raise ConfigurationError(
|
||||
"Invalid SRV host: return address is identical to SRV hostname"
|
||||
)
|
||||
try:
|
||||
nlist = node[0].lower().split(".")[1:][-self.__slen :]
|
||||
except Exception:
|
||||
|
||||
@ -244,8 +244,6 @@ class Topology:
|
||||
# Close servers and clear the pools.
|
||||
for server in self._servers.values():
|
||||
server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
# Reset the session pool to avoid duplicate sessions in
|
||||
# the child process.
|
||||
self._session_pool.reset()
|
||||
|
||||
@ -117,6 +117,8 @@ filterwarnings = [
|
||||
"module:unclosed <ssl.SSLSocket:ResourceWarning",
|
||||
"module:unclosed <socket object:ResourceWarning",
|
||||
"module:unclosed transport:ResourceWarning",
|
||||
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/724
|
||||
"module:unclosed event loop:ResourceWarning",
|
||||
# https://github.com/eventlet/eventlet/issues/818
|
||||
"module:please use dns.resolver.Resolver.resolve:DeprecationWarning",
|
||||
# https://github.com/dateutil/dateutil/issues/1314
|
||||
|
||||
36
test/asynchronous/test_async_loop_safety.py
Normal file
36
test/asynchronous/test_async_loop_safety.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2025-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test that the asynchronous API detects event loop changes and fails correctly."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
|
||||
|
||||
class TestClientLoopSafety(unittest.TestCase):
|
||||
def test_client_errors_on_different_loop(self):
|
||||
client = AsyncMongoClient()
|
||||
loop1 = asyncio.new_event_loop()
|
||||
loop1.run_until_complete(client.aconnect())
|
||||
loop2 = asyncio.new_event_loop()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Cannot use AsyncMongoClient in different event loop"
|
||||
):
|
||||
loop2.run_until_complete(client.aconnect())
|
||||
loop1.run_until_complete(client.close())
|
||||
loop1.close()
|
||||
loop2.close()
|
||||
@ -30,6 +30,7 @@ from test.asynchronous import (
|
||||
unittest,
|
||||
)
|
||||
from test.utils_shared import async_wait_until
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pymongo.asynchronous.uri_parser import parse_uri
|
||||
from pymongo.common import validate_read_preference_tags
|
||||
@ -186,12 +187,6 @@ create_tests(TestDNSSharded)
|
||||
|
||||
class TestParsingErrors(AsyncPyMongoTestCase):
|
||||
async def test_invalid_host(self):
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
|
||||
client = self.simple_client("mongodb+srv://mongodb")
|
||||
await client.aconnect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
|
||||
client = self.simple_client("mongodb+srv://mongodb.com")
|
||||
await client.aconnect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||
client = self.simple_client("mongodb+srv://127.0.0.1")
|
||||
await client.aconnect()
|
||||
@ -207,5 +202,93 @@ class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest):
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
|
||||
|
||||
class TestInitialDnsSeedlistDiscovery(AsyncPyMongoTestCase):
|
||||
"""
|
||||
Initial DNS Seedlist Discovery prose tests
|
||||
https://github.com/mongodb/specifications/blob/0a7a8b5/source/initial-dns-seedlist-discovery/tests/README.md#prose-tests
|
||||
"""
|
||||
|
||||
async def run_initial_dns_seedlist_discovery_prose_tests(self, test_cases):
|
||||
for case in test_cases:
|
||||
with patch("dns.asyncresolver.resolve") as mock_resolver:
|
||||
|
||||
async def mock_resolve(query, record_type, *args, **kwargs):
|
||||
mock_srv = MagicMock()
|
||||
mock_srv.target.to_text.return_value = case["mock_target"]
|
||||
return [mock_srv]
|
||||
|
||||
mock_resolver.side_effect = mock_resolve
|
||||
domain = case["query"].split("._tcp.")[1]
|
||||
connection_string = f"mongodb+srv://{domain}"
|
||||
try:
|
||||
await parse_uri(connection_string)
|
||||
except ConfigurationError as e:
|
||||
self.assertIn(case["expected_error"], str(e))
|
||||
else:
|
||||
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
|
||||
|
||||
async def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
|
||||
with patch("dns.asyncresolver.resolve"):
|
||||
await parse_uri("mongodb+srv://localhost/")
|
||||
await parse_uri("mongodb+srv://mongo.local/")
|
||||
|
||||
async def test_2_throw_when_return_address_does_not_end_with_srv_domain(self):
|
||||
test_cases = [
|
||||
{
|
||||
"query": "_mongodb._tcp.localhost",
|
||||
"mock_target": "localhost.mongodb",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.blogs.mongodb.com",
|
||||
"mock_target": "blogs.evil.com",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.blogs.mongo.local",
|
||||
"mock_target": "test_1.evil.com",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
]
|
||||
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
async def test_3_throw_when_return_address_is_identical_to_srv_hostname(self):
|
||||
test_cases = [
|
||||
{
|
||||
"query": "_mongodb._tcp.localhost",
|
||||
"mock_target": "localhost",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.mongo.local",
|
||||
"mock_target": "mongo.local",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
]
|
||||
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
async def test_4_throw_when_return_address_does_not_contain_dot_separating_shared_part_of_domain(
|
||||
self
|
||||
):
|
||||
test_cases = [
|
||||
{
|
||||
"query": "_mongodb._tcp.localhost",
|
||||
"mock_target": "test_1.cluster_1localhost",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.mongo.local",
|
||||
"mock_target": "test_1.my_hostmongo.local",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.blogs.mongodb.com",
|
||||
"mock_target": "cluster.testmongodb.com",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
]
|
||||
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -102,7 +102,14 @@ class TestLogger(AsyncIntegrationTest):
|
||||
await self.db.test.insert_one({"x": "1"})
|
||||
|
||||
async with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["find"],
|
||||
"errorCode": 10107,
|
||||
"errorLabels": ["RetryableWriteError"],
|
||||
},
|
||||
}
|
||||
):
|
||||
with self.assertLogs("pymongo.command", level="DEBUG") as cm:
|
||||
await self.db.test.find_one({"x": "1"})
|
||||
@ -110,7 +117,27 @@ class TestLogger(AsyncIntegrationTest):
|
||||
retry_messages = [
|
||||
r.getMessage() for r in cm.records if "Retrying read attempt" in r.getMessage()
|
||||
]
|
||||
print(retry_messages)
|
||||
self.assertEqual(len(retry_messages), 1)
|
||||
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
@async_client_context.require_retryable_writes
|
||||
async def test_logging_retry_write_attempts(self):
|
||||
async with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"errorCode": 10107,
|
||||
"errorLabels": ["RetryableWriteError"],
|
||||
"failCommands": ["insert"],
|
||||
},
|
||||
}
|
||||
):
|
||||
with self.assertLogs("pymongo.command", level="DEBUG") as cm:
|
||||
await self.db.test.insert_one({"x": "1"})
|
||||
|
||||
retry_messages = [
|
||||
r.getMessage() for r in cm.records if "Retrying write attempt" in r.getMessage()
|
||||
]
|
||||
self.assertEqual(len(retry_messages), 1)
|
||||
|
||||
|
||||
|
||||
@ -222,7 +222,6 @@ class EntityMapUtil:
|
||||
self._listeners: Dict[str, EventListenerUtil] = {}
|
||||
self._session_lsids: Dict[str, Mapping[str, Any]] = {}
|
||||
self.test: UnifiedSpecTestMixinV1 = test_class
|
||||
self._cluster_time: Mapping[str, Any] = {}
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._entities
|
||||
@ -421,13 +420,11 @@ class EntityMapUtil:
|
||||
# session has been closed.
|
||||
return self._session_lsids[session_name]
|
||||
|
||||
async def advance_cluster_times(self) -> None:
|
||||
async def advance_cluster_times(self, cluster_time) -> None:
|
||||
"""Manually synchronize entities when desired"""
|
||||
if not self._cluster_time:
|
||||
self._cluster_time = (await self.test.client.admin.command("ping")).get("$clusterTime")
|
||||
for entity in self._entities.values():
|
||||
if isinstance(entity, AsyncClientSession) and self._cluster_time:
|
||||
entity.advance_cluster_time(self._cluster_time)
|
||||
if isinstance(entity, AsyncClientSession) and cluster_time:
|
||||
entity.advance_cluster_time(cluster_time)
|
||||
|
||||
|
||||
class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
@ -1044,7 +1041,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
|
||||
async def _testOperation_createEntities(self, spec):
|
||||
await self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri)
|
||||
await self.entity_map.advance_cluster_times()
|
||||
await self.entity_map.advance_cluster_times(self._cluster_time)
|
||||
|
||||
def _testOperation_assertSessionTransactionState(self, spec):
|
||||
session = self.entity_map[spec["session"]]
|
||||
@ -1443,11 +1440,12 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
await self.entity_map.create_entities_from_spec(
|
||||
self.TEST_SPEC.get("createEntities", []), uri=uri
|
||||
)
|
||||
self._cluster_time = None
|
||||
# process initialData
|
||||
if "initialData" in self.TEST_SPEC:
|
||||
await self.insert_initial_data(self.TEST_SPEC["initialData"])
|
||||
self._cluster_time = (await self.client.admin.command("ping")).get("$clusterTime")
|
||||
await self.entity_map.advance_cluster_times()
|
||||
self._cluster_time = self.client._topology.max_cluster_time()
|
||||
await self.entity_map.advance_cluster_times(self._cluster_time)
|
||||
|
||||
if "expectLogMessages" in spec:
|
||||
expect_log_messages = spec["expectLogMessages"]
|
||||
|
||||
176
test/csot/waitQueueTimeout.json
Normal file
176
test/csot/waitQueueTimeout.json
Normal file
@ -0,0 +1,176 @@
|
||||
{
|
||||
"description": "WaitQueueTimeoutError does not clear the pool",
|
||||
"schemaVersion": "1.9",
|
||||
"runOnRequirements": [
|
||||
{
|
||||
"minServerVersion": "4.4",
|
||||
"topologies": [
|
||||
"single",
|
||||
"replicaset",
|
||||
"sharded"
|
||||
]
|
||||
}
|
||||
],
|
||||
"createEntities": [
|
||||
{
|
||||
"client": {
|
||||
"id": "failPointClient",
|
||||
"useMultipleMongoses": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"client": {
|
||||
"id": "client",
|
||||
"uriOptions": {
|
||||
"maxPoolSize": 1,
|
||||
"appname": "waitQueueTimeoutErrorTest"
|
||||
},
|
||||
"useMultipleMongoses": false,
|
||||
"observeEvents": [
|
||||
"commandStartedEvent",
|
||||
"poolClearedEvent"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"database": {
|
||||
"id": "database",
|
||||
"client": "client",
|
||||
"databaseName": "test"
|
||||
}
|
||||
}
|
||||
],
|
||||
"tests": [
|
||||
{
|
||||
"description": "WaitQueueTimeoutError does not clear the pool",
|
||||
"operations": [
|
||||
{
|
||||
"name": "failPoint",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"client": "failPointClient",
|
||||
"failPoint": {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {
|
||||
"times": 1
|
||||
},
|
||||
"data": {
|
||||
"failCommands": [
|
||||
"ping"
|
||||
],
|
||||
"blockConnection": true,
|
||||
"blockTimeMS": 500,
|
||||
"appName": "waitQueueTimeoutErrorTest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "createEntities",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"entities": [
|
||||
{
|
||||
"thread": {
|
||||
"id": "thread0"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "runOnThread",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"thread": "thread0",
|
||||
"operation": {
|
||||
"name": "runCommand",
|
||||
"object": "database",
|
||||
"arguments": {
|
||||
"command": {
|
||||
"ping": 1
|
||||
},
|
||||
"commandName": "ping"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "waitForEvent",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"client": "client",
|
||||
"event": {
|
||||
"commandStartedEvent": {
|
||||
"commandName": "ping"
|
||||
}
|
||||
},
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "runCommand",
|
||||
"object": "database",
|
||||
"arguments": {
|
||||
"timeoutMS": 100,
|
||||
"command": {
|
||||
"hello": 1
|
||||
},
|
||||
"commandName": "hello"
|
||||
},
|
||||
"expectError": {
|
||||
"isTimeoutError": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "waitForThread",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"thread": "thread0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "runCommand",
|
||||
"object": "database",
|
||||
"arguments": {
|
||||
"command": {
|
||||
"hello": 1
|
||||
},
|
||||
"commandName": "hello"
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client",
|
||||
"eventType": "command",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"commandName": "ping",
|
||||
"databaseName": "test",
|
||||
"command": {
|
||||
"ping": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"commandName": "hello",
|
||||
"databaseName": "test",
|
||||
"command": {
|
||||
"hello": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"client": "client",
|
||||
"eventType": "cmap",
|
||||
"events": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -30,6 +30,7 @@ from test import (
|
||||
unittest,
|
||||
)
|
||||
from test.utils_shared import wait_until
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pymongo.common import validate_read_preference_tags
|
||||
from pymongo.errors import ConfigurationError
|
||||
@ -184,12 +185,6 @@ create_tests(TestDNSSharded)
|
||||
|
||||
class TestParsingErrors(PyMongoTestCase):
|
||||
def test_invalid_host(self):
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
|
||||
client = self.simple_client("mongodb+srv://mongodb")
|
||||
client._connect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
|
||||
client = self.simple_client("mongodb+srv://mongodb.com")
|
||||
client._connect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||
client = self.simple_client("mongodb+srv://127.0.0.1")
|
||||
client._connect()
|
||||
@ -205,5 +200,93 @@ class TestCaseInsensitive(IntegrationTest):
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
|
||||
|
||||
class TestInitialDnsSeedlistDiscovery(PyMongoTestCase):
|
||||
"""
|
||||
Initial DNS Seedlist Discovery prose tests
|
||||
https://github.com/mongodb/specifications/blob/0a7a8b5/source/initial-dns-seedlist-discovery/tests/README.md#prose-tests
|
||||
"""
|
||||
|
||||
def run_initial_dns_seedlist_discovery_prose_tests(self, test_cases):
|
||||
for case in test_cases:
|
||||
with patch("dns.resolver.resolve") as mock_resolver:
|
||||
|
||||
def mock_resolve(query, record_type, *args, **kwargs):
|
||||
mock_srv = MagicMock()
|
||||
mock_srv.target.to_text.return_value = case["mock_target"]
|
||||
return [mock_srv]
|
||||
|
||||
mock_resolver.side_effect = mock_resolve
|
||||
domain = case["query"].split("._tcp.")[1]
|
||||
connection_string = f"mongodb+srv://{domain}"
|
||||
try:
|
||||
parse_uri(connection_string)
|
||||
except ConfigurationError as e:
|
||||
self.assertIn(case["expected_error"], str(e))
|
||||
else:
|
||||
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
|
||||
|
||||
def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
|
||||
with patch("dns.resolver.resolve"):
|
||||
parse_uri("mongodb+srv://localhost/")
|
||||
parse_uri("mongodb+srv://mongo.local/")
|
||||
|
||||
def test_2_throw_when_return_address_does_not_end_with_srv_domain(self):
|
||||
test_cases = [
|
||||
{
|
||||
"query": "_mongodb._tcp.localhost",
|
||||
"mock_target": "localhost.mongodb",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.blogs.mongodb.com",
|
||||
"mock_target": "blogs.evil.com",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.blogs.mongo.local",
|
||||
"mock_target": "test_1.evil.com",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
]
|
||||
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
def test_3_throw_when_return_address_is_identical_to_srv_hostname(self):
|
||||
test_cases = [
|
||||
{
|
||||
"query": "_mongodb._tcp.localhost",
|
||||
"mock_target": "localhost",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.mongo.local",
|
||||
"mock_target": "mongo.local",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
]
|
||||
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
def test_4_throw_when_return_address_does_not_contain_dot_separating_shared_part_of_domain(
|
||||
self
|
||||
):
|
||||
test_cases = [
|
||||
{
|
||||
"query": "_mongodb._tcp.localhost",
|
||||
"mock_target": "test_1.cluster_1localhost",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.mongo.local",
|
||||
"mock_target": "test_1.my_hostmongo.local",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
{
|
||||
"query": "_mongodb._tcp.blogs.mongodb.com",
|
||||
"mock_target": "cluster.testmongodb.com",
|
||||
"expected_error": "Invalid SRV host",
|
||||
},
|
||||
]
|
||||
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -101,7 +101,14 @@ class TestLogger(IntegrationTest):
|
||||
self.db.test.insert_one({"x": "1"})
|
||||
|
||||
with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["find"],
|
||||
"errorCode": 10107,
|
||||
"errorLabels": ["RetryableWriteError"],
|
||||
},
|
||||
}
|
||||
):
|
||||
with self.assertLogs("pymongo.command", level="DEBUG") as cm:
|
||||
self.db.test.find_one({"x": "1"})
|
||||
@ -109,7 +116,27 @@ class TestLogger(IntegrationTest):
|
||||
retry_messages = [
|
||||
r.getMessage() for r in cm.records if "Retrying read attempt" in r.getMessage()
|
||||
]
|
||||
print(retry_messages)
|
||||
self.assertEqual(len(retry_messages), 1)
|
||||
|
||||
@client_context.require_failCommand_fail_point
|
||||
@client_context.require_retryable_writes
|
||||
def test_logging_retry_write_attempts(self):
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"errorCode": 10107,
|
||||
"errorLabels": ["RetryableWriteError"],
|
||||
"failCommands": ["insert"],
|
||||
},
|
||||
}
|
||||
):
|
||||
with self.assertLogs("pymongo.command", level="DEBUG") as cm:
|
||||
self.db.test.insert_one({"x": "1"})
|
||||
|
||||
retry_messages = [
|
||||
r.getMessage() for r in cm.records if "Retrying write attempt" in r.getMessage()
|
||||
]
|
||||
self.assertEqual(len(retry_messages), 1)
|
||||
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ from urllib.parse import quote_plus
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from bson.binary import JAVA_LEGACY
|
||||
from pymongo import ReadPreference
|
||||
|
||||
@ -221,7 +221,6 @@ class EntityMapUtil:
|
||||
self._listeners: Dict[str, EventListenerUtil] = {}
|
||||
self._session_lsids: Dict[str, Mapping[str, Any]] = {}
|
||||
self.test: UnifiedSpecTestMixinV1 = test_class
|
||||
self._cluster_time: Mapping[str, Any] = {}
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._entities
|
||||
@ -420,13 +419,11 @@ class EntityMapUtil:
|
||||
# session has been closed.
|
||||
return self._session_lsids[session_name]
|
||||
|
||||
def advance_cluster_times(self) -> None:
|
||||
def advance_cluster_times(self, cluster_time) -> None:
|
||||
"""Manually synchronize entities when desired"""
|
||||
if not self._cluster_time:
|
||||
self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime")
|
||||
for entity in self._entities.values():
|
||||
if isinstance(entity, ClientSession) and self._cluster_time:
|
||||
entity.advance_cluster_time(self._cluster_time)
|
||||
if isinstance(entity, ClientSession) and cluster_time:
|
||||
entity.advance_cluster_time(cluster_time)
|
||||
|
||||
|
||||
class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
@ -1035,7 +1032,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
|
||||
def _testOperation_createEntities(self, spec):
|
||||
self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri)
|
||||
self.entity_map.advance_cluster_times()
|
||||
self.entity_map.advance_cluster_times(self._cluster_time)
|
||||
|
||||
def _testOperation_assertSessionTransactionState(self, spec):
|
||||
session = self.entity_map[spec["session"]]
|
||||
@ -1428,11 +1425,12 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
self._uri = uri
|
||||
self.entity_map = EntityMapUtil(self)
|
||||
self.entity_map.create_entities_from_spec(self.TEST_SPEC.get("createEntities", []), uri=uri)
|
||||
self._cluster_time = None
|
||||
# process initialData
|
||||
if "initialData" in self.TEST_SPEC:
|
||||
self.insert_initial_data(self.TEST_SPEC["initialData"])
|
||||
self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime")
|
||||
self.entity_map.advance_cluster_times()
|
||||
self._cluster_time = self.client._topology.max_cluster_time()
|
||||
self.entity_map.advance_cluster_times(self._cluster_time)
|
||||
|
||||
if "expectLogMessages" in spec:
|
||||
expect_log_messages = spec["expectLogMessages"]
|
||||
|
||||
@ -133,6 +133,7 @@ replacements = {
|
||||
"async_joinall": "joinall",
|
||||
"_async_create_connection": "_create_connection",
|
||||
"pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts",
|
||||
"dns.asyncresolver.resolve": "dns.resolver.resolve",
|
||||
}
|
||||
|
||||
docstring_replacements: dict[tuple[str, str], str] = {
|
||||
@ -179,7 +180,12 @@ gridfs_files = [
|
||||
|
||||
def async_only_test(f: str) -> bool:
|
||||
"""Return True for async tests that should not be converted to sync."""
|
||||
return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"]
|
||||
return f in [
|
||||
"test_locks.py",
|
||||
"test_concurrency.py",
|
||||
"test_async_cancellation.py",
|
||||
"test_async_loop_safety.py",
|
||||
]
|
||||
|
||||
|
||||
test_files = [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user