From 01f659cd8bd6ae970b044c3043c9ed2ca6d89bf4 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 30 Jan 2025 12:34:59 -0800 Subject: [PATCH 01/28] PYTHON-5071 Use one event loop for all asyncio tests (#2086) --- test/__init__.py | 121 +++++++++++++++++++++++--------- test/asynchronous/__init__.py | 125 ++++++++++++++++++++++++---------- 2 files changed, 176 insertions(+), 70 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index d3a63db2d..b49eee99a 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import gc +import inspect import logging import multiprocessing import os @@ -30,6 +31,33 @@ import traceback import unittest import warnings from asyncio import iscoroutinefunction + +from pymongo.uri_parser import parse_uri + +try: + import ipaddress + + HAVE_IPADDRESS = True +except ImportError: + HAVE_IPADDRESS = False +from contextlib import contextmanager +from functools import partial, wraps +from typing import Any, Callable, Dict, Generator, overload +from unittest import SkipTest +from urllib.parse import quote_plus + +import pymongo +import pymongo.errors +from bson.son import SON +from pymongo.common import partition_node +from pymongo.hello import HelloCompat +from pymongo.server_api import ServerApi +from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] +from pymongo.synchronous.database import Database +from pymongo.synchronous.mongo_client import MongoClient + +sys.path[0:0] = [""] + from test.helpers import ( COMPRESSORS, IS_SRV, @@ -52,31 +80,7 @@ from test.helpers import ( sanitize_cmd, sanitize_reply, ) - -from pymongo.uri_parser import parse_uri - -try: - import ipaddress - - HAVE_IPADDRESS = True -except ImportError: - HAVE_IPADDRESS = False -from contextlib import contextmanager -from functools import partial, wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, overload -from unittest import SkipTest -from urllib.parse import quote_plus - -import pymongo -import pymongo.errors -from bson.son import SON -from pymongo.common import partition_node -from pymongo.hello import HelloCompat -from pymongo.server_api import ServerApi -from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous.database import Database -from pymongo.synchronous.mongo_client import MongoClient _IS_SYNC = True @@ -863,18 +867,66 @@ class ClientContext: # Reusable client context client_context = ClientContext() +# Global event loop for async tests. +LOOP = None -def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif client_context.client is not None: - client_context.client.close() - client_context.client = None - client_context._init_client() + +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP class PyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # An async TestCase that uses a single event loop for all tests. + # Inspired by TestCase. + def setUp(self): + pass + + def tearDown(self): + pass + + def addCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def _callSetUp(self): + self.setUp() + self._callAsync(self.setUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.tearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return get_loop().run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return get_loop().run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) + def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase): @client_context.require_connection def setUp(self) -> None: - if not _IS_SYNC: - reset_client_context() if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): @@ -1186,6 +1236,9 @@ class MockClientTest(UnitTest): def setup(): + if not _IS_SYNC: + # Set up the event loop. + get_loop() client_context.init() warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 73e282474..76fae407d 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import gc +import inspect import logging import multiprocessing import os @@ -30,6 +31,33 @@ import traceback import unittest import warnings from asyncio import iscoroutinefunction + +from pymongo.uri_parser import parse_uri + +try: + import ipaddress + + HAVE_IPADDRESS = True +except ImportError: + HAVE_IPADDRESS = False +from contextlib import asynccontextmanager, contextmanager +from functools import partial, wraps +from typing import Any, Callable, Dict, Generator, overload +from unittest import SkipTest +from urllib.parse import quote_plus + +import pymongo +import pymongo.errors +from bson.son import SON +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.common import partition_node +from pymongo.hello import HelloCompat +from pymongo.server_api import ServerApi +from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] + +sys.path[0:0] = [""] + from test.helpers import ( COMPRESSORS, IS_SRV, @@ -52,31 +80,7 @@ from test.helpers import ( sanitize_cmd, sanitize_reply, ) - -from pymongo.uri_parser import parse_uri - -try: - import ipaddress - - HAVE_IPADDRESS = True -except ImportError: - HAVE_IPADDRESS = False -from contextlib import asynccontextmanager, contextmanager -from functools import partial, wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, overload -from unittest import SkipTest -from urllib.parse import quote_plus - -import pymongo -import pymongo.errors -from bson.son import SON -from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.common import partition_node -from pymongo.hello import HelloCompat -from pymongo.server_api import ServerApi -from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] _IS_SYNC = False @@ -865,18 +869,66 @@ class AsyncClientContext: # Reusable client context async_client_context = AsyncClientContext() - -async def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif async_client_context.client is not None: - await async_client_context.client.close() - async_client_context.client = None - await async_client_context._init_client() +# Global event loop for async tests. +LOOP = None -class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP + + +class AsyncPyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # An async TestCase that uses a single event loop for all tests. + # Inspired by IsolatedAsyncioTestCase. + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass + + def addAsyncCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def _callSetUp(self): + self.setUp() + self._callAsync(self.asyncSetUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.asyncTearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return get_loop().run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return get_loop().run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) + def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): @async_client_context.require_connection async def asyncSetUp(self) -> None: - if not _IS_SYNC: - await reset_client_context() if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): @@ -1204,6 +1254,9 @@ class AsyncMockClientTest(AsyncUnitTest): async def async_setup(): + if not _IS_SYNC: + # Set up the event loop. + get_loop() await async_client_context.init() warnings.resetwarnings() warnings.simplefilter("always") From 94b9a54c8ef829307c04d984858386a4476986e8 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:10:01 -0800 Subject: [PATCH 02/28] PYTHON-5083 Convert test.test_gridfs_spec to async (#2104) --- test/asynchronous/test_gridfs_spec.py | 39 +++++++++++++++++++++++++++ test/test_gridfs_spec.py | 8 +++++- tools/synchro.py | 1 + 3 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_gridfs_spec.py diff --git a/test/asynchronous/test_gridfs_spec.py b/test/asynchronous/test_gridfs_spec.py new file mode 100644 index 000000000..f3dc14fbd --- /dev/null +++ b/test/asynchronous/test_gridfs_spec.py @@ -0,0 +1,39 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the AsyncGridFS unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs") + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_gridfs_spec.py b/test/test_gridfs_spec.py index 6840b6ae0..e84e19725 100644 --- a/test/test_gridfs_spec.py +++ b/test/test_gridfs_spec.py @@ -17,14 +17,20 @@ from __future__ import annotations import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "gridfs") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/tools/synchro.py b/tools/synchro.py index 897e5e801..e20a8facd 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -207,6 +207,7 @@ converted_tests = [ "test_data_lake.py", "test_encryption.py", "test_grid_file.py", + "test_gridfs_spec.py", "test_logger.py", "test_monitoring.py", "test_raw_bson.py", From 2909e1fc8a1937b3f2ae50f9df17521b623688d1 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 30 Jan 2025 16:15:18 -0500 Subject: [PATCH 03/28] PYTHON-5085 - Convert test.test_index_management to async (#2101) --- test/asynchronous/test_index_management.py | 383 +++++++++++++++++++++ test/test_index_management.py | 59 ++-- tools/synchro.py | 1 + 3 files changed, 418 insertions(+), 25 deletions(-) create mode 100644 test/asynchronous/test_index_management.py diff --git a/test/asynchronous/test_index_management.py b/test/asynchronous/test_index_management.py new file mode 100644 index 000000000..2920c48b2 --- /dev/null +++ b/test/asynchronous/test_index_management.py @@ -0,0 +1,383 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the auth spec tests.""" +from __future__ import annotations + +import asyncio +import os +import pathlib +import sys +import time +import uuid +from typing import Any, Mapping + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils import AllowListEventListener, OvertCommandListener + +from pymongo.errors import OperationFailure +from pymongo.operations import SearchIndexModel +from pymongo.read_concern import ReadConcern +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + +pytestmark = pytest.mark.index_management + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management") + +_NAME = "test-search-index" + + +class TestCreateSearchIndex(AsyncIntegrationTest): + async def test_inputs(self): + if not os.environ.get("TEST_INDEX_MANAGEMENT"): + raise unittest.SkipTest("Skipping index management tests") + listener = AllowListEventListener("createSearchIndexes") + client = self.simple_client(event_listeners=[listener]) + coll = client.test.test + await coll.drop() + definition = dict(mappings=dict(dynamic=True)) + model_kwarg_list: list[Mapping[str, Any]] = [ + dict(definition=definition, name=None), + dict(definition=definition, name="test"), + ] + for model_kwargs in model_kwarg_list: + model = SearchIndexModel(**model_kwargs) + with self.assertRaises(OperationFailure): + await coll.create_search_index(model) + with self.assertRaises(OperationFailure): + await coll.create_search_index(model_kwargs) + + listener.reset() + with self.assertRaises(OperationFailure): + await coll.create_search_index({"definition": definition, "arbitraryOption": 1}) + self.assertEqual( + {"definition": definition, "arbitraryOption": 1}, + listener.events[0].command["indexes"][0], + ) + + listener.reset() + with self.assertRaises(OperationFailure): + await coll.create_search_index({"definition": definition, "type": "search"}) + self.assertEqual( + {"definition": definition, "type": "search"}, listener.events[0].command["indexes"][0] + ) + + +class SearchIndexIntegrationBase(AsyncPyMongoTestCase): + db_name = "test_search_index_base" + + @classmethod + def setUpClass(cls) -> None: + if not os.environ.get("TEST_INDEX_MANAGEMENT"): + raise unittest.SkipTest("Skipping index management tests") + cls.url = os.environ.get("MONGODB_URI") + cls.username = os.environ["DB_USER"] + cls.password = os.environ["DB_PASSWORD"] + cls.listener = OvertCommandListener() + + async def asyncSetUp(self) -> None: + self.client = self.simple_client( + self.url, + username=self.username, + password=self.password, + event_listeners=[self.listener], + ) + await self.client.drop_database(_NAME) + self.db = self.client[self.db_name] + + async def asyncTearDown(self): + await self.client.drop_database(_NAME) + + async def wait_for_ready(self, coll, name=_NAME, predicate=None): + """Wait for a search index to be ready.""" + indices: list[Mapping[str, Any]] = [] + if predicate is None: + predicate = lambda index: index.get("queryable") is True + + while True: + indices = await (await coll.list_search_indexes(name)).to_list() + if len(indices) and predicate(indices[0]): + return indices[0] + await asyncio.sleep(5) + + +class TestSearchIndexIntegration(SearchIndexIntegrationBase): + db_name = "test_search_index" + + async def test_comment_field(self): + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Create a new search index on ``coll0`` that implicitly passes its type. + search_definition = {"mappings": {"dynamic": False}} + self.listener.reset() + implicit_search_resp = await coll0.create_search_index( + model={"name": _NAME + "-implicit", "definition": search_definition}, comment="foo" + ) + event = self.listener.events[0] + self.assertEqual(event.command["comment"], "foo") + + # Get the index definition. + self.listener.reset() + await (await coll0.list_search_indexes(name=implicit_search_resp, comment="foo")).next() + event = self.listener.events[0] + self.assertEqual(event.command["comment"], "foo") + + +class TestSearchIndexProse(SearchIndexIntegrationBase): + db_name = "test_search_index_prose" + + async def test_case_1(self): + """Driver can successfully create and list search indexes.""" + + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + + # Create a new search index on ``coll0`` with the ``createSearchIndex`` helper. Use the following definition: + model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}} + await coll0.insert_one({}) + resp = await coll0.create_search_index(model) + + # Assert that the command returns the name of the index: ``"test-search-index"``. + self.assertEqual(resp, _NAME) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``: + # An index with the ``name`` of ``test-search-index`` is present and the index has a field ``queryable`` with a value of ``true``. + index = await self.wait_for_ready(coll0) + + # . Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }`` + self.assertIn("latestDefinition", index) + self.assertEqual(index["latestDefinition"], model["definition"]) + + async def test_case_2(self): + """Driver can successfully create multiple indexes in batch.""" + + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Create two new search indexes on ``coll0`` with the ``createSearchIndexes`` helper. + name1 = "test-search-index-1" + name2 = "test-search-index-2" + definition = {"mappings": {"dynamic": False}} + index_definitions: list[dict[str, Any]] = [ + {"name": name1, "definition": definition}, + {"name": name2, "definition": definition}, + ] + await coll0.create_search_indexes( + [SearchIndexModel(i["definition"], i["name"]) for i in index_definitions] + ) + + # .Assert that the command returns an array containing the new indexes' names: ``["test-search-index-1", "test-search-index-2"]``. + indices = await (await coll0.list_search_indexes()).to_list() + names = [i["name"] for i in indices] + self.assertIn(name1, names) + self.assertIn(name2, names) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied. + # An index with the ``name`` of ``test-search-index-1`` is present and index has a field ``queryable`` with the value of ``true``. Store result in ``index1``. + # An index with the ``name`` of ``test-search-index-2`` is present and index has a field ``queryable`` with the value of ``true``. Store result in ``index2``. + index1 = await self.wait_for_ready(coll0, name1) + index2 = await self.wait_for_ready(coll0, name2) + + # Assert that ``index1`` and ``index2`` have the property ``latestDefinition`` whose value is ``{ "mappings" : { "dynamic" : false } }`` + for index in [index1, index2]: + self.assertIn("latestDefinition", index) + self.assertEqual(index["latestDefinition"], definition) + + async def test_case_3(self): + """Driver can successfully drop search indexes.""" + + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Create a new search index on ``coll0``. + model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}} + resp = await coll0.create_search_index(model) + + # Assert that the command returns the name of the index: ``"test-search-index"``. + self.assertEqual(resp, "test-search-index") + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied: + # An index with the ``name`` of ``test-search-index`` is present and index has a field ``queryable`` with the value of ``true``. + await self.wait_for_ready(coll0) + + # Run a ``dropSearchIndex`` on ``coll0``, using ``test-search-index`` for the name. + await coll0.drop_search_index(_NAME) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until ``listSearchIndexes`` returns an empty array. + t0 = time.time() + while True: + indices = await (await coll0.list_search_indexes()).to_list() + if indices: + break + if (time.time() - t0) / 60 > 5: + raise TimeoutError("Timed out waiting for index deletion") + await asyncio.sleep(5) + + async def test_case_4(self): + """Driver can update a search index.""" + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Create a new search index on ``coll0``. + model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}} + resp = await coll0.create_search_index(model) + + # Assert that the command returns the name of the index: ``"test-search-index"``. + self.assertEqual(resp, _NAME) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied: + # An index with the ``name`` of ``test-search-index`` is present and index has a field ``queryable`` with the value of ``true``. + await self.wait_for_ready(coll0) + + # Run a ``updateSearchIndex`` on ``coll0``. + # Assert that the command does not error and the server responds with a success. + model2: dict[str, Any] = {"name": _NAME, "definition": {"mappings": {"dynamic": True}}} + await coll0.update_search_index(_NAME, model2["definition"]) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied: + # An index with the ``name`` of ``test-search-index`` is present. This index is referred to as ``index``. + # The index has a field ``queryable`` with a value of ``true`` and has a field ``status`` with the value of ``READY``. + predicate = lambda index: index.get("queryable") is True and index.get("status") == "READY" + await self.wait_for_ready(coll0, predicate=predicate) + + # Assert that an index is present with the name ``test-search-index`` and the definition has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': true } }``. + index = (await (await coll0.list_search_indexes(_NAME)).to_list())[0] + self.assertIn("latestDefinition", index) + self.assertEqual(index["latestDefinition"], model2["definition"]) + + async def test_case_5(self): + """``dropSearchIndex`` suppresses namespace not found errors.""" + # Create a driver-side collection object for a randomly generated collection name. Do not create this collection on the server. + coll0 = self.db[f"col{uuid.uuid4()}"] + + # Run a ``dropSearchIndex`` command and assert that no error is thrown. + await coll0.drop_search_index("foo") + + async def test_case_6(self): + """Driver can successfully create and list search indexes with non-default readConcern and writeConcern.""" + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Apply a write concern ``WriteConcern(w=1)`` and a read concern with ``ReadConcern(level="majority")`` to ``coll0``. + coll0 = coll0.with_options( + write_concern=WriteConcern(w="1"), read_concern=ReadConcern(level="majority") + ) + + # Create a new search index on ``coll0`` with the ``createSearchIndex`` helper. + name = "test-search-index-case6" + model = {"name": name, "definition": {"mappings": {"dynamic": False}}} + resp = await coll0.create_search_index(model) + + # Assert that the command returns the name of the index: ``"test-search-index-case6"``. + self.assertEqual(resp, name) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``: + # - An index with the ``name`` of ``test-search-index-case6`` is present and the index has a field ``queryable`` with a value of ``true``. + index = await self.wait_for_ready(coll0, name) + + # Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }`` + self.assertIn("latestDefinition", index) + self.assertEqual(index["latestDefinition"], model["definition"]) + + async def test_case_7(self): + """Driver handles index types.""" + + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Use these search and vector search definitions for indexes. + search_definition = {"mappings": {"dynamic": False}} + vector_search_definition = { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean", + }, + ] + } + + # Create a new search index on ``coll0`` that implicitly passes its type. + implicit_search_resp = await coll0.create_search_index( + model={"name": _NAME + "-implicit", "definition": search_definition} + ) + + # Get the index definition. + resp = await (await coll0.list_search_indexes(name=implicit_search_resp)).next() + + # Assert that the index model contains the correct index type: ``"search"``. + self.assertEqual(resp["type"], "search") + + # Create a new search index on ``coll0`` that explicitly passes its type. + explicit_search_resp = await coll0.create_search_index( + model={"name": _NAME + "-explicit", "type": "search", "definition": search_definition} + ) + + # Get the index definition. + resp = await (await coll0.list_search_indexes(name=explicit_search_resp)).next() + + # Assert that the index model contains the correct index type: ``"search"``. + self.assertEqual(resp["type"], "search") + + # Create a new vector search index on ``coll0`` that explicitly passes its type. + explicit_vector_resp = await coll0.create_search_index( + model={ + "name": _NAME + "-vector", + "type": "vectorSearch", + "definition": vector_search_definition, + } + ) + + # Get the index definition. + resp = await (await coll0.list_search_indexes(name=explicit_vector_resp)).next() + + # Assert that the index model contains the correct index type: ``"vectorSearch"``. + self.assertEqual(resp["type"], "vectorSearch") + + # Catch the error raised when trying to create a vector search index without specifying the type + with self.assertRaises(OperationFailure) as e: + await coll0.create_search_index( + model={"name": _NAME + "-error", "definition": vector_search_definition} + ) + self.assertIn("Attribute mappings missing.", e.exception.details["errmsg"]) + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_index_management.py b/test/test_index_management.py index 6ca726e2e..5135e43f1 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -15,7 +15,9 @@ """Run the auth spec tests.""" from __future__ import annotations +import asyncio import os +import pathlib import sys import time import uuid @@ -27,16 +29,22 @@ sys.path[0:0] = [""] from test import IntegrationTest, PyMongoTestCase, unittest from test.unified_format import generate_test_classes -from test.utils import AllowListEventListener, EventListener, OvertCommandListener +from test.utils import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern +_IS_SYNC = True + pytestmark = pytest.mark.index_management -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "index_management") +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management") _NAME = "test-search-index" @@ -82,23 +90,25 @@ class SearchIndexIntegrationBase(PyMongoTestCase): @classmethod def setUpClass(cls) -> None: - super().setUpClass() if not os.environ.get("TEST_INDEX_MANAGEMENT"): raise unittest.SkipTest("Skipping index management tests") - url = os.environ.get("MONGODB_URI") - username = os.environ["DB_USER"] - password = os.environ["DB_PASSWORD"] - cls.listener = listener = OvertCommandListener() - cls.client = cls.unmanaged_simple_client( - url, username=username, password=password, event_listeners=[listener] - ) - cls.client.drop_database(_NAME) - cls.db = cls.client[cls.db_name] + cls.url = os.environ.get("MONGODB_URI") + cls.username = os.environ["DB_USER"] + cls.password = os.environ["DB_PASSWORD"] + cls.listener = OvertCommandListener() - @classmethod - def tearDownClass(cls): - cls.client.drop_database(_NAME) - cls.client.close() + def setUp(self) -> None: + self.client = self.simple_client( + self.url, + username=self.username, + password=self.password, + event_listeners=[self.listener], + ) + self.client.drop_database(_NAME) + self.db = self.client[self.db_name] + + def tearDown(self): + self.client.drop_database(_NAME) def wait_for_ready(self, coll, name=_NAME, predicate=None): """Wait for a search index to be ready.""" @@ -107,10 +117,9 @@ class SearchIndexIntegrationBase(PyMongoTestCase): predicate = lambda index: index.get("queryable") is True while True: - indices = list(coll.list_search_indexes(name)) + indices = (coll.list_search_indexes(name)).to_list() if len(indices) and predicate(indices[0]): return indices[0] - break time.sleep(5) @@ -133,7 +142,7 @@ class TestSearchIndexIntegration(SearchIndexIntegrationBase): # Get the index definition. self.listener.reset() - coll0.list_search_indexes(name=implicit_search_resp, comment="foo").next() + (coll0.list_search_indexes(name=implicit_search_resp, comment="foo")).next() event = self.listener.events[0] self.assertEqual(event.command["comment"], "foo") @@ -183,7 +192,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase): ) # .Assert that the command returns an array containing the new indexes' names: ``["test-search-index-1", "test-search-index-2"]``. - indices = list(coll0.list_search_indexes()) + indices = (coll0.list_search_indexes()).to_list() names = [i["name"] for i in indices] self.assertIn(name1, names) self.assertIn(name2, names) @@ -223,7 +232,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase): # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until ``listSearchIndexes`` returns an empty array. t0 = time.time() while True: - indices = list(coll0.list_search_indexes()) + indices = (coll0.list_search_indexes()).to_list() if indices: break if (time.time() - t0) / 60 > 5: @@ -259,7 +268,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase): self.wait_for_ready(coll0, predicate=predicate) # Assert that an index is present with the name ``test-search-index`` and the definition has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': true } }``. - index = list(coll0.list_search_indexes(_NAME))[0] + index = ((coll0.list_search_indexes(_NAME)).to_list())[0] self.assertIn("latestDefinition", index) self.assertEqual(index["latestDefinition"], model2["definition"]) @@ -324,7 +333,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase): ) # Get the index definition. - resp = coll0.list_search_indexes(name=implicit_search_resp).next() + resp = (coll0.list_search_indexes(name=implicit_search_resp)).next() # Assert that the index model contains the correct index type: ``"search"``. self.assertEqual(resp["type"], "search") @@ -335,7 +344,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase): ) # Get the index definition. - resp = coll0.list_search_indexes(name=explicit_search_resp).next() + resp = (coll0.list_search_indexes(name=explicit_search_resp)).next() # Assert that the index model contains the correct index type: ``"search"``. self.assertEqual(resp["type"], "search") @@ -350,7 +359,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase): ) # Get the index definition. - resp = coll0.list_search_indexes(name=explicit_vector_resp).next() + resp = (coll0.list_search_indexes(name=explicit_vector_resp)).next() # Assert that the index model contains the correct index type: ``"vectorSearch"``. self.assertEqual(resp["type"], "vectorSearch") diff --git a/tools/synchro.py b/tools/synchro.py index e20a8facd..08281c73d 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -206,6 +206,7 @@ converted_tests = [ "test_database.py", "test_data_lake.py", "test_encryption.py", + "test_index_management.py", "test_grid_file.py", "test_gridfs_spec.py", "test_logger.py", From 0a1471d8f99c5f48cf7937f7008f942c8eb6c5e4 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 30 Jan 2025 16:29:52 -0500 Subject: [PATCH 04/28] PYTHON-5084 - Convert test.test_heartbeat_monitoring to async (#2100) --- .../asynchronous/test_heartbeat_monitoring.py | 97 +++++++++++++++++++ test/test_client.py | 2 +- test/test_heartbeat_monitoring.py | 34 ++++--- test/utils.py | 59 ++++++++++- tools/synchro.py | 3 + 5 files changed, 177 insertions(+), 18 deletions(-) create mode 100644 test/asynchronous/test_heartbeat_monitoring.py diff --git a/test/asynchronous/test_heartbeat_monitoring.py b/test/asynchronous/test_heartbeat_monitoring.py new file mode 100644 index 000000000..ff595a814 --- /dev/null +++ b/test/asynchronous/test_heartbeat_monitoring.py @@ -0,0 +1,97 @@ +# Copyright 2016-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the monitoring of the server heartbeats.""" +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest +from test.utils import AsyncMockPool, HeartbeatEventListener, async_wait_until + +from pymongo.asynchronous.monitor import Monitor +from pymongo.errors import ConnectionFailure +from pymongo.hello import Hello, HelloCompat + +_IS_SYNC = False + + +class TestHeartbeatMonitoring(AsyncIntegrationTest): + async def create_mock_monitor(self, responses, uri, expected_results): + listener = HeartbeatEventListener() + with client_knobs( + heartbeat_frequency=0.1, min_heartbeat_interval=0.1, events_queue_frequency=0.1 + ): + + class MockMonitor(Monitor): + async def _check_with_socket(self, *args, **kwargs): + if isinstance(responses[1], Exception): + raise responses[1] + return Hello(responses[1]), 99 + + _ = await self.async_single_client( + h=uri, + event_listeners=(listener,), + _monitor_class=MockMonitor, + _pool_class=AsyncMockPool, + connect=True, + ) + + expected_len = len(expected_results) + # Wait for *at least* expected_len number of results. The + # monitor thread may run multiple times during the execution + # of this test. + await async_wait_until( + lambda: len(listener.events) >= expected_len, "publish all events" + ) + + # zip gives us len(expected_results) pairs. + for expected, actual in zip(expected_results, listener.events): + self.assertEqual(expected, actual.__class__.__name__) + self.assertEqual(actual.connection_id, responses[0]) + if expected != "ServerHeartbeatStartedEvent": + if isinstance(actual.reply, Hello): + self.assertEqual(actual.duration, 99) + self.assertEqual(actual.reply._doc, responses[1]) + else: + self.assertEqual(actual.reply, responses[1]) + + async def test_standalone(self): + responses = ( + ("a", 27017), + {HelloCompat.LEGACY_CMD: True, "maxWireVersion": 4, "minWireVersion": 0, "ok": 1}, + ) + uri = "mongodb://a:27017" + expected_results = ["ServerHeartbeatStartedEvent", "ServerHeartbeatSucceededEvent"] + + await self.create_mock_monitor(responses, uri, expected_results) + + async def test_standalone_error(self): + responses = (("a", 27017), ConnectionFailure("SPECIAL MESSAGE")) + uri = "mongodb://a:27017" + # _check_with_socket failing results in a second attempt. + expected_results = [ + "ServerHeartbeatStartedEvent", + "ServerHeartbeatFailedEvent", + "ServerHeartbeatStartedEvent", + "ServerHeartbeatFailedEvent", + ] + + await self.create_mock_monitor(responses, uri, expected_results) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_client.py b/test/test_client.py index 2a33077f5..cdc7691c2 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -2399,7 +2399,7 @@ class TestMongoClientFailover(MockClientTest): # MongoClient discovers it's alone. The first attempt raises either # ServerSelectionTimeoutError or AutoReconnect (from - # AsyncMockPool.get_socket). + # MockPool.get_socket). with self.assertRaises(AutoReconnect): c.db.collection.find_one() diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 5e203a33b..0523d0ba4 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -26,6 +26,8 @@ from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat from pymongo.synchronous.monitor import Monitor +_IS_SYNC = True + class TestHeartbeatMonitoring(IntegrationTest): def create_mock_monitor(self, responses, uri, expected_results): @@ -40,8 +42,12 @@ class TestHeartbeatMonitoring(IntegrationTest): raise responses[1] return Hello(responses[1]), 99 - m = self.single_client( - h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool + _ = self.single_client( + h=uri, + event_listeners=(listener,), + _monitor_class=MockMonitor, + _pool_class=MockPool, + connect=True, ) expected_len = len(expected_results) @@ -50,20 +56,16 @@ class TestHeartbeatMonitoring(IntegrationTest): # of this test. wait_until(lambda: len(listener.events) >= expected_len, "publish all events") - try: - # zip gives us len(expected_results) pairs. - for expected, actual in zip(expected_results, listener.events): - self.assertEqual(expected, actual.__class__.__name__) - self.assertEqual(actual.connection_id, responses[0]) - if expected != "ServerHeartbeatStartedEvent": - if isinstance(actual.reply, Hello): - self.assertEqual(actual.duration, 99) - self.assertEqual(actual.reply._doc, responses[1]) - else: - self.assertEqual(actual.reply, responses[1]) - - finally: - m.close() + # zip gives us len(expected_results) pairs. + for expected, actual in zip(expected_results, listener.events): + self.assertEqual(expected, actual.__class__.__name__) + self.assertEqual(actual.connection_id, responses[0]) + if expected != "ServerHeartbeatStartedEvent": + if isinstance(actual.reply, Hello): + self.assertEqual(actual.duration, 99) + self.assertEqual(actual.reply._doc, responses[1]) + else: + self.assertEqual(actual.reply, responses[1]) def test_standalone(self): responses = ( diff --git a/test/utils.py b/test/utils.py index 69154bc63..91000a636 100644 --- a/test/utils.py +++ b/test/utils.py @@ -43,7 +43,7 @@ from pymongo.cursor_shared import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat from pymongo.helpers_shared import _SENSITIVE_COMMANDS -from pymongo.lock import _create_lock +from pymongo.lock import _async_create_lock, _create_lock from pymongo.monitoring import ( ConnectionCheckedInEvent, ConnectionCheckedOutEvent, @@ -312,6 +312,22 @@ class HeartbeatEventsListListener(HeartbeatEventListener): self.event_list.append("serverHeartbeatFailedEvent") +class AsyncMockConnection: + def __init__(self): + self.cancel_context = _CancellationContext() + self.more_to_come = False + self.id = random.randint(0, 100) + + def close_conn(self, reason): + pass + + def __aenter__(self): + return self + + def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + class MockConnection: def __init__(self): self.cancel_context = _CancellationContext() @@ -328,6 +344,47 @@ class MockConnection: pass +class AsyncMockPool: + def __init__(self, address, options, handshake=True, client_id=None): + self.gen = _PoolGeneration() + self._lock = _async_create_lock() + self.opts = options + self.operation_count = 0 + self.conns = [] + + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) + + @contextlib.asynccontextmanager + async def checkout(self, handler=None): + yield AsyncMockConnection() + + async def checkin(self, *args, **kwargs): + pass + + async def _reset(self, service_id=None): + async with self._lock: + self.gen.inc(service_id) + + async def ready(self): + pass + + async def reset(self, service_id=None, interrupt_connections=False): + await self._reset() + + async def reset_without_pause(self): + await self._reset() + + async def close(self): + await self._reset() + + async def update_is_writable(self, is_writable): + pass + + async def remove_stale_sockets(self, *args, **kwargs): + pass + + class MockPool: def __init__(self, address, options, handshake=True, client_id=None): self.gen = _PoolGeneration() diff --git a/tools/synchro.py b/tools/synchro.py index 08281c73d..74b7c8053 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -119,6 +119,8 @@ replacements = { "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "AsyncMockConnection": "MockConnection", + "AsyncMockPool": "MockPool", } docstring_replacements: dict[tuple[str, str], str] = { @@ -206,6 +208,7 @@ converted_tests = [ "test_database.py", "test_data_lake.py", "test_encryption.py", + "test_heartbeat_monitoring.py", "test_index_management.py", "test_grid_file.py", "test_gridfs_spec.py", From c8d3afdefd627d60dc47201681e4fcd65289815e Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 30 Jan 2025 16:30:04 -0500 Subject: [PATCH 05/28] PYTHON-5086 - Convert test.json_util integration test to async (#2102) --- .../test_json_util_integration.py | 28 +++++++++++++++++++ test/test_json_util.py | 23 ++------------- test/test_json_util_integration.py | 28 +++++++++++++++++++ tools/synchro.py | 1 + 4 files changed, 59 insertions(+), 21 deletions(-) create mode 100644 test/asynchronous/test_json_util_integration.py create mode 100644 test/test_json_util_integration.py diff --git a/test/asynchronous/test_json_util_integration.py b/test/asynchronous/test_json_util_integration.py new file mode 100644 index 000000000..4c02792d8 --- /dev/null +++ b/test/asynchronous/test_json_util_integration.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from test.asynchronous import AsyncIntegrationTest +from typing import Any, List, MutableMapping + +from bson import Binary, Code, DBRef, ObjectId, json_util +from bson.binary import USER_DEFINED_SUBTYPE + +_IS_SYNC = False + + +class TestJsonUtilRoundtrip(AsyncIntegrationTest): + async def test_cursor(self): + db = self.db + + await db.drop_collection("test") + docs: List[MutableMapping[str, Any]] = [ + {"foo": [1, 2]}, + {"bar": {"hello": "world"}}, + {"code": Code("function x() { return 1; }")}, + {"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, + {"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}}, + ] + + await db.test.insert_many(docs) + reloaded_docs = json_util.loads(json_util.dumps(await (db.test.find()).to_list())) + for doc in docs: + self.assertTrue(doc in reloaded_docs) diff --git a/test/test_json_util.py b/test/test_json_util.py index 821ca76da..8aed4a82b 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -21,13 +21,13 @@ import re import sys import uuid from collections import OrderedDict -from typing import Any, List, MutableMapping, Tuple, Type +from typing import Any, Tuple, Type from bson.codec_options import CodecOptions, DatetimeConversion sys.path[0:0] = [""] -from test import IntegrationTest, unittest +from test import unittest from bson import EPOCH_AWARE, EPOCH_NAIVE, SON, DatetimeMS, json_util from bson.binary import ( @@ -636,24 +636,5 @@ class TestJsonUtil(unittest.TestCase): self.assertEqual(json_util.dumps(MyBinary(b"bin", USER_DEFINED_SUBTYPE)), expected_json) -class TestJsonUtilRoundtrip(IntegrationTest): - def test_cursor(self): - db = self.db - - db.drop_collection("test") - docs: List[MutableMapping[str, Any]] = [ - {"foo": [1, 2]}, - {"bar": {"hello": "world"}}, - {"code": Code("function x() { return 1; }")}, - {"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, - {"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}}, - ] - - db.test.insert_many(docs) - reloaded_docs = json_util.loads(json_util.dumps(db.test.find())) - for doc in docs: - self.assertTrue(doc in reloaded_docs) - - if __name__ == "__main__": unittest.main() diff --git a/test/test_json_util_integration.py b/test/test_json_util_integration.py new file mode 100644 index 000000000..acab4f318 --- /dev/null +++ b/test/test_json_util_integration.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from test import IntegrationTest +from typing import Any, List, MutableMapping + +from bson import Binary, Code, DBRef, ObjectId, json_util +from bson.binary import USER_DEFINED_SUBTYPE + +_IS_SYNC = True + + +class TestJsonUtilRoundtrip(IntegrationTest): + def test_cursor(self): + db = self.db + + db.drop_collection("test") + docs: List[MutableMapping[str, Any]] = [ + {"foo": [1, 2]}, + {"bar": {"hello": "world"}}, + {"code": Code("function x() { return 1; }")}, + {"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, + {"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}}, + ] + + db.test.insert_many(docs) + reloaded_docs = json_util.loads(json_util.dumps((db.test.find()).to_list())) + for doc in docs: + self.assertTrue(doc in reloaded_docs) diff --git a/tools/synchro.py b/tools/synchro.py index 74b7c8053..dc272929a 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -211,6 +211,7 @@ converted_tests = [ "test_heartbeat_monitoring.py", "test_index_management.py", "test_grid_file.py", + "test_json_util_integration.py", "test_gridfs_spec.py", "test_logger.py", "test_monitoring.py", From 19fdf7ccebf4c5ec45f54a976abeecb4ebcae1da Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 31 Jan 2025 11:39:48 -0500 Subject: [PATCH 06/28] PYTHON-5093 - Convert test.test_read_concern to async (#2109) --- test/asynchronous/test_read_concern.py | 122 +++++++++++++++++++++++++ test/test_read_concern.py | 12 ++- tools/synchro.py | 1 + 3 files changed, 130 insertions(+), 5 deletions(-) create mode 100644 test/asynchronous/test_read_concern.py diff --git a/test/asynchronous/test_read_concern.py b/test/asynchronous/test_read_concern.py new file mode 100644 index 000000000..fbc07a5c3 --- /dev/null +++ b/test/asynchronous/test_read_concern.py @@ -0,0 +1,122 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the read_concern module.""" +from __future__ import annotations + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context +from test.utils import OvertCommandListener + +from bson.son import SON +from pymongo.errors import OperationFailure +from pymongo.read_concern import ReadConcern + +_IS_SYNC = False + + +class TestReadConcern(AsyncIntegrationTest): + listener: OvertCommandListener + + @async_client_context.require_connection + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + await async_client_context.client.pymongo_test.create_collection("coll") + + async def asyncTearDown(self): + await async_client_context.client.pymongo_test.drop_collection("coll") + + def test_read_concern(self): + rc = ReadConcern() + self.assertIsNone(rc.level) + self.assertTrue(rc.ok_for_legacy) + + rc = ReadConcern("majority") + self.assertEqual("majority", rc.level) + self.assertFalse(rc.ok_for_legacy) + + rc = ReadConcern("local") + self.assertEqual("local", rc.level) + self.assertTrue(rc.ok_for_legacy) + + self.assertRaises(TypeError, ReadConcern, 42) + + async def test_read_concern_uri(self): + uri = f"mongodb://{await async_client_context.pair}/?readConcernLevel=majority" + client = await self.async_rs_or_single_client(uri, connect=False) + self.assertEqual(ReadConcern("majority"), client.read_concern) + + async def test_invalid_read_concern(self): + coll = self.db.get_collection("coll", read_concern=ReadConcern("unknown")) + # We rely on the server to validate read concern. + with self.assertRaises(OperationFailure): + await coll.find_one() + + async def test_find_command(self): + # readConcern not sent in command if not specified. + coll = self.db.coll + await coll.find({"field": "value"}).to_list() + self.assertNotIn("readConcern", self.listener.started_events[0].command) + + self.listener.reset() + + # Explicitly set readConcern to 'local'. + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + await coll.find({"field": "value"}).to_list() + self.assertEqualCommand( + SON( + [ + ("find", "coll"), + ("filter", {"field": "value"}), + ("readConcern", {"level": "local"}), + ] + ), + self.listener.started_events[0].command, + ) + + async def test_command_cursor(self): + # readConcern not sent in command if not specified. + coll = self.db.coll + await (await coll.aggregate([{"$match": {"field": "value"}}])).to_list() + self.assertNotIn("readConcern", self.listener.started_events[0].command) + + self.listener.reset() + + # Explicitly set readConcern to 'local'. + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + await (await coll.aggregate([{"$match": {"field": "value"}}])).to_list() + self.assertEqual({"level": "local"}, self.listener.started_events[0].command["readConcern"]) + + async def test_aggregate_out(self): + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + await ( + await coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}]) + ).to_list() + + # Aggregate with $out supports readConcern MongoDB 4.2 onwards. + if async_client_context.version >= (4, 1): + self.assertIn("readConcern", self.listener.started_events[0].command) + else: + self.assertNotIn("readConcern", self.listener.started_events[0].command) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_read_concern.py b/test/test_read_concern.py index f7c090142..8ec9865ea 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -27,6 +27,8 @@ from bson.son import SON from pymongo.errors import OperationFailure from pymongo.read_concern import ReadConcern +_IS_SYNC = True + class TestReadConcern(IntegrationTest): listener: OvertCommandListener @@ -71,14 +73,14 @@ class TestReadConcern(IntegrationTest): def test_find_command(self): # readConcern not sent in command if not specified. coll = self.db.coll - tuple(coll.find({"field": "value"})) + coll.find({"field": "value"}).to_list() self.assertNotIn("readConcern", self.listener.started_events[0].command) self.listener.reset() # Explicitly set readConcern to 'local'. coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) - tuple(coll.find({"field": "value"})) + coll.find({"field": "value"}).to_list() self.assertEqualCommand( SON( [ @@ -93,19 +95,19 @@ class TestReadConcern(IntegrationTest): def test_command_cursor(self): # readConcern not sent in command if not specified. coll = self.db.coll - tuple(coll.aggregate([{"$match": {"field": "value"}}])) + (coll.aggregate([{"$match": {"field": "value"}}])).to_list() self.assertNotIn("readConcern", self.listener.started_events[0].command) self.listener.reset() # Explicitly set readConcern to 'local'. coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) - tuple(coll.aggregate([{"$match": {"field": "value"}}])) + (coll.aggregate([{"$match": {"field": "value"}}])).to_list() self.assertEqual({"level": "local"}, self.listener.started_events[0].command["readConcern"]) def test_aggregate_out(self): coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) - tuple(coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])) + (coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])).to_list() # Aggregate with $out supports readConcern MongoDB 4.2 onwards. if client_context.version >= (4, 1): diff --git a/tools/synchro.py b/tools/synchro.py index dc272929a..59c5e1ad4 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -216,6 +216,7 @@ converted_tests = [ "test_logger.py", "test_monitoring.py", "test_raw_bson.py", + "test_read_concern.py", "test_retryable_reads.py", "test_retryable_writes.py", "test_session.py", From 8f6249e2f9528895d1cb7f9d760095df60c58e96 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 31 Jan 2025 11:40:05 -0500 Subject: [PATCH 07/28] PYTHON-5091 - Convert test.test_on_demand_csfle to async (#2108) --- test/asynchronous/test_on_demand_csfle.py | 115 ++++++++++++++++++++++ test/test_on_demand_csfle.py | 16 ++- tools/synchro.py | 1 + 3 files changed, 123 insertions(+), 9 deletions(-) create mode 100644 test/asynchronous/test_on_demand_csfle.py diff --git a/test/asynchronous/test_on_demand_csfle.py b/test/asynchronous/test_on_demand_csfle.py new file mode 100644 index 000000000..617e2ed8d --- /dev/null +++ b/test/asynchronous/test_on_demand_csfle.py @@ -0,0 +1,115 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test client side encryption with on demand credentials.""" +from __future__ import annotations + +import os +import sys +import unittest + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context + +from bson.codec_options import CodecOptions +from pymongo.asynchronous.encryption import ( + _HAVE_PYMONGOCRYPT, + AsyncClientEncryption, + EncryptionError, +) + +_IS_SYNC = False + +pytestmark = pytest.mark.csfle + + +class TestonDemandGCPCredentials(AsyncIntegrationTest): + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + @async_client_context.require_version_min(4, 2, -1) + async def asyncSetUp(self): + await super().asyncSetUp() + self.master_key = { + "projectId": "devprod-drivers", + "location": "global", + "keyRing": "key-ring-csfle", + "keyName": "key-name-csfle", + } + + @unittest.skipIf(not os.getenv("TEST_FLE_GCP_AUTO"), "Not testing FLE GCP auto") + async def test_01_failure(self): + if os.environ["SUCCESS"].lower() == "true": + self.skipTest("Expecting success") + self.client_encryption = AsyncClientEncryption( + kms_providers={"gcp": {}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=async_client_context.client, + codec_options=CodecOptions(), + ) + with self.assertRaises(EncryptionError): + await self.client_encryption.create_data_key("gcp", self.master_key) + + @unittest.skipIf(not os.getenv("TEST_FLE_GCP_AUTO"), "Not testing FLE GCP auto") + async def test_02_success(self): + if os.environ["SUCCESS"].lower() == "false": + self.skipTest("Expecting failure") + self.client_encryption = AsyncClientEncryption( + kms_providers={"gcp": {}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=async_client_context.client, + codec_options=CodecOptions(), + ) + await self.client_encryption.create_data_key("gcp", self.master_key) + + +class TestonDemandAzureCredentials(AsyncIntegrationTest): + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + @async_client_context.require_version_min(4, 2, -1) + async def asyncSetUp(self): + await super().asyncSetUp() + self.master_key = { + "keyVaultEndpoint": os.environ["KEY_VAULT_ENDPOINT"], + "keyName": os.environ["KEY_NAME"], + } + + @unittest.skipIf(not os.getenv("TEST_FLE_AZURE_AUTO"), "Not testing FLE Azure auto") + async def test_01_failure(self): + if os.environ["SUCCESS"].lower() == "true": + self.skipTest("Expecting success") + self.client_encryption = AsyncClientEncryption( + kms_providers={"azure": {}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=async_client_context.client, + codec_options=CodecOptions(), + ) + with self.assertRaises(EncryptionError): + await self.client_encryption.create_data_key("azure", self.master_key) + + @unittest.skipIf(not os.getenv("TEST_FLE_AZURE_AUTO"), "Not testing FLE Azure auto") + async def test_02_success(self): + if os.environ["SUCCESS"].lower() == "false": + self.skipTest("Expecting failure") + self.client_encryption = AsyncClientEncryption( + kms_providers={"azure": {}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=async_client_context.client, + codec_options=CodecOptions(), + ) + await self.client_encryption.create_data_key("azure", self.master_key) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_on_demand_csfle.py b/test/test_on_demand_csfle.py index 023feca8c..023d44f64 100644 --- a/test/test_on_demand_csfle.py +++ b/test/test_on_demand_csfle.py @@ -26,18 +26,20 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context from bson.codec_options import CodecOptions -from pymongo.synchronous.encryption import _HAVE_PYMONGOCRYPT, ClientEncryption, EncryptionError +from pymongo.synchronous.encryption import ( + _HAVE_PYMONGOCRYPT, + ClientEncryption, + EncryptionError, +) + +_IS_SYNC = True pytestmark = pytest.mark.csfle class TestonDemandGCPCredentials(IntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) - def setUpClass(cls): - super().setUpClass() - def setUp(self): super().setUp() self.master_key = { @@ -74,12 +76,8 @@ class TestonDemandGCPCredentials(IntegrationTest): class TestonDemandAzureCredentials(IntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) - def setUpClass(cls): - super().setUpClass() - def setUp(self): super().setUp() self.master_key = { diff --git a/tools/synchro.py b/tools/synchro.py index 59c5e1ad4..2cfce1f01 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -215,6 +215,7 @@ converted_tests = [ "test_gridfs_spec.py", "test_logger.py", "test_monitoring.py", + "test_on_demand_csfle.py", "test_raw_bson.py", "test_read_concern.py", "test_retryable_reads.py", From c42f3d64213610be63abfa51c0225c6c12c5a6ba Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 31 Jan 2025 08:43:35 -0800 Subject: [PATCH 08/28] PYTHON-5079 Convert test.test_dns to async (#2096) --- test/asynchronous/test_dns.py | 221 ++++++++++++++++++++++++++++++++++ test/test_dns.py | 45 +++++-- tools/synchro.py | 1 + 3 files changed, 256 insertions(+), 11 deletions(-) create mode 100644 test/asynchronous/test_dns.py diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py new file mode 100644 index 000000000..e24e0fb5c --- /dev/null +++ b/test/asynchronous/test_dns.py @@ -0,0 +1,221 @@ +# Copyright 2017 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the SRV support tests.""" +from __future__ import annotations + +import glob +import json +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + async_client_context, + unittest, +) +from test.utils import async_wait_until + +from pymongo.common import validate_read_preference_tags +from pymongo.errors import ConfigurationError +from pymongo.uri_parser import parse_uri, split_hosts + +_IS_SYNC = False + + +class TestDNSRepl(AsyncPyMongoTestCase): + if _IS_SYNC: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set" + ) + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set" + ) + load_balanced = False + + @async_client_context.require_replica_set + def asyncSetUp(self): + pass + + +class TestDNSLoadBalanced(AsyncPyMongoTestCase): + if _IS_SYNC: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced" + ) + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced" + ) + load_balanced = True + + @async_client_context.require_load_balancer + def asyncSetUp(self): + pass + + +class TestDNSSharded(AsyncPyMongoTestCase): + if _IS_SYNC: + TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded") + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded" + ) + load_balanced = False + + @async_client_context.require_mongos + def asyncSetUp(self): + pass + + +def create_test(test_case): + async def run_test(self): + uri = test_case["uri"] + seeds = test_case.get("seeds") + num_seeds = test_case.get("numSeeds", len(seeds or [])) + hosts = test_case.get("hosts") + num_hosts = test_case.get("numHosts", len(hosts or [])) + + options = test_case.get("options", {}) + if "ssl" in options: + options["tls"] = options.pop("ssl") + parsed_options = test_case.get("parsed_options") + # See DRIVERS-1324, unless tls is explicitly set to False we need TLS. + needs_tls = not (options and (options.get("ssl") is False or options.get("tls") is False)) + if needs_tls and not async_client_context.tls: + self.skipTest("this test requires a TLS cluster") + if not needs_tls and async_client_context.tls: + self.skipTest("this test requires a non-TLS cluster") + + if seeds: + seeds = split_hosts(",".join(seeds)) + if hosts: + hosts = frozenset(split_hosts(",".join(hosts))) + + if seeds or num_seeds: + result = parse_uri(uri, validate=True) + if seeds is not None: + self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) + if num_seeds is not None: + self.assertEqual(len(result["nodelist"]), num_seeds) + if options: + opts = result["options"] + if "readpreferencetags" in opts: + rpts = validate_read_preference_tags( + "readPreferenceTags", opts.pop("readpreferencetags") + ) + opts["readPreferenceTags"] = rpts + self.assertEqual(result["options"], options) + if parsed_options: + for opt, expected in parsed_options.items(): + if opt == "user": + self.assertEqual(result["username"], expected) + elif opt == "password": + self.assertEqual(result["password"], expected) + elif opt == "auth_database" or opt == "db": + self.assertEqual(result["database"], expected) + + hostname = next(iter(async_client_context.client.nodes))[0] + # The replica set members must be configured as 'localhost'. + if hostname == "localhost": + copts = async_client_context.default_client_options.copy() + # Remove tls since SRV parsing should add it automatically. + copts.pop("tls", None) + if async_client_context.tls: + # Our test certs don't support the SRV hosts used in these + # tests. + copts["tlsAllowInvalidHostnames"] = True + + client = self.simple_client(uri, **copts) + if client._options.connect: + await client.aconnect() + if num_seeds is not None: + self.assertEqual(len(client._topology_settings.seeds), num_seeds) + if hosts is not None: + await async_wait_until( + lambda: hosts == client.nodes, "match test hosts to client nodes" + ) + if num_hosts is not None: + await async_wait_until( + lambda: num_hosts == len(client.nodes), "wait to connect to num_hosts" + ) + if test_case.get("ping", True): + await client.admin.command("ping") + # XXX: we should block until SRV poller runs at least once + # and re-run these assertions. + else: + try: + parse_uri(uri) + except (ConfigurationError, ValueError): + pass + else: + self.fail("failed to raise an exception") + + return run_test + + +def create_tests(cls): + for filename in glob.glob(os.path.join(cls.TEST_PATH, "*.json")): + test_suffix, _ = os.path.splitext(os.path.basename(filename)) + with open(filename) as dns_test_file: + test_method = create_test(json.load(dns_test_file)) + setattr(cls, "test_" + test_suffix, test_method) + + +create_tests(TestDNSRepl) +create_tests(TestDNSLoadBalanced) +create_tests(TestDNSSharded) + + +class TestParsingErrors(AsyncPyMongoTestCase): + async def test_invalid_host(self): + self.assertRaisesRegex( + ConfigurationError, + "Invalid URI host: mongodb is not", + self.simple_client, + "mongodb+srv://mongodb", + ) + self.assertRaisesRegex( + ConfigurationError, + "Invalid URI host: mongodb.com is not", + self.simple_client, + "mongodb+srv://mongodb.com", + ) + self.assertRaisesRegex( + ConfigurationError, + "Invalid URI host: an IP address is not", + self.simple_client, + "mongodb+srv://127.0.0.1", + ) + self.assertRaisesRegex( + ConfigurationError, + "Invalid URI host: an IP address is not", + self.simple_client, + "mongodb+srv://[::1]", + ) + + +class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest): + async def test_connect_case_insensitive(self): + client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + self.assertGreater(len(client.topology_description.server_descriptions()), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_dns.py b/test/test_dns.py index f2185efb1..6f4736fd5 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -18,22 +18,35 @@ from __future__ import annotations import glob import json import os +import pathlib import sys sys.path[0:0] = [""] -from test import IntegrationTest, PyMongoTestCase, client_context, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + client_context, + unittest, +) from test.utils import wait_until from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError from pymongo.uri_parser import parse_uri, split_hosts +_IS_SYNC = True + class TestDNSRepl(PyMongoTestCase): - TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set" - ) + if _IS_SYNC: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set" + ) + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set" + ) load_balanced = False @client_context.require_replica_set @@ -42,9 +55,14 @@ class TestDNSRepl(PyMongoTestCase): class TestDNSLoadBalanced(PyMongoTestCase): - TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced" - ) + if _IS_SYNC: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced" + ) + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced" + ) load_balanced = True @client_context.require_load_balancer @@ -53,7 +71,12 @@ class TestDNSLoadBalanced(PyMongoTestCase): class TestDNSSharded(PyMongoTestCase): - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded") + if _IS_SYNC: + TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded") + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded" + ) load_balanced = False @client_context.require_mongos @@ -119,7 +142,9 @@ def create_test(test_case): # tests. copts["tlsAllowInvalidHostnames"] = True - client = PyMongoTestCase.unmanaged_simple_client(uri, **copts) + client = self.simple_client(uri, **copts) + if client._options.connect: + client._connect() if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: @@ -132,7 +157,6 @@ def create_test(test_case): client.admin.command("ping") # XXX: we should block until SRV poller runs at least once # and re-run these assertions. - client.close() else: try: parse_uri(uri) @@ -188,7 +212,6 @@ class TestParsingErrors(PyMongoTestCase): class TestCaseInsensitive(IntegrationTest): def test_connect_case_insensitive(self): client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") - self.addCleanup(client.close) self.assertGreater(len(client.topology_description.server_descriptions()), 1) diff --git a/tools/synchro.py b/tools/synchro.py index 2cfce1f01..ef82db756 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -207,6 +207,7 @@ converted_tests = [ "test_custom_types.py", "test_database.py", "test_data_lake.py", + "test_dns.py", "test_encryption.py", "test_heartbeat_monitoring.py", "test_index_management.py", From 44d1d40d6574d1cb51479bdf68a766d0ade37079 Mon Sep 17 00:00:00 2001 From: The Light <59693377+tejaschauhan36912@users.noreply.github.com> Date: Sat, 1 Feb 2025 01:31:58 +0530 Subject: [PATCH 09/28] PYTHON-5115 Type validation errors should include the invalid type name (#2085) Co-authored-by: Iris Ho --- bson/__init__.py | 2 +- bson/binary.py | 6 ++-- bson/code.py | 4 +-- bson/codec_options.py | 12 ++++++-- bson/dbref.py | 4 +-- bson/decimal128.py | 2 +- bson/timestamp.py | 4 +-- gridfs/asynchronous/grid_file.py | 12 +++++--- gridfs/synchronous/grid_file.py | 12 +++++--- pymongo/__init__.py | 2 +- pymongo/_asyncio_lock.py | 2 +- pymongo/_azure_helpers.py | 2 +- pymongo/asynchronous/auth.py | 2 +- pymongo/asynchronous/auth_oidc.py | 4 ++- pymongo/asynchronous/client_session.py | 12 ++++++-- pymongo/asynchronous/collection.py | 10 +++---- pymongo/asynchronous/command_cursor.py | 6 ++-- pymongo/asynchronous/cursor.py | 28 +++++++++--------- pymongo/asynchronous/database.py | 8 ++++-- pymongo/asynchronous/encryption.py | 8 ++++-- pymongo/asynchronous/mongo_client.py | 12 +++++--- pymongo/auth_shared.py | 4 +-- pymongo/collation.py | 2 +- pymongo/common.py | 40 ++++++++++++++++---------- pymongo/compression_support.py | 2 +- pymongo/driver_info.py | 2 +- pymongo/encryption_options.py | 4 ++- pymongo/helpers_shared.py | 9 ++++-- pymongo/monitoring.py | 8 ++++-- pymongo/read_concern.py | 2 +- pymongo/ssl_support.py | 2 +- pymongo/synchronous/auth.py | 2 +- pymongo/synchronous/auth_oidc.py | 4 ++- pymongo/synchronous/client_session.py | 12 ++++++-- pymongo/synchronous/collection.py | 10 +++---- pymongo/synchronous/command_cursor.py | 6 ++-- pymongo/synchronous/cursor.py | 28 +++++++++--------- pymongo/synchronous/database.py | 8 ++++-- pymongo/synchronous/encryption.py | 8 ++++-- pymongo/synchronous/mongo_client.py | 12 +++++--- pymongo/uri_parser.py | 10 +++---- pymongo/write_concern.py | 4 +-- 42 files changed, 204 insertions(+), 129 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index fc6efe0d5..790ac06ef 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -1386,7 +1386,7 @@ def is_valid(bson: bytes) -> bool: :param bson: the data to be validated """ if not isinstance(bson, bytes): - raise TypeError("BSON data must be an instance of a subclass of bytes") + raise TypeError(f"BSON data must be an instance of a subclass of bytes, not {type(bson)}") try: _bson_to_dict(bson, DEFAULT_CODEC_OPTIONS) diff --git a/bson/binary.py b/bson/binary.py index f90dce226..aab59cccb 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -290,7 +290,7 @@ class Binary(bytes): subtype: int = BINARY_SUBTYPE, ) -> Binary: if not isinstance(subtype, int): - raise TypeError("subtype must be an instance of int") + raise TypeError(f"subtype must be an instance of int, not {type(subtype)}") if subtype >= 256 or subtype < 0: raise ValueError("subtype must be contained in [0, 256)") # Support any type that implements the buffer protocol. @@ -321,7 +321,7 @@ class Binary(bytes): .. versionadded:: 3.11 """ if not isinstance(uuid, UUID): - raise TypeError("uuid must be an instance of uuid.UUID") + raise TypeError(f"uuid must be an instance of uuid.UUID, not {type(uuid)}") if uuid_representation not in ALL_UUID_REPRESENTATIONS: raise ValueError( @@ -470,7 +470,7 @@ class Binary(bytes): """ if self.subtype != VECTOR_SUBTYPE: - raise ValueError(f"Cannot decode subtype {self.subtype} as a vector.") + raise ValueError(f"Cannot decode subtype {self.subtype} as a vector") position = 0 dtype, padding = struct.unpack_from(" Code: if not isinstance(code, str): - raise TypeError("code must be an instance of str") + raise TypeError(f"code must be an instance of str, not {type(code)}") self = str.__new__(cls, code) @@ -67,7 +67,7 @@ class Code(str): if scope is not None: if not isinstance(scope, _Mapping): - raise TypeError("scope must be an instance of dict") + raise TypeError(f"scope must be an instance of dict, not {type(scope)}") if self.__scope is not None: self.__scope.update(scope) # type: ignore else: diff --git a/bson/codec_options.py b/bson/codec_options.py index 3a0b83b7b..258a777a1 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -401,17 +401,23 @@ else: "uuid_representation must be a value from bson.binary.UuidRepresentation" ) if not isinstance(unicode_decode_error_handler, str): - raise ValueError("unicode_decode_error_handler must be a string") + raise ValueError( + f"unicode_decode_error_handler must be a string, not {type(unicode_decode_error_handler)}" + ) if tzinfo is not None: if not isinstance(tzinfo, datetime.tzinfo): - raise TypeError("tzinfo must be an instance of datetime.tzinfo") + raise TypeError( + f"tzinfo must be an instance of datetime.tzinfo, not {type(tzinfo)}" + ) if not tz_aware: raise ValueError("cannot specify tzinfo without also setting tz_aware=True") type_registry = type_registry or TypeRegistry() if not isinstance(type_registry, TypeRegistry): - raise TypeError("type_registry must be an instance of TypeRegistry") + raise TypeError( + f"type_registry must be an instance of TypeRegistry, not {type(type_registry)}" + ) return tuple.__new__( cls, diff --git a/bson/dbref.py b/bson/dbref.py index 6c21b8162..40bdb73cf 100644 --- a/bson/dbref.py +++ b/bson/dbref.py @@ -56,9 +56,9 @@ class DBRef: .. seealso:: The MongoDB documentation on `dbrefs `_. """ if not isinstance(collection, str): - raise TypeError("collection must be an instance of str") + raise TypeError(f"collection must be an instance of str, not {type(collection)}") if database is not None and not isinstance(database, str): - raise TypeError("database must be an instance of str") + raise TypeError(f"database must be an instance of str, not {type(database)}") self.__collection = collection self.__id = id diff --git a/bson/decimal128.py b/bson/decimal128.py index 016afb5eb..92c054d87 100644 --- a/bson/decimal128.py +++ b/bson/decimal128.py @@ -277,7 +277,7 @@ class Decimal128: point in Binary Integer Decimal (BID) format). """ if not isinstance(value, bytes): - raise TypeError("value must be an instance of bytes") + raise TypeError(f"value must be an instance of bytes, not {type(value)}") if len(value) != 16: raise ValueError("value must be exactly 16 bytes") return cls((_UNPACK_64(value[8:])[0], _UNPACK_64(value[:8])[0])) # type: ignore diff --git a/bson/timestamp.py b/bson/timestamp.py index 3e76e7baa..949bd7b36 100644 --- a/bson/timestamp.py +++ b/bson/timestamp.py @@ -58,9 +58,9 @@ class Timestamp: time = time - offset time = int(calendar.timegm(time.timetuple())) if not isinstance(time, int): - raise TypeError("time must be an instance of int") + raise TypeError(f"time must be an instance of int, not {type(time)}") if not isinstance(inc, int): - raise TypeError("inc must be an instance of int") + raise TypeError(f"inc must be an instance of int, not {type(inc)}") if not 0 <= time < UPPERBOUND: raise ValueError("time must be contained in [0, 2**32)") if not 0 <= inc < UPPERBOUND: diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index a49d51d30..d15713c51 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -100,7 +100,7 @@ class AsyncGridFS: .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(database, AsyncDatabase): - raise TypeError("database must be an instance of Database") + raise TypeError(f"database must be an instance of Database, not {type(database)}") database = _clear_entity_type_registry(database) @@ -503,7 +503,7 @@ class AsyncGridFSBucket: .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(db, AsyncDatabase): - raise TypeError("database must be an instance of AsyncDatabase") + raise TypeError(f"database must be an instance of AsyncDatabase, not {type(db)}") db = _clear_entity_type_registry(db) @@ -1082,7 +1082,9 @@ class AsyncGridIn: :attr:`~pymongo.collection.AsyncCollection.write_concern` """ if not isinstance(root_collection, AsyncCollection): - raise TypeError("root_collection must be an instance of AsyncCollection") + raise TypeError( + f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}" + ) if not root_collection.write_concern.acknowledged: raise ConfigurationError("root_collection must use acknowledged write_concern") @@ -1436,7 +1438,9 @@ class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, AsyncCollection): - raise TypeError("root_collection must be an instance of AsyncCollection") + raise TypeError( + f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}" + ) _disallow_transactions(session) root_collection = _clear_entity_type_registry(root_collection) diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 655f05f57..ea0b53cfb 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -100,7 +100,7 @@ class GridFS: .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(database, Database): - raise TypeError("database must be an instance of Database") + raise TypeError(f"database must be an instance of Database, not {type(database)}") database = _clear_entity_type_registry(database) @@ -501,7 +501,7 @@ class GridFSBucket: .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(db, Database): - raise TypeError("database must be an instance of Database") + raise TypeError(f"database must be an instance of Database, not {type(db)}") db = _clear_entity_type_registry(db) @@ -1076,7 +1076,9 @@ class GridIn: :attr:`~pymongo.collection.Collection.write_concern` """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an instance of Collection") + raise TypeError( + f"root_collection must be an instance of Collection, not {type(root_collection)}" + ) if not root_collection.write_concern.acknowledged: raise ConfigurationError("root_collection must use acknowledged write_concern") @@ -1426,7 +1428,9 @@ class GridOut(GRIDOUT_BASE_CLASS): # type: ignore from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an instance of Collection") + raise TypeError( + f"root_collection must be an instance of Collection, not {type(root_collection)}" + ) _disallow_transactions(session) root_collection = _clear_entity_type_registry(root_collection) diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 58f6ff338..8d6def160 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -160,7 +160,7 @@ def timeout(seconds: Optional[float]) -> ContextManager[None]: .. versionadded:: 4.2 """ if not isinstance(seconds, (int, float, type(None))): - raise TypeError("timeout must be None, an int, or a float") + raise TypeError(f"timeout must be None, an int, or a float, not {type(seconds)}") if seconds and seconds < 0: raise ValueError("timeout cannot be negative") if seconds is not None: diff --git a/pymongo/_asyncio_lock.py b/pymongo/_asyncio_lock.py index 669b0f63a..a9c409d48 100644 --- a/pymongo/_asyncio_lock.py +++ b/pymongo/_asyncio_lock.py @@ -160,7 +160,7 @@ class Lock(_ContextManagerMixin, _LoopBoundMixin): self._locked = False self._wake_up_first() else: - raise RuntimeError("Lock is not acquired.") + raise RuntimeError("Lock is not acquired") def _wake_up_first(self) -> None: """Ensure that the first waiter will wake up.""" diff --git a/pymongo/_azure_helpers.py b/pymongo/_azure_helpers.py index 704c561cd..6e86ab567 100644 --- a/pymongo/_azure_helpers.py +++ b/pymongo/_azure_helpers.py @@ -46,7 +46,7 @@ def _get_azure_response( try: data = json.loads(body) except Exception: - raise ValueError("Azure IMDS response must be in JSON format.") from None + raise ValueError("Azure IMDS response must be in JSON format") from None for key in ["access_token", "expires_in"]: if not data.get(key): diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index b1e6d0125..8cc4edf19 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -161,7 +161,7 @@ def _password_digest(username: str, password: str) -> str: if len(password) == 0: raise ValueError("password can't be empty") if not isinstance(username, str): - raise TypeError("username must be an instance of str") + raise TypeError(f"username must be an instance of str, not {type(username)}") md5hash = hashlib.md5() # noqa: S324 data = f"{username}:mongo:{password}" diff --git a/pymongo/asynchronous/auth_oidc.py b/pymongo/asynchronous/auth_oidc.py index f1c15045d..38346648c 100644 --- a/pymongo/asynchronous/auth_oidc.py +++ b/pymongo/asynchronous/auth_oidc.py @@ -213,7 +213,9 @@ class _OIDCAuthenticator: ) resp = cb.fetch(context) if not isinstance(resp, OIDCCallbackResult): - raise ValueError("Callback result must be of type OIDCCallbackResult") + raise ValueError( + f"Callback result must be of type OIDCCallbackResult, not {type(resp)}" + ) self.refresh_token = resp.refresh_token self.access_token = resp.access_token self.token_gen_id += 1 diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index d80495d80..4c5171a35 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -310,7 +310,9 @@ class TransactionOptions: ) if max_commit_time_ms is not None: if not isinstance(max_commit_time_ms, int): - raise TypeError("max_commit_time_ms must be an integer or None") + raise TypeError( + f"max_commit_time_ms must be an integer or None, not {type(max_commit_time_ms)}" + ) @property def read_concern(self) -> Optional[ReadConcern]: @@ -902,7 +904,9 @@ class AsyncClientSession: another `AsyncClientSession` instance. """ if not isinstance(cluster_time, _Mapping): - raise TypeError("cluster_time must be a subclass of collections.Mapping") + raise TypeError( + f"cluster_time must be a subclass of collections.Mapping, not {type(cluster_time)}" + ) if not isinstance(cluster_time.get("clusterTime"), Timestamp): raise ValueError("Invalid cluster_time") self._advance_cluster_time(cluster_time) @@ -923,7 +927,9 @@ class AsyncClientSession: another `AsyncClientSession` instance. """ if not isinstance(operation_time, Timestamp): - raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp") + raise TypeError( + f"operation_time must be an instance of bson.timestamp.Timestamp, not {type(operation_time)}" + ) self._advance_operation_time(operation_time) def _process_response(self, reply: Mapping[str, Any]) -> None: diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 9b7342362..e83a39143 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -228,7 +228,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): read_concern or database.read_concern, ) if not isinstance(name, str): - raise TypeError("name must be an instance of str") + raise TypeError(f"name must be an instance of str, not {type(name)}") from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): @@ -2475,7 +2475,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): name = helpers_shared._gen_index_name(index_or_name) if not isinstance(name, str): - raise TypeError("index_or_name must be an instance of str or list") + raise TypeError(f"index_or_name must be an instance of str or list, not {type(name)}") cmd = {"dropIndexes": self._name, "index": name} cmd.update(kwargs) @@ -3078,7 +3078,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): """ if not isinstance(new_name, str): - raise TypeError("new_name must be an instance of str") + raise TypeError(f"new_name must be an instance of str, not {type(new_name)}") if not new_name or ".." in new_name: raise InvalidName("collection names cannot be empty") @@ -3148,7 +3148,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): """ if not isinstance(key, str): - raise TypeError("key must be an instance of str") + raise TypeError(f"key must be an instance of str, not {type(key)}") cmd = {"distinct": self._name, "key": key} if filter is not None: if "query" in kwargs: @@ -3196,7 +3196,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): common.validate_is_mapping("filter", filter) if not isinstance(return_document, bool): raise ValueError( - "return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER" + f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}" ) collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd = {"findAndModify": self._name, "query": filter, "new": return_document} diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index 5a4559bd7..353c5e91c 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -94,7 +94,9 @@ class AsyncCommandCursor(Generic[_DocumentType]): self.batch_size(batch_size) if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: - raise TypeError("max_await_time_ms must be an integer or None") + raise TypeError( + f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}" + ) def __del__(self) -> None: self._die_no_lock() @@ -115,7 +117,7 @@ class AsyncCommandCursor(Generic[_DocumentType]): :param batch_size: The size of each batch of results requested. """ if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 8193e5328..9101197ce 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -146,9 +146,9 @@ class AsyncCursor(Generic[_DocumentType]): spec: Mapping[str, Any] = filter or {} validate_is_mapping("filter", spec) if not isinstance(skip, int): - raise TypeError("skip must be an instance of int") + raise TypeError(f"skip must be an instance of int, not {type(skip)}") if not isinstance(limit, int): - raise TypeError("limit must be an instance of int") + raise TypeError(f"limit must be an instance of int, not {type(limit)}") validate_boolean("no_cursor_timeout", no_cursor_timeout) if no_cursor_timeout and not self._explicit_session: warnings.warn( @@ -171,7 +171,7 @@ class AsyncCursor(Generic[_DocumentType]): validate_boolean("allow_partial_results", allow_partial_results) validate_boolean("oplog_replay", oplog_replay) if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") # Only set if allow_disk_use is provided by the user, else None. @@ -388,7 +388,7 @@ class AsyncCursor(Generic[_DocumentType]): cursor.add_option(2) """ if not isinstance(mask, int): - raise TypeError("mask must be an int") + raise TypeError(f"mask must be an int, not {type(mask)}") self._check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: @@ -408,7 +408,7 @@ class AsyncCursor(Generic[_DocumentType]): cursor.remove_option(2) """ if not isinstance(mask, int): - raise TypeError("mask must be an int") + raise TypeError(f"mask must be an int, not {type(mask)}") self._check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: @@ -432,7 +432,7 @@ class AsyncCursor(Generic[_DocumentType]): .. versionadded:: 3.11 """ if not isinstance(allow_disk_use, bool): - raise TypeError("allow_disk_use must be a bool") + raise TypeError(f"allow_disk_use must be a bool, not {type(allow_disk_use)}") self._check_okay_to_chain() self._allow_disk_use = allow_disk_use @@ -451,7 +451,7 @@ class AsyncCursor(Generic[_DocumentType]): .. seealso:: The MongoDB documentation on `limit `_. """ if not isinstance(limit, int): - raise TypeError("limit must be an integer") + raise TypeError(f"limit must be an integer, not {type(limit)}") if self._exhaust: raise InvalidOperation("Can't use limit and exhaust together.") self._check_okay_to_chain() @@ -479,7 +479,7 @@ class AsyncCursor(Generic[_DocumentType]): :param batch_size: The size of each batch of results requested. """ if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") self._check_okay_to_chain() @@ -499,7 +499,7 @@ class AsyncCursor(Generic[_DocumentType]): :param skip: the number of results to skip """ if not isinstance(skip, int): - raise TypeError("skip must be an integer") + raise TypeError(f"skip must be an integer, not {type(skip)}") if skip < 0: raise ValueError("skip must be >= 0") self._check_okay_to_chain() @@ -520,7 +520,7 @@ class AsyncCursor(Generic[_DocumentType]): :param max_time_ms: the time limit after which the operation is aborted """ if not isinstance(max_time_ms, int) and max_time_ms is not None: - raise TypeError("max_time_ms must be an integer or None") + raise TypeError(f"max_time_ms must be an integer or None, not {type(max_time_ms)}") self._check_okay_to_chain() self._max_time_ms = max_time_ms @@ -543,7 +543,9 @@ class AsyncCursor(Generic[_DocumentType]): .. versionadded:: 3.2 """ if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: - raise TypeError("max_await_time_ms must be an integer or None") + raise TypeError( + f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}" + ) self._check_okay_to_chain() # Ignore max_await_time_ms if not tailable or await_data is False. @@ -679,7 +681,7 @@ class AsyncCursor(Generic[_DocumentType]): .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): - raise TypeError("spec must be an instance of list or tuple") + raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}") self._check_okay_to_chain() self._max = dict(spec) @@ -701,7 +703,7 @@ class AsyncCursor(Generic[_DocumentType]): .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): - raise TypeError("spec must be an instance of list or tuple") + raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}") self._check_okay_to_chain() self._min = dict(spec) diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 98a0a6ff3..4aba9ab0e 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -122,7 +122,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): from pymongo.asynchronous.mongo_client import AsyncMongoClient if not isinstance(name, str): - raise TypeError("name must be an instance of str") + raise TypeError(f"name must be an instance of str, not {type(name)}") if not isinstance(client, AsyncMongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. @@ -1310,7 +1310,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str") + raise TypeError(f"name_or_collection must be an instance of str, not {type(name)}") encrypted_fields = await self._get_encrypted_fields( {"encryptedFields": encrypted_fields}, name, @@ -1374,7 +1374,9 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str or AsyncCollection") + raise TypeError( + f"name_or_collection must be an instance of str or AsyncCollection, not {type(name)}" + ) cmd = {"validate": name, "scandata": scandata, "full": full} if comment is not None: cmd["comment"] = comment diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 98ab68527..f777104cf 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -322,7 +322,9 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc] raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS) data_key_id = raw_doc.get("_id") if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE: - raise TypeError("data_key _id must be Binary with a UUID subtype") + raise TypeError( + f"data_key _id must be Binary with a UUID subtype, not {type(data_key_id)}" + ) assert self.key_vault_coll is not None await self.key_vault_coll.insert_one(raw_doc) @@ -644,7 +646,9 @@ class AsyncClientEncryption(Generic[_DocumentType]): ) if not isinstance(codec_options, CodecOptions): - raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + raise TypeError( + f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}" + ) if not isinstance(key_vault_client, AsyncMongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 1600e5062..cf7de19c2 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -750,7 +750,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): if port is None: port = self.PORT if not isinstance(port, int): - raise TypeError("port must be an instance of int") + raise TypeError(f"port must be an instance of int, not {type(port)}") # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -1971,7 +1971,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): The cursor is closed synchronously on the current thread. """ if not isinstance(cursor_id, int): - raise TypeError("cursor_id must be an instance of int") + raise TypeError(f"cursor_id must be an instance of int, not {type(cursor_id)}") try: if conn_mgr: @@ -2093,7 +2093,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): """If provided session is None, lend a temporary session.""" if session is not None: if not isinstance(session, client_session.AsyncClientSession): - raise ValueError("'session' argument must be an AsyncClientSession or None.") + raise ValueError( + f"'session' argument must be an AsyncClientSession or None, not {type(session)}" + ) # Don't call end_session. yield session return @@ -2247,7 +2249,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): name = name.name if not isinstance(name, str): - raise TypeError("name_or_database must be an instance of str or a AsyncDatabase") + raise TypeError( + f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}" + ) async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: await self[name]._command( diff --git a/pymongo/auth_shared.py b/pymongo/auth_shared.py index 9534bd74a..410521d73 100644 --- a/pymongo/auth_shared.py +++ b/pymongo/auth_shared.py @@ -107,7 +107,7 @@ def _build_credentials_tuple( ) -> MongoCredential: """Build and return a mechanism specific credentials tuple.""" if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: - raise ConfigurationError(f"{mech} requires a username.") + raise ConfigurationError(f"{mech} requires a username") if mech == "GSSAPI": if source is not None and source != "$external": raise ValueError("authentication source must be $external or None for GSSAPI") @@ -219,7 +219,7 @@ def _build_credentials_tuple( else: source_database = source or database or "admin" if passwd is None: - raise ConfigurationError("A password is required.") + raise ConfigurationError("A password is required") return MongoCredential(mech, source_database, user, passwd, None, _Cache()) diff --git a/pymongo/collation.py b/pymongo/collation.py index 9adcb2e40..fc84b937f 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -223,4 +223,4 @@ def validate_collation_or_none( return value.document if isinstance(value, dict): return value - raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") + raise TypeError("collation must be a dict, an instance of collation.Collation, or None") diff --git a/pymongo/common.py b/pymongo/common.py index b442da6a3..4be7a3122 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -202,7 +202,7 @@ def validate_integer(option: str, value: Any) -> int: return int(value) except ValueError: raise ValueError(f"The value of {option} must be an integer") from None - raise TypeError(f"Wrong type for {option}, value must be an integer") + raise TypeError(f"Wrong type for {option}, value must be an integer, not {type(value)}") def validate_positive_integer(option: str, value: Any) -> int: @@ -250,7 +250,7 @@ def validate_string(option: str, value: Any) -> str: """Validates that 'value' is an instance of `str`.""" if isinstance(value, str): return value - raise TypeError(f"Wrong type for {option}, value must be an instance of str") + raise TypeError(f"Wrong type for {option}, value must be an instance of str, not {type(value)}") def validate_string_or_none(option: str, value: Any) -> Optional[str]: @@ -269,7 +269,9 @@ def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]: return int(value) except ValueError: return value - raise TypeError(f"Wrong type for {option}, value must be an integer or a string") + raise TypeError( + f"Wrong type for {option}, value must be an integer or a string, not {type(value)}" + ) def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]: @@ -282,7 +284,9 @@ def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[in except ValueError: return value return validate_non_negative_integer(option, val) - raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string") + raise TypeError( + f"Wrong type for {option}, value must be an non negative integer or a string, not {type(value)}" + ) def validate_positive_float(option: str, value: Any) -> float: @@ -365,7 +369,7 @@ def validate_max_staleness(option: str, value: Any) -> int: def validate_read_preference(dummy: Any, value: Any) -> _ServerMode: """Validate a read preference.""" if not isinstance(value, _ServerMode): - raise TypeError(f"{value!r} is not a read preference.") + raise TypeError(f"{value!r} is not a read preference") return value @@ -441,7 +445,9 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni props: dict[str, Any] = {} if not isinstance(value, str): if not isinstance(value, dict): - raise ValueError("Auth mechanism properties must be given as a string or a dictionary") + raise ValueError( + f"Auth mechanism properties must be given as a string or a dictionary, not {type(value)}" + ) for key, value in value.items(): # noqa: B020 if isinstance(value, str): props[key] = value @@ -453,7 +459,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni from pymongo.auth_oidc_shared import OIDCCallback if not isinstance(value, OIDCCallback): - raise ValueError("callback must be an OIDCCallback object") + raise ValueError(f"callback must be an OIDCCallback object, not {type(value)}") props[key] = value else: raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}") @@ -476,7 +482,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni raise ValueError( f"{key} is not a supported auth " "mechanism property. Must be one of " - f"{tuple(_MECHANISM_PROPS)}." + f"{tuple(_MECHANISM_PROPS)}" ) if key == "CANONICALIZE_HOST_NAME": @@ -520,7 +526,7 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: def validate_list(option: str, value: Any) -> list: """Validates that 'value' is a list.""" if not isinstance(value, list): - raise TypeError(f"{option} must be a list") + raise TypeError(f"{option} must be a list, not {type(value)}") return value @@ -587,7 +593,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]: if value is None: return value if not isinstance(value, ServerApi): - raise TypeError(f"{option} must be an instance of ServerApi") + raise TypeError(f"{option} must be an instance of ServerApi, not {type(value)}") return value @@ -596,7 +602,7 @@ def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]: if value is None: return value if not callable(value): - raise ValueError(f"{option} must be a callable") + raise ValueError(f"{option} must be a callable, not {type(value)}") return value @@ -651,7 +657,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A from pymongo.encryption_options import AutoEncryptionOpts if not isinstance(value, AutoEncryptionOpts): - raise TypeError(f"{option} must be an instance of AutoEncryptionOpts") + raise TypeError(f"{option} must be an instance of AutoEncryptionOpts, not {type(value)}") return value @@ -668,7 +674,9 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo elif isinstance(value, int): return DatetimeConversion(value) - raise TypeError(f"{option} must be a str or int representing DatetimeConversion") + raise TypeError( + f"{option} must be a str or int representing DatetimeConversion, not {type(value)}" + ) def validate_server_monitoring_mode(option: str, value: str) -> str: @@ -928,12 +936,14 @@ class BaseObject: if not isinstance(write_concern, WriteConcern): raise TypeError( - "write_concern must be an instance of pymongo.write_concern.WriteConcern" + f"write_concern must be an instance of pymongo.write_concern.WriteConcern, not {type(write_concern)}" ) self._write_concern = write_concern if not isinstance(read_concern, ReadConcern): - raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern") + raise TypeError( + f"read_concern must be an instance of pymongo.read_concern.ReadConcern, not {type(read_concern)}" + ) self._read_concern = read_concern @property diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index f49b56cc9..748645173 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -91,7 +91,7 @@ def validate_zlib_compression_level(option: str, value: Any) -> int: try: level = int(value) except Exception: - raise TypeError(f"{option} must be an integer, not {value!r}.") from None + raise TypeError(f"{option} must be an integer, not {value!r}") from None if level < -1 or level > 9: raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) return level diff --git a/pymongo/driver_info.py b/pymongo/driver_info.py index 5ca3f952c..724a6f20d 100644 --- a/pymongo/driver_info.py +++ b/pymongo/driver_info.py @@ -39,7 +39,7 @@ class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])): for key, value in self._asdict().items(): if value is not None and not isinstance(value, str): raise TypeError( - f"Wrong type for DriverInfo {key} option, value must be an instance of str" + f"Wrong type for DriverInfo {key} option, value must be an instance of str, not {type(value)}" ) return self diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index ee749e7ac..26dfbf5f0 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -225,7 +225,9 @@ class AutoEncryptionOpts: mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] self._mongocryptd_spawn_args = mongocryptd_spawn_args if not isinstance(self._mongocryptd_spawn_args, list): - raise TypeError("mongocryptd_spawn_args must be a list") + raise TypeError( + f"mongocryptd_spawn_args must be a list, not {type(self._mongocryptd_spawn_args)}" + ) if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") # Maps KMS provider name to a SSLContext. diff --git a/pymongo/helpers_shared.py b/pymongo/helpers_shared.py index 83ea2ddf7..c6b820c1c 100644 --- a/pymongo/helpers_shared.py +++ b/pymongo/helpers_shared.py @@ -122,7 +122,7 @@ def _index_list( """ if direction is not None: if not isinstance(key_or_list, str): - raise TypeError("Expected a string and a direction") + raise TypeError(f"Expected a string and a direction, not {type(key_or_list)}") return [(key_or_list, direction)] else: if isinstance(key_or_list, str): @@ -132,7 +132,9 @@ def _index_list( elif isinstance(key_or_list, abc.Mapping): return list(key_or_list.items()) elif not isinstance(key_or_list, (list, tuple)): - raise TypeError("if no direction is specified, key_or_list must be an instance of list") + raise TypeError( + f"if no direction is specified, key_or_list must be an instance of list, not {type(key_or_list)}" + ) values: list[tuple[str, int]] = [] for item in key_or_list: if isinstance(item, str): @@ -172,11 +174,12 @@ def _index_document(index_list: _IndexList) -> dict[str, Any]: def _validate_index_key_pair(key: Any, value: Any) -> None: if not isinstance(key, str): - raise TypeError("first item in each key pair must be an instance of str") + raise TypeError(f"first item in each key pair must be an instance of str, not {type(key)}") if not isinstance(value, (str, int, abc.Mapping)): raise TypeError( "second item in each key pair must be 1, -1, " "'2d', or another valid MongoDB index specifier." + f", not {type(value)}" ) diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 96f88597d..38d6e3a22 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -472,14 +472,15 @@ def _validate_event_listeners( ) -> Sequence[_EventListeners]: """Validate event listeners""" if not isinstance(listeners, abc.Sequence): - raise TypeError(f"{option} must be a list or tuple") + raise TypeError(f"{option} must be a list or tuple, not {type(listeners)}") for listener in listeners: if not isinstance(listener, _EventListener): raise TypeError( f"Listeners for {option} must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " - "ConnectionPoolListener." + "ConnectionPoolListener," + f"not {type(listener)}" ) return listeners @@ -496,7 +497,8 @@ def register(listener: _EventListener) -> None: f"Listeners for {listener} must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " - "ConnectionPoolListener." + "ConnectionPoolListener," + f"not {type(listener)}" ) if isinstance(listener, CommandListener): _LISTENERS.command_listeners.append(listener) diff --git a/pymongo/read_concern.py b/pymongo/read_concern.py index fa2f4a318..17f3a46ed 100644 --- a/pymongo/read_concern.py +++ b/pymongo/read_concern.py @@ -38,7 +38,7 @@ class ReadConcern: if level is None or isinstance(level, str): self.__level = level else: - raise TypeError("level must be a string or None.") + raise TypeError(f"level must be a string or None, not {type(level)}") @property def level(self) -> Optional[str]: diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index 580d71f9b..0faf21ba8 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -115,4 +115,4 @@ else: def get_ssl_context(*dummy): # type: ignore """No ssl module, raise ConfigurationError.""" - raise ConfigurationError("The ssl module is not available.") + raise ConfigurationError("The ssl module is not available") diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index 56860eff3..6041ebdbe 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -158,7 +158,7 @@ def _password_digest(username: str, password: str) -> str: if len(password) == 0: raise ValueError("password can't be empty") if not isinstance(username, str): - raise TypeError("username must be an instance of str") + raise TypeError(f"username must be an instance of str, not {type(username)}") md5hash = hashlib.md5() # noqa: S324 data = f"{username}:mongo:{password}" diff --git a/pymongo/synchronous/auth_oidc.py b/pymongo/synchronous/auth_oidc.py index 5a8967d96..c5efdd5fc 100644 --- a/pymongo/synchronous/auth_oidc.py +++ b/pymongo/synchronous/auth_oidc.py @@ -213,7 +213,9 @@ class _OIDCAuthenticator: ) resp = cb.fetch(context) if not isinstance(resp, OIDCCallbackResult): - raise ValueError("Callback result must be of type OIDCCallbackResult") + raise ValueError( + f"Callback result must be of type OIDCCallbackResult, not {type(resp)}" + ) self.refresh_token = resp.refresh_token self.access_token = resp.access_token self.token_gen_id += 1 diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index f1d680fc0..298dd7b35 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -309,7 +309,9 @@ class TransactionOptions: ) if max_commit_time_ms is not None: if not isinstance(max_commit_time_ms, int): - raise TypeError("max_commit_time_ms must be an integer or None") + raise TypeError( + f"max_commit_time_ms must be an integer or None, not {type(max_commit_time_ms)}" + ) @property def read_concern(self) -> Optional[ReadConcern]: @@ -897,7 +899,9 @@ class ClientSession: another `ClientSession` instance. """ if not isinstance(cluster_time, _Mapping): - raise TypeError("cluster_time must be a subclass of collections.Mapping") + raise TypeError( + f"cluster_time must be a subclass of collections.Mapping, not {type(cluster_time)}" + ) if not isinstance(cluster_time.get("clusterTime"), Timestamp): raise ValueError("Invalid cluster_time") self._advance_cluster_time(cluster_time) @@ -918,7 +922,9 @@ class ClientSession: another `ClientSession` instance. """ if not isinstance(operation_time, Timestamp): - raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp") + raise TypeError( + f"operation_time must be an instance of bson.timestamp.Timestamp, not {type(operation_time)}" + ) self._advance_operation_time(operation_time) def _process_response(self, reply: Mapping[str, Any]) -> None: diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 6edfddc9a..b956ac58a 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -231,7 +231,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): read_concern or database.read_concern, ) if not isinstance(name, str): - raise TypeError("name must be an instance of str") + raise TypeError(f"name must be an instance of str, not {type(name)}") from pymongo.synchronous.database import Database if not isinstance(database, Database): @@ -2472,7 +2472,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): name = helpers_shared._gen_index_name(index_or_name) if not isinstance(name, str): - raise TypeError("index_or_name must be an instance of str or list") + raise TypeError(f"index_or_name must be an instance of str or list, not {type(name)}") cmd = {"dropIndexes": self._name, "index": name} cmd.update(kwargs) @@ -3071,7 +3071,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): """ if not isinstance(new_name, str): - raise TypeError("new_name must be an instance of str") + raise TypeError(f"new_name must be an instance of str, not {type(new_name)}") if not new_name or ".." in new_name: raise InvalidName("collection names cannot be empty") @@ -3141,7 +3141,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): """ if not isinstance(key, str): - raise TypeError("key must be an instance of str") + raise TypeError(f"key must be an instance of str, not {type(key)}") cmd = {"distinct": self._name, "key": key} if filter is not None: if "query" in kwargs: @@ -3189,7 +3189,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): common.validate_is_mapping("filter", filter) if not isinstance(return_document, bool): raise ValueError( - "return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER" + f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}" ) collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd = {"findAndModify": self._name, "query": filter, "new": return_document} diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index 3a4372856..e23519d74 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -94,7 +94,9 @@ class CommandCursor(Generic[_DocumentType]): self.batch_size(batch_size) if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: - raise TypeError("max_await_time_ms must be an integer or None") + raise TypeError( + f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}" + ) def __del__(self) -> None: self._die_no_lock() @@ -115,7 +117,7 @@ class CommandCursor(Generic[_DocumentType]): :param batch_size: The size of each batch of results requested. """ if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index b35098a32..cda093ee1 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -146,9 +146,9 @@ class Cursor(Generic[_DocumentType]): spec: Mapping[str, Any] = filter or {} validate_is_mapping("filter", spec) if not isinstance(skip, int): - raise TypeError("skip must be an instance of int") + raise TypeError(f"skip must be an instance of int, not {type(skip)}") if not isinstance(limit, int): - raise TypeError("limit must be an instance of int") + raise TypeError(f"limit must be an instance of int, not {type(limit)}") validate_boolean("no_cursor_timeout", no_cursor_timeout) if no_cursor_timeout and not self._explicit_session: warnings.warn( @@ -171,7 +171,7 @@ class Cursor(Generic[_DocumentType]): validate_boolean("allow_partial_results", allow_partial_results) validate_boolean("oplog_replay", oplog_replay) if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") # Only set if allow_disk_use is provided by the user, else None. @@ -388,7 +388,7 @@ class Cursor(Generic[_DocumentType]): cursor.add_option(2) """ if not isinstance(mask, int): - raise TypeError("mask must be an int") + raise TypeError(f"mask must be an int, not {type(mask)}") self._check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: @@ -408,7 +408,7 @@ class Cursor(Generic[_DocumentType]): cursor.remove_option(2) """ if not isinstance(mask, int): - raise TypeError("mask must be an int") + raise TypeError(f"mask must be an int, not {type(mask)}") self._check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: @@ -432,7 +432,7 @@ class Cursor(Generic[_DocumentType]): .. versionadded:: 3.11 """ if not isinstance(allow_disk_use, bool): - raise TypeError("allow_disk_use must be a bool") + raise TypeError(f"allow_disk_use must be a bool, not {type(allow_disk_use)}") self._check_okay_to_chain() self._allow_disk_use = allow_disk_use @@ -451,7 +451,7 @@ class Cursor(Generic[_DocumentType]): .. seealso:: The MongoDB documentation on `limit `_. """ if not isinstance(limit, int): - raise TypeError("limit must be an integer") + raise TypeError(f"limit must be an integer, not {type(limit)}") if self._exhaust: raise InvalidOperation("Can't use limit and exhaust together.") self._check_okay_to_chain() @@ -479,7 +479,7 @@ class Cursor(Generic[_DocumentType]): :param batch_size: The size of each batch of results requested. """ if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") self._check_okay_to_chain() @@ -499,7 +499,7 @@ class Cursor(Generic[_DocumentType]): :param skip: the number of results to skip """ if not isinstance(skip, int): - raise TypeError("skip must be an integer") + raise TypeError(f"skip must be an integer, not {type(skip)}") if skip < 0: raise ValueError("skip must be >= 0") self._check_okay_to_chain() @@ -520,7 +520,7 @@ class Cursor(Generic[_DocumentType]): :param max_time_ms: the time limit after which the operation is aborted """ if not isinstance(max_time_ms, int) and max_time_ms is not None: - raise TypeError("max_time_ms must be an integer or None") + raise TypeError(f"max_time_ms must be an integer or None, not {type(max_time_ms)}") self._check_okay_to_chain() self._max_time_ms = max_time_ms @@ -543,7 +543,9 @@ class Cursor(Generic[_DocumentType]): .. versionadded:: 3.2 """ if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: - raise TypeError("max_await_time_ms must be an integer or None") + raise TypeError( + f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}" + ) self._check_okay_to_chain() # Ignore max_await_time_ms if not tailable or await_data is False. @@ -677,7 +679,7 @@ class Cursor(Generic[_DocumentType]): .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): - raise TypeError("spec must be an instance of list or tuple") + raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}") self._check_okay_to_chain() self._max = dict(spec) @@ -699,7 +701,7 @@ class Cursor(Generic[_DocumentType]): .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): - raise TypeError("spec must be an instance of list or tuple") + raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}") self._check_okay_to_chain() self._min = dict(spec) diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index a0bef5534..0dc03cb74 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -122,7 +122,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): from pymongo.synchronous.mongo_client import MongoClient if not isinstance(name, str): - raise TypeError("name must be an instance of str") + raise TypeError(f"name must be an instance of str, not {type(name)}") if not isinstance(client, MongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. @@ -1303,7 +1303,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str") + raise TypeError(f"name_or_collection must be an instance of str, not {type(name)}") encrypted_fields = self._get_encrypted_fields( {"encryptedFields": encrypted_fields}, name, @@ -1367,7 +1367,9 @@ class Database(common.BaseObject, Generic[_DocumentType]): name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str or Collection") + raise TypeError( + f"name_or_collection must be an instance of str or Collection, not {type(name)}" + ) cmd = {"validate": name, "scandata": scandata, "full": full} if comment is not None: cmd["comment"] = comment diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index d41169861..59f38e191 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -320,7 +320,9 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc] raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS) data_key_id = raw_doc.get("_id") if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE: - raise TypeError("data_key _id must be Binary with a UUID subtype") + raise TypeError( + f"data_key _id must be Binary with a UUID subtype, not {type(data_key_id)}" + ) assert self.key_vault_coll is not None self.key_vault_coll.insert_one(raw_doc) @@ -642,7 +644,9 @@ class ClientEncryption(Generic[_DocumentType]): ) if not isinstance(codec_options, CodecOptions): - raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + raise TypeError( + f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}" + ) if not isinstance(key_vault_client, MongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index a694a58c1..706623c21 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -748,7 +748,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if port is None: port = self.PORT if not isinstance(port, int): - raise TypeError("port must be an instance of int") + raise TypeError(f"port must be an instance of int, not {type(port)}") # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -1965,7 +1965,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): The cursor is closed synchronously on the current thread. """ if not isinstance(cursor_id, int): - raise TypeError("cursor_id must be an instance of int") + raise TypeError(f"cursor_id must be an instance of int, not {type(cursor_id)}") try: if conn_mgr: @@ -2087,7 +2087,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): """If provided session is None, lend a temporary session.""" if session is not None: if not isinstance(session, client_session.ClientSession): - raise ValueError("'session' argument must be a ClientSession or None.") + raise ValueError( + f"'session' argument must be a ClientSession or None, not {type(session)}" + ) # Don't call end_session. yield session return @@ -2235,7 +2237,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): name = name.name if not isinstance(name, str): - raise TypeError("name_or_database must be an instance of str or a Database") + raise TypeError( + f"name_or_database must be an instance of str or a Database, not {type(name)}" + ) with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: self[name]._command( diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 7018dad7d..8f56ae409 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -91,7 +91,7 @@ def parse_userinfo(userinfo: str) -> tuple[str, str]: user, _, passwd = userinfo.partition(":") # No password is expected with GSSAPI authentication. if not user: - raise InvalidURI("The empty string is not valid username.") + raise InvalidURI("The empty string is not valid username") return unquote_plus(user), unquote_plus(passwd) @@ -347,7 +347,7 @@ def split_options( semi_idx = opts.find(";") try: if and_idx >= 0 and semi_idx >= 0: - raise InvalidURI("Can not mix '&' and ';' for option separators.") + raise InvalidURI("Can not mix '&' and ';' for option separators") elif and_idx >= 0: options = _parse_options(opts, "&") elif semi_idx >= 0: @@ -357,7 +357,7 @@ def split_options( else: raise ValueError except ValueError: - raise InvalidURI("MongoDB URI options are key=value pairs.") from None + raise InvalidURI("MongoDB URI options are key=value pairs") from None options = _handle_security_options(options) @@ -389,7 +389,7 @@ def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[ nodes = [] for entity in hosts.split(","): if not entity: - raise ConfigurationError("Empty host (or extra comma in host list).") + raise ConfigurationError("Empty host (or extra comma in host list)") port = default_port # Unix socket entities don't have ports if entity.endswith(".sock"): @@ -502,7 +502,7 @@ def parse_uri( raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP.") + raise InvalidURI("Must provide at least one hostname or IP") user = None passwd = None diff --git a/pymongo/write_concern.py b/pymongo/write_concern.py index 67c954989..21faeebed 100644 --- a/pymongo/write_concern.py +++ b/pymongo/write_concern.py @@ -74,7 +74,7 @@ class WriteConcern: if wtimeout is not None: if not isinstance(wtimeout, int): - raise TypeError("wtimeout must be an integer") + raise TypeError(f"wtimeout must be an integer, not {type(wtimeout)}") if wtimeout < 0: raise ValueError("wtimeout cannot be less than 0") self.__document["wtimeout"] = wtimeout @@ -98,7 +98,7 @@ class WriteConcern: raise ValueError("w cannot be less than 0") self.__acknowledged = w > 0 elif not isinstance(w, str): - raise TypeError("w must be an integer or string") + raise TypeError(f"w must be an integer or string, not {type(w)}") self.__document["w"] = w self.__server_default = not self.__document From 3e783f5489c5ac899be5e1bbbe0d26f1fe4f1f73 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 31 Jan 2025 12:13:21 -0800 Subject: [PATCH 10/28] PYTHON-5088 Convert test.test_max_staleness to async (#2105) --- test/asynchronous/test_max_staleness.py | 149 ++++++++++++++++++++++++ test/test_max_staleness.py | 11 +- tools/synchro.py | 1 + 3 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 test/asynchronous/test_max_staleness.py diff --git a/test/asynchronous/test_max_staleness.py b/test/asynchronous/test_max_staleness.py new file mode 100644 index 000000000..7dbf17021 --- /dev/null +++ b/test/asynchronous/test_max_staleness.py @@ -0,0 +1,149 @@ +# Copyright 2016 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test maxStalenessSeconds support.""" +from __future__ import annotations + +import asyncio +import os +import sys +import time +import warnings +from pathlib import Path + +from pymongo import AsyncMongoClient +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest +from test.utils_selection_tests import create_selection_tests + +from pymongo.errors import ConfigurationError +from pymongo.server_selectors import writable_server_selector + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "max_staleness") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "max_staleness") + + +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore + pass + + +class TestMaxStaleness(AsyncPyMongoTestCase): + async def test_max_staleness(self): + client = self.simple_client() + self.assertEqual(-1, client.read_preference.max_staleness) + + client = self.simple_client("mongodb://a/?readPreference=secondary") + self.assertEqual(-1, client.read_preference.max_staleness) + + # These tests are specified in max-staleness-tests.rst. + with self.assertRaises(ConfigurationError): + # Default read pref "primary" can't be used with max staleness. + self.simple_client("mongodb://a/?maxStalenessSeconds=120") + + with self.assertRaises(ConfigurationError): + # Read pref "primary" can't be used with max staleness. + self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120") + + client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1") + self.assertEqual(-1, client.read_preference.max_staleness) + + client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1") + self.assertEqual(-1, client.read_preference.max_staleness) + + client = self.simple_client( + "mongodb://host/?readPreference=secondary&maxStalenessSeconds=120" + ) + self.assertEqual(120, client.read_preference.max_staleness) + + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1") + self.assertEqual(1, client.read_preference.max_staleness) + + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1") + self.assertEqual(-1, client.read_preference.max_staleness) + + client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest") + self.assertEqual(-1, client.read_preference.max_staleness) + + with self.assertRaises(TypeError): + # Prohibit None. + self.simple_client(maxStalenessSeconds=None, readPreference="nearest") + + async def test_max_staleness_float(self): + with self.assertRaises(TypeError) as ctx: + await self.async_rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") + + self.assertIn("must be an integer", str(ctx.exception)) + + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter("always") + client = self.simple_client( + "mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest" + ) + + # Option was ignored. + self.assertEqual(-1, client.read_preference.max_staleness) + self.assertIn("must be an integer", str(ctx[0])) + + async def test_max_staleness_zero(self): + # Zero is too small. + with self.assertRaises(ValueError) as ctx: + await self.async_rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") + + self.assertIn("must be a positive integer", str(ctx.exception)) + + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter("always") + client = self.simple_client( + "mongodb://host/?maxStalenessSeconds=0&readPreference=nearest" + ) + + # Option was ignored. + self.assertEqual(-1, client.read_preference.max_staleness) + self.assertIn("must be a positive integer", str(ctx[0])) + + @async_client_context.require_replica_set + async def test_last_write_date(self): + # From max-staleness-tests.rst, "Parse lastWriteDate". + client = await self.async_rs_or_single_client(heartbeatFrequencyMS=500) + await client.pymongo_test.test.insert_one({}) + # Wait for the server description to be updated. + await asyncio.sleep(1) + server = await client._topology.select_server(writable_server_selector, _Op.TEST) + first = server.description.last_write_date + self.assertTrue(first) + # The first last_write_date may correspond to a internal server write, + # sleep so that the next write does not occur within the same second. + await asyncio.sleep(1) + await client.pymongo_test.test.insert_one({}) + # Wait for the server description to be updated. + await asyncio.sleep(1) + server = await client._topology.select_server(writable_server_selector, _Op.TEST) + second = server.description.last_write_date + assert first is not None + + assert second is not None + self.assertGreater(second, first) + self.assertLess(second, first + 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 32d09ada9..56e047fd4 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -15,10 +15,12 @@ """Test maxStalenessSeconds support.""" from __future__ import annotations +import asyncio import os import sys import time import warnings +from pathlib import Path from pymongo import MongoClient from pymongo.operations import _Op @@ -31,11 +33,16 @@ from test.utils_selection_tests import create_selection_tests from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "max_staleness") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "max_staleness") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "max_staleness") -class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore pass diff --git a/tools/synchro.py b/tools/synchro.py index ef82db756..2a2969679 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -215,6 +215,7 @@ converted_tests = [ "test_json_util_integration.py", "test_gridfs_spec.py", "test_logger.py", + "test_max_staleness.py", "test_monitoring.py", "test_on_demand_csfle.py", "test_raw_bson.py", From 3b5788906ddeb326c0407f7b490aadde0f88c2ee Mon Sep 17 00:00:00 2001 From: Jib Date: Fri, 31 Jan 2025 15:16:17 -0500 Subject: [PATCH 11/28] Update ReadTheDocs to include django-mongodb-backend (#2084) --- doc/tools.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/tools.rst b/doc/tools.rst index 6dd0df8a4..7ec3ddb44 100644 --- a/doc/tools.rst +++ b/doc/tools.rst @@ -67,6 +67,14 @@ uMongo mongomock. The source `is available on GitHub `_ +Django MongoDB Backend + `Django MongoDB Backend `_ is a + database backend library specifically made for Django. The integration takes + advantage of MongoDB's unique document model capabilities, which align + naturally with Django's philosophy of simplified data modeling and + reduced development complexity. The source is available + `on GitHub `_. + No longer maintained """""""""""""""""""" From acc437af57aca88fc2b333ca6ca5dead28819fe6 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 31 Jan 2025 15:50:46 -0500 Subject: [PATCH 12/28] PYTHON-5097 - Convert test.test_retryable_writes_unified to async (#2113) --- .../test_retryable_writes_unified.py | 39 +++++++++++++++++++ test/test_retryable_writes_unified.py | 8 +++- tools/synchro.py | 1 + 3 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_retryable_writes_unified.py diff --git a/test/asynchronous/test_retryable_writes_unified.py b/test/asynchronous/test_retryable_writes_unified.py new file mode 100644 index 000000000..bb493e601 --- /dev/null +++ b/test/asynchronous/test_retryable_writes_unified.py @@ -0,0 +1,39 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Retryable Writes unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified") + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_retryable_writes_unified.py b/test/test_retryable_writes_unified.py index da16166ec..036c410e2 100644 --- a/test/test_retryable_writes_unified.py +++ b/test/test_retryable_writes_unified.py @@ -17,14 +17,20 @@ from __future__ import annotations import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_writes", "unified") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/tools/synchro.py b/tools/synchro.py index 2a2969679..e0b208377 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -222,6 +222,7 @@ converted_tests = [ "test_read_concern.py", "test_retryable_reads.py", "test_retryable_writes.py", + "test_retryable_writes_unified.py", "test_session.py", "test_transactions.py", "unified_format.py", From 6b141d1f5bb2fb9a301932ce45870b21ddd8ea21 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 31 Jan 2025 15:51:00 -0500 Subject: [PATCH 13/28] PYTHON-5096 - Convert test.test_retryable_reads_unified to async (#2112) --- .../test_retryable_reads_unified.py | 46 +++++++++++++++++++ test/test_retryable_reads_unified.py | 8 +++- tools/synchro.py | 1 + 3 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_retryable_reads_unified.py diff --git a/test/asynchronous/test_retryable_reads_unified.py b/test/asynchronous/test_retryable_reads_unified.py new file mode 100644 index 000000000..e62d60681 --- /dev/null +++ b/test/asynchronous/test_retryable_reads_unified.py @@ -0,0 +1,46 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Retryable Reads unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified") + +# Generate unified tests. +# PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects. +globals().update( + generate_test_classes( + TEST_PATH, + module=__name__, + expected_failures=["ListDatabaseObjects .*", "ListCollectionObjects .*", "MapReduce .*"], + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_retryable_reads_unified.py b/test/test_retryable_reads_unified.py index 3f8740cf4..b1c6435c9 100644 --- a/test/test_retryable_reads_unified.py +++ b/test/test_retryable_reads_unified.py @@ -15,6 +15,7 @@ """Test the Retryable Reads unified spec tests.""" from __future__ import annotations +import os import sys from pathlib import Path @@ -23,8 +24,13 @@ sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = Path(__file__).parent / "retryable_reads/unified" +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified") # Generate unified tests. # PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects. diff --git a/tools/synchro.py b/tools/synchro.py index e0b208377..eb44ef4ac 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -221,6 +221,7 @@ converted_tests = [ "test_raw_bson.py", "test_read_concern.py", "test_retryable_reads.py", + "test_retryable_reads_unified.py", "test_retryable_writes.py", "test_retryable_writes_unified.py", "test_session.py", From 702c86c02cbf989fc38df0ddc916260ccf43ac43 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 3 Feb 2025 08:52:54 -0500 Subject: [PATCH 14/28] PYTHON-5095 - Convert test_read_write_concern_spec to async (#2111) --- test/asynchronous/__init__.py | 8 +- .../test_read_write_concern_spec.py | 344 ++++++++++++++++++ test/test_read_write_concern_spec.py | 22 +- tools/synchro.py | 1 + 4 files changed, 364 insertions(+), 11 deletions(-) create mode 100644 test/asynchronous/test_read_write_concern_spec.py diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 76fae407d..a6ba29baa 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1176,15 +1176,15 @@ class AsyncPyMongoTestCase(unittest.TestCase): async def disable_replication(self, client): """Disable replication on all secondaries.""" - for h, p in client.secondaries: + for h, p in await client.secondaries: secondary = await self.async_single_client(h, p) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") + await secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") async def enable_replication(self, client): """Enable replication on all secondaries.""" - for h, p in client.secondaries: + for h, p in await client.secondaries: secondary = await self.async_single_client(h, p) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") + await secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") class AsyncUnitTest(AsyncPyMongoTestCase): diff --git a/test/asynchronous/test_read_write_concern_spec.py b/test/asynchronous/test_read_write_concern_spec.py new file mode 100644 index 000000000..3fb13ba19 --- /dev/null +++ b/test/asynchronous/test_read_write_concern_spec.py @@ -0,0 +1,344 @@ +# Copyright 2018-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the read and write concern tests.""" +from __future__ import annotations + +import json +import os +import sys +import warnings +from pathlib import Path + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils import OvertCommandListener + +from pymongo import DESCENDING +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + WriteConcernError, + WriteError, + WTimeoutError, +) +from pymongo.operations import IndexModel, InsertOne +from pymongo.read_concern import ReadConcern +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") + + +class TestReadWriteConcernSpec(AsyncIntegrationTest): + async def test_omit_default_read_write_concern(self): + listener = OvertCommandListener() + # Client with default readConcern and writeConcern + client = await self.async_rs_or_single_client(event_listeners=[listener]) + collection = client.pymongo_test.collection + # Prepare for tests of find() and aggregate(). + await collection.insert_many([{} for _ in range(10)]) + self.addAsyncCleanup(collection.drop) + self.addAsyncCleanup(client.pymongo_test.collection2.drop) + # Commands MUST NOT send the default read/write concern to the server. + + async def rename_and_drop(): + # Ensure collection exists. + await collection.insert_one({}) + await collection.rename("collection2") + await client.pymongo_test.collection2.drop() + + async def insert_command_default_write_concern(): + await collection.database.command( + "insert", "collection", documents=[{}], write_concern=WriteConcern() + ) + + async def aggregate_op(): + await (await collection.aggregate([])).to_list() + + ops = [ + ("aggregate", aggregate_op), + ("find", lambda: collection.find().to_list()), + ("insert_one", lambda: collection.insert_one({})), + ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: collection.delete_one({})), + ("delete_many", lambda: collection.delete_many({})), + ("bulk_write", lambda: collection.bulk_write([InsertOne({})])), + ("rename_and_drop", rename_and_drop), + ("command", insert_command_default_write_concern), + ] + + for name, f in ops: + listener.reset() + await f() + + self.assertGreaterEqual(len(listener.started_events), 1) + for _i, event in enumerate(listener.started_events): + self.assertNotIn( + "readConcern", + event.command, + f"{name} sent default readConcern with {event.command_name}", + ) + self.assertNotIn( + "writeConcern", + event.command, + f"{name} sent default writeConcern with {event.command_name}", + ) + + async def assertWriteOpsRaise(self, write_concern, expected_exception): + wc = write_concern.document + # Set socket timeout to avoid indefinite stalls + client = await self.async_rs_or_single_client( + w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 + ) + db = client.get_database("pymongo_test") + coll = db.test + + async def insert_command(): + await coll.database.command( + "insert", + "new_collection", + documents=[{}], + writeConcern=write_concern.document, + parse_write_concern_error=True, + ) + + ops = [ + ("insert_one", lambda: coll.insert_one({})), + ("insert_many", lambda: coll.insert_many([{}, {}])), + ("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: coll.delete_one({})), + ("delete_many", lambda: coll.delete_many({})), + ("bulk_write", lambda: coll.bulk_write([InsertOne({})])), + ("command", insert_command), + ("aggregate", lambda: coll.aggregate([{"$out": "out"}])), + # SERVER-46668 Delete all the documents in the collection to + # workaround a hang in createIndexes. + ("delete_many", lambda: coll.delete_many({})), + ("create_index", lambda: coll.create_index([("a", DESCENDING)])), + ("create_indexes", lambda: coll.create_indexes([IndexModel("b")])), + ("drop_index", lambda: coll.drop_index([("a", DESCENDING)])), + ("create", lambda: db.create_collection("new")), + ("rename", lambda: coll.rename("new")), + ("drop", lambda: db.new.drop()), + ] + # SERVER-47194: dropDatabase does not respect wtimeout in 3.6. + if async_client_context.version[:2] != (3, 6): + ops.append(("drop_database", lambda: client.drop_database(db))) + + for name, f in ops: + # Ensure insert_many and bulk_write still raise BulkWriteError. + if name in ("insert_many", "bulk_write"): + expected = BulkWriteError + else: + expected = expected_exception + with self.assertRaises(expected, msg=name) as cm: + await f() + if expected == BulkWriteError: + bulk_result = cm.exception.details + assert bulk_result is not None + wc_errors = bulk_result["writeConcernErrors"] + self.assertTrue(wc_errors) + + @async_client_context.require_replica_set + async def test_raise_write_concern_error(self): + self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test") + assert async_client_context.w is not None + await self.assertWriteOpsRaise( + WriteConcern(w=async_client_context.w + 1, wtimeout=1), WriteConcernError + ) + + @async_client_context.require_secondaries_count(1) + @async_client_context.require_test_commands + async def test_raise_wtimeout(self): + self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test") + self.addAsyncCleanup(self.enable_replication, async_client_context.client) + # Disable replication to guarantee a wtimeout error. + await self.disable_replication(async_client_context.client) + await self.assertWriteOpsRaise( + WriteConcern(w=async_client_context.w, wtimeout=1), WTimeoutError + ) + + @async_client_context.require_failCommand_fail_point + async def test_error_includes_errInfo(self): + expected_wce = { + "code": 100, + "codeName": "UnsatisfiableWriteConcern", + "errmsg": "Not enough data-bearing nodes", + "errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}}, + } + cause_wce = { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, + } + async with self.fail_point(cause_wce): + # Write concern error on insert includes errInfo. + with self.assertRaises(WriteConcernError) as ctx: + await self.db.test.insert_one({}) + self.assertEqual(ctx.exception.details, expected_wce) + + # Test bulk_write as well. + with self.assertRaises(BulkWriteError) as ctx: + await self.db.test.bulk_write([InsertOne({})]) + expected_details = { + "writeErrors": [], + "writeConcernErrors": [expected_wce], + "nInserted": 1, + "nUpserted": 0, + "nMatched": 0, + "nModified": 0, + "nRemoved": 0, + "upserted": [], + } + self.assertEqual(ctx.exception.details, expected_details) + + @async_client_context.require_version_min(4, 9) + async def test_write_error_details_exposes_errinfo(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(event_listeners=[listener]) + db = client.errinfotest + self.addAsyncCleanup(client.drop_database, "errinfotest") + validator = {"x": {"$type": "string"}} + await db.create_collection("test", validator=validator) + with self.assertRaises(WriteError) as ctx: + await db.test.insert_one({"x": 1}) + self.assertEqual(ctx.exception.code, 121) + self.assertIsNotNone(ctx.exception.details) + assert ctx.exception.details is not None + self.assertIsNotNone(ctx.exception.details.get("errInfo")) + for event in listener.succeeded_events: + if event.command_name == "insert": + self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details) + break + else: + self.fail("Couldn't find insert event.") + + +def normalize_write_concern(concern): + result = {} + for key in concern: + if key.lower() == "wtimeoutms": + result["wtimeout"] = concern[key] + elif key == "journal": + result["j"] = concern[key] + else: + result[key] = concern[key] + return result + + +def create_connection_string_test(test_case): + def run_test(self): + uri = test_case["uri"] + valid = test_case["valid"] + warning = test_case["warning"] + + if not valid: + if warning is False: + self.assertRaises( + (ConfigurationError, ValueError), AsyncMongoClient, uri, connect=False + ) + else: + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + self.assertRaises(UserWarning, AsyncMongoClient, uri, connect=False) + else: + client = AsyncMongoClient(uri, connect=False) + if "writeConcern" in test_case: + document = client.write_concern.document + self.assertEqual(document, normalize_write_concern(test_case["writeConcern"])) + if "readConcern" in test_case: + document = client.read_concern.document + self.assertEqual(document, test_case["readConcern"]) + + return run_test + + +def create_document_test(test_case): + def run_test(self): + valid = test_case["valid"] + + if "writeConcern" in test_case: + normalized = normalize_write_concern(test_case["writeConcern"]) + if not valid: + self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized) + else: + write_concern = WriteConcern(**normalized) + self.assertEqual(write_concern.document, test_case["writeConcernDocument"]) + self.assertEqual(write_concern.acknowledged, test_case["isAcknowledged"]) + self.assertEqual(write_concern.is_server_default, test_case["isServerDefault"]) + if "readConcern" in test_case: + # Any string for 'level' is equally valid + read_concern = ReadConcern(**test_case["readConcern"]) + self.assertEqual(read_concern.document, test_case["readConcernDocument"]) + self.assertEqual(not bool(read_concern.level), test_case["isServerDefault"]) + + return run_test + + +def create_tests(): + for dirpath, _, filenames in os.walk(TEST_PATH): + dirname = os.path.split(dirpath)[-1] + + if dirname == "operation": + # This directory is tested by TestOperations. + continue + elif dirname == "connection-string": + create_test = create_connection_string_test + else: + create_test = create_document_test + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as test_stream: + test_cases = json.load(test_stream)["tests"] + + fname = os.path.splitext(filename)[0] + for test_case in test_cases: + new_test = create_test(test_case) + test_name = "test_{}_{}_{}".format( + dirname.replace("-", "_"), + fname.replace("-", "_"), + str(test_case["description"].lower().replace(" ", "_")), + ) + + new_test.__name__ = test_name + setattr(TestReadWriteConcernSpec, new_test.__name__, new_test) + + +create_tests() + + +# Generate unified tests. +# PyMongo does not support MapReduce. +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "operation"), + module=__name__, + expected_failures=["MapReduce .*"], + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index db53b67ae..8543991f7 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -19,6 +19,7 @@ import json import os import sys import warnings +from pathlib import Path sys.path[0:0] = [""] @@ -39,7 +40,13 @@ from pymongo.read_concern import ReadConcern from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "read_write_concern") +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") class TestReadWriteConcernSpec(IntegrationTest): @@ -47,7 +54,6 @@ class TestReadWriteConcernSpec(IntegrationTest): listener = OvertCommandListener() # Client with default readConcern and writeConcern client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) @@ -66,9 +72,12 @@ class TestReadWriteConcernSpec(IntegrationTest): "insert", "collection", documents=[{}], write_concern=WriteConcern() ) + def aggregate_op(): + (collection.aggregate([])).to_list() + ops = [ - ("aggregate", lambda: list(collection.aggregate([]))), - ("find", lambda: list(collection.find())), + ("aggregate", aggregate_op), + ("find", lambda: collection.find().to_list()), ("insert_one", lambda: collection.insert_one({})), ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), @@ -207,7 +216,6 @@ class TestReadWriteConcernSpec(IntegrationTest): def test_write_error_details_exposes_errinfo(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) db = client.errinfotest self.addCleanup(client.drop_database, "errinfotest") validator = {"x": {"$type": "string"}} @@ -286,7 +294,7 @@ def create_document_test(test_case): def create_tests(): - for dirpath, _, filenames in os.walk(_TEST_PATH): + for dirpath, _, filenames in os.walk(TEST_PATH): dirname = os.path.split(dirpath)[-1] if dirname == "operation": @@ -321,7 +329,7 @@ create_tests() # PyMongo does not support MapReduce. globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "operation"), + os.path.join(TEST_PATH, "operation"), module=__name__, expected_failures=["MapReduce .*"], ) diff --git a/tools/synchro.py b/tools/synchro.py index eb44ef4ac..ba0a4712e 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -220,6 +220,7 @@ converted_tests = [ "test_on_demand_csfle.py", "test_raw_bson.py", "test_read_concern.py", + "test_read_write_concern_spec.py", "test_retryable_reads.py", "test_retryable_reads_unified.py", "test_retryable_writes.py", From 665eb9a4b83029b43266e0de45339a2ae8764dee Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 3 Feb 2025 14:37:37 -0500 Subject: [PATCH 15/28] PYTHON-5105 - Convert test.test_srv_polling to async (#2124) --- test/asynchronous/test_srv_polling.py | 361 ++++++++++++++++++++++++++ test/test_srv_polling.py | 23 +- tools/synchro.py | 1 + 3 files changed, 378 insertions(+), 7 deletions(-) create mode 100644 test/asynchronous/test_srv_polling.py diff --git a/test/asynchronous/test_srv_polling.py b/test/asynchronous/test_srv_polling.py new file mode 100644 index 000000000..763c80e66 --- /dev/null +++ b/test/asynchronous/test_srv_polling.py @@ -0,0 +1,361 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the SRV support tests.""" +from __future__ import annotations + +import asyncio +import sys +import time +from typing import Any + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest +from test.utils import FunctionCallRecorder, async_wait_until + +import pymongo +from pymongo import common +from pymongo.errors import ConfigurationError +from pymongo.srv_resolver import _have_dnspython + +_IS_SYNC = False + +WAIT_TIME = 0.1 + + +class SrvPollingKnobs: + def __init__( + self, + ttl_time=None, + min_srv_rescan_interval=None, + nodelist_callback=None, + count_resolver_calls=False, + ): + self.ttl_time = ttl_time + self.min_srv_rescan_interval = min_srv_rescan_interval + self.nodelist_callback = nodelist_callback + self.count_resolver_calls = count_resolver_calls + + self.old_min_srv_rescan_interval = None + self.old_dns_resolver_response = None + + def enable(self): + self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL + self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl + + if self.min_srv_rescan_interval is not None: + common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval + + def mock_get_hosts_and_min_ttl(resolver, *args): + assert self.old_dns_resolver_response is not None + nodes, ttl = self.old_dns_resolver_response(resolver) + if self.nodelist_callback is not None: + nodes = self.nodelist_callback() + if self.ttl_time is not None: + ttl = self.ttl_time + return nodes, ttl + + patch_func: Any + if self.count_resolver_calls: + patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl) + else: + patch_func = mock_get_hosts_and_min_ttl + + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + + def __enter__(self): + self.enable() + + def disable(self): + common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + self.old_dns_resolver_response + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disable() + + +class TestSrvPolling(AsyncPyMongoTestCase): + BASE_SRV_RESPONSE = [ + ("localhost.test.build.10gen.cc", 27017), + ("localhost.test.build.10gen.cc", 27018), + ] + + CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc" + + async def asyncSetUp(self): + # Patch timeouts to ensure short rescan SRV interval. + self.client_knobs = client_knobs( + heartbeat_frequency=WAIT_TIME, + min_heartbeat_interval=WAIT_TIME, + events_queue_frequency=WAIT_TIME, + ) + self.client_knobs.enable() + + async def asyncTearDown(self): + self.client_knobs.disable() + + def get_nodelist(self, client): + return client._topology.description.server_descriptions().keys() + + async def assert_nodelist_change(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)): + """Check if the client._topology eventually sees all nodes in the + expected_nodelist. + """ + + def predicate(): + nodelist = self.get_nodelist(client) + if set(expected_nodelist) == set(nodelist): + return True + return False + + await async_wait_until(predicate, "see expected nodelist", timeout=timeout) + + async def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)): + """Check if the client._topology ever deviates from seeing all nodes + in the expected_nodelist. Consistency is checked after sleeping for + (WAIT_TIME * 10) seconds. Also check that the resolver is called at + least once. + """ + + def predicate(): + if set(expected_nodelist) == set(self.get_nodelist(client)): + return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1 + return False + + await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) + + nodelist = self.get_nodelist(client) + if set(expected_nodelist) != set(nodelist): + msg = "Client nodelist %s changed unexpectedly (expected %s)" + raise self.fail(msg % (nodelist, expected_nodelist)) + self.assertGreaterEqual( + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + 1, + "resolver was never called", + ) + return True + + async def run_scenario(self, dns_response, expect_change): + self.assertEqual(_have_dnspython(), True) + if callable(dns_response): + dns_resolver_response = dns_response + else: + + def dns_resolver_response(): + return dns_response + + if expect_change: + assertion_method = self.assert_nodelist_change + count_resolver_calls = False + expected_response = dns_response + else: + assertion_method = self.assert_nodelist_nochange + count_resolver_calls = True + expected_response = self.BASE_SRV_RESPONSE + + # Patch timeouts to ensure short test running times. + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING) + await client.aconnect() + await self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) + # Patch list of hosts returned by DNS query. + with SrvPollingKnobs( + nodelist_callback=dns_resolver_response, count_resolver_calls=count_resolver_calls + ): + await assertion_method(expected_response, client) + + async def test_addition(self): + response = self.BASE_SRV_RESPONSE[:] + response.append(("localhost.test.build.10gen.cc", 27019)) + await self.run_scenario(response, True) + + async def test_removal(self): + response = self.BASE_SRV_RESPONSE[:] + response.remove(("localhost.test.build.10gen.cc", 27018)) + await self.run_scenario(response, True) + + async def test_replace_one(self): + response = self.BASE_SRV_RESPONSE[:] + response.remove(("localhost.test.build.10gen.cc", 27018)) + response.append(("localhost.test.build.10gen.cc", 27019)) + await self.run_scenario(response, True) + + async def test_replace_both_with_one(self): + response = [("localhost.test.build.10gen.cc", 27019)] + await self.run_scenario(response, True) + + async def test_replace_both_with_two(self): + response = [ + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] + await self.run_scenario(response, True) + + async def test_dns_failures(self): + from dns import exception + + for exc in (exception.FormError, exception.TooBig, exception.Timeout): + + def response_callback(*args): + raise exc("DNS Failure!") + + await self.run_scenario(response_callback, False) + + async def test_dns_record_lookup_empty(self): + response: list = [] + await self.run_scenario(response, False) + + async def _test_recover_from_initial(self, initial_callback): + # Construct a valid final response callback distinct from base. + response_final = self.BASE_SRV_RESPONSE[:] + response_final.pop() + + def final_callback(): + return response_final + + with SrvPollingKnobs( + ttl_time=WAIT_TIME, + min_srv_rescan_interval=WAIT_TIME, + nodelist_callback=initial_callback, + count_resolver_calls=True, + ): + # Client uses unpatched method to get initial nodelist + client = self.simple_client(self.CONNECTION_STRING) + await client.aconnect() + # Invalid DNS resolver response should not change nodelist. + await self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) + + with SrvPollingKnobs( + ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, nodelist_callback=final_callback + ): + # Nodelist should reflect new valid DNS resolver response. + await self.assert_nodelist_change(response_final, client) + + async def test_recover_from_initially_empty_seedlist(self): + def empty_seedlist(): + return [] + + await self._test_recover_from_initial(empty_seedlist) + + async def test_recover_from_initially_erroring_seedlist(self): + def erroring_seedlist(): + raise ConfigurationError + + await self._test_recover_from_initial(erroring_seedlist) + + async def test_10_all_dns_selected(self): + response = [ + ("localhost.test.build.10gen.cc", 27017), + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] + + def nodelist_callback(): + return response + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0) + await client.aconnect() + with SrvPollingKnobs(nodelist_callback=nodelist_callback): + await self.assert_nodelist_change(response, client) + + async def test_11_all_dns_selected(self): + response = [ + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] + + def nodelist_callback(): + return response + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) + await client.aconnect() + with SrvPollingKnobs(nodelist_callback=nodelist_callback): + await self.assert_nodelist_change(response, client) + + async def test_12_new_dns_randomly_selected(self): + response = [ + ("localhost.test.build.10gen.cc", 27020), + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27017), + ] + + def nodelist_callback(): + return response + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) + await client.aconnect() + with SrvPollingKnobs(nodelist_callback=nodelist_callback): + await asyncio.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) + final_topology = set(client.topology_description.server_descriptions()) + self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology) + self.assertEqual(len(final_topology), 2) + + async def test_does_not_flipflop(self): + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1) + await client.aconnect() + old = set(client.topology_description.server_descriptions()) + await asyncio.sleep(4 * WAIT_TIME) + new = set(client.topology_description.server_descriptions()) + self.assertSetEqual(old, new) + + async def test_srv_service_name(self): + # Construct a valid final response callback distinct from base. + response = [ + ("localhost.test.build.10gen.cc.", 27019), + ("localhost.test.build.10gen.cc.", 27020), + ] + + def nodelist_callback(): + return response + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client( + "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname" + ) + await client.aconnect() + with SrvPollingKnobs(nodelist_callback=nodelist_callback): + await self.assert_nodelist_change(response, client) + + async def test_srv_waits_to_poll(self): + modified = [("localhost.test.build.10gen.cc", 27019)] + + def resolver_response(): + return modified + + with SrvPollingKnobs( + ttl_time=WAIT_TIME, + min_srv_rescan_interval=WAIT_TIME, + nodelist_callback=resolver_response, + ): + client = self.simple_client(self.CONNECTION_STRING) + await client.aconnect() + with self.assertRaises(AssertionError): + await self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2) + + def test_import_dns_resolver(self): + # Regression test for PYTHON-4407 + import dns.resolver + + self.assertTrue(hasattr(dns.resolver, "resolve")) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index e01552bf7..86fad6d90 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -15,8 +15,9 @@ """Run the SRV support tests.""" from __future__ import annotations +import asyncio import sys -from time import sleep +import time from typing import Any sys.path[0:0] = [""] @@ -28,7 +29,8 @@ import pymongo from pymongo import common from pymongo.errors import ConfigurationError from pymongo.srv_resolver import _have_dnspython -from pymongo.synchronous.mongo_client import MongoClient + +_IS_SYNC = True WAIT_TIME = 0.1 @@ -168,6 +170,7 @@ class TestSrvPolling(PyMongoTestCase): # Patch timeouts to ensure short test running times. with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING) + client._connect() self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) # Patch list of hosts returned by DNS query. with SrvPollingKnobs( @@ -232,6 +235,7 @@ class TestSrvPolling(PyMongoTestCase): ): # Client uses unpatched method to get initial nodelist client = self.simple_client(self.CONNECTION_STRING) + client._connect() # Invalid DNS resolver response should not change nodelist. self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) @@ -265,6 +269,7 @@ class TestSrvPolling(PyMongoTestCase): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0) + client._connect() with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -279,6 +284,7 @@ class TestSrvPolling(PyMongoTestCase): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) + client._connect() with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -294,8 +300,9 @@ class TestSrvPolling(PyMongoTestCase): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) + client._connect() with SrvPollingKnobs(nodelist_callback=nodelist_callback): - sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) + time.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) final_topology = set(client.topology_description.server_descriptions()) self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology) self.assertEqual(len(final_topology), 2) @@ -303,8 +310,9 @@ class TestSrvPolling(PyMongoTestCase): def test_does_not_flipflop(self): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1) + client._connect() old = set(client.topology_description.server_descriptions()) - sleep(4 * WAIT_TIME) + time.sleep(4 * WAIT_TIME) new = set(client.topology_description.server_descriptions()) self.assertSetEqual(old, new) @@ -322,6 +330,7 @@ class TestSrvPolling(PyMongoTestCase): client = self.simple_client( "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname" ) + client._connect() with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -337,9 +346,9 @@ class TestSrvPolling(PyMongoTestCase): nodelist_callback=resolver_response, ): client = self.simple_client(self.CONNECTION_STRING) - self.assertRaises( - AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2 - ) + client._connect() + with self.assertRaises(AssertionError): + self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2) def test_import_dns_resolver(self): # Regression test for PYTHON-4407 diff --git a/tools/synchro.py b/tools/synchro.py index ba0a4712e..6317cb84f 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -226,6 +226,7 @@ converted_tests = [ "test_retryable_writes.py", "test_retryable_writes_unified.py", "test_session.py", + "test_srv_polling.py", "test_transactions.py", "unified_format.py", ] From 1fda6a2310511a1e39e0476418f6477495441101 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 3 Feb 2025 15:48:04 -0500 Subject: [PATCH 16/28] PYTHON-5110 - Convert test.test_unified_format to async (#2130) --- test/asynchronous/test_unified_format.py | 99 ++++++++++++++++++++++++ test/asynchronous/unified_format.py | 2 +- test/test_unified_format.py | 17 ++-- test/unified_format.py | 2 +- tools/synchro.py | 1 + 5 files changed, 114 insertions(+), 7 deletions(-) create mode 100644 test/asynchronous/test_unified_format.py diff --git a/test/asynchronous/test_unified_format.py b/test/asynchronous/test_unified_format.py new file mode 100644 index 000000000..a005739e9 --- /dev/null +++ b/test/asynchronous/test_unified_format.py @@ -0,0 +1,99 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Any + +sys.path[0:0] = [""] + +from test import UnitTest, unittest +from test.asynchronous.unified_format import MatchEvaluatorUtil, generate_test_classes + +from bson import ObjectId + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format") + + +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "valid-pass"), + module=__name__, + class_name_prefix="UnifiedTestFormat", + expected_failures=[ + "Client side error in command starting transaction", # PYTHON-1894 + ], + RUN_ON_SERVERLESS=False, + ) +) + + +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "valid-fail"), + module=__name__, + class_name_prefix="UnifiedTestFormat", + bypass_test_generation_errors=True, + expected_failures=[ + ".*", # All tests expected to fail + ], + RUN_ON_SERVERLESS=False, + ) +) + + +class TestMatchEvaluatorUtil(UnitTest): + def setUp(self): + self.match_evaluator = MatchEvaluatorUtil(self) + + def test_unsetOrMatches(self): + spec: dict[str, Any] = {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}} + for actual in [{}, {"y": 2}, None]: + self.match_evaluator.match_result(spec, actual) + + spec = {"x": {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}} + for actual in [{}, {"x": {}}, {"x": {"y": 2}}]: + self.match_evaluator.match_result(spec, actual) + + spec = {"y": {"$$unsetOrMatches": {"$$exists": True}}} + self.match_evaluator.match_result(spec, {}) + self.match_evaluator.match_result(spec, {"y": 2}) + self.match_evaluator.match_result(spec, {"x": 1}) + self.match_evaluator.match_result(spec, {"y": {}}) + + def test_type(self): + self.match_evaluator.match_result( + { + "operationType": "insert", + "ns": {"db": "change-stream-tests", "coll": "test"}, + "fullDocument": {"_id": {"$$type": "objectId"}, "x": 1}, + }, + { + "operationType": "insert", + "fullDocument": {"_id": ObjectId("5fc93511ac93941052098f0c"), "x": 1}, + "ns": {"db": "change-stream-tests", "coll": "test"}, + }, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 52d964eb3..37248e9ad 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -711,7 +711,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): return await target.command(**kwargs) async def _databaseOperation_runCursorCommand(self, target, **kwargs): - return list(await self._databaseOperation_createCommandCursor(target, **kwargs)) + return await (await self._databaseOperation_createCommandCursor(target, **kwargs)).to_list() async def _databaseOperation_createCommandCursor(self, target, **kwargs): self.__raise_if_unsupported("createCommandCursor", target, AsyncDatabase) diff --git a/test/test_unified_format.py b/test/test_unified_format.py index 1b3a13423..05f58d5d0 100644 --- a/test/test_unified_format.py +++ b/test/test_unified_format.py @@ -15,21 +15,28 @@ from __future__ import annotations import os import sys +from pathlib import Path from typing import Any sys.path[0:0] = [""] -from test import unittest +from test import UnitTest, unittest from test.unified_format import MatchEvaluatorUtil, generate_test_classes from bson import ObjectId -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "unified-test-format") +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format") globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "valid-pass"), + os.path.join(TEST_PATH, "valid-pass"), module=__name__, class_name_prefix="UnifiedTestFormat", expected_failures=[ @@ -42,7 +49,7 @@ globals().update( globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "valid-fail"), + os.path.join(TEST_PATH, "valid-fail"), module=__name__, class_name_prefix="UnifiedTestFormat", bypass_test_generation_errors=True, @@ -54,7 +61,7 @@ globals().update( ) -class TestMatchEvaluatorUtil(unittest.TestCase): +class TestMatchEvaluatorUtil(UnitTest): def setUp(self): self.match_evaluator = MatchEvaluatorUtil(self) diff --git a/test/unified_format.py b/test/unified_format.py index 372eb8abb..e66b57f9d 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -708,7 +708,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): return target.command(**kwargs) def _databaseOperation_runCursorCommand(self, target, **kwargs): - return list(self._databaseOperation_createCommandCursor(target, **kwargs)) + return (self._databaseOperation_createCommandCursor(target, **kwargs)).to_list() def _databaseOperation_createCommandCursor(self, target, **kwargs): self.__raise_if_unsupported("createCommandCursor", target, Database) diff --git a/tools/synchro.py b/tools/synchro.py index 6317cb84f..f83c5d4ca 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -228,6 +228,7 @@ converted_tests = [ "test_session.py", "test_srv_polling.py", "test_transactions.py", + "test_unified_format.py", "unified_format.py", ] From b47143cd1047f388d2a76df09f52b19a728ac4ab Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 07:42:13 -0500 Subject: [PATCH 17/28] PYTHON-4864 - Create async version of SpecRunnerThread (#2094) --- test/asynchronous/helpers.py | 37 ++++++++++++++++++++++++++ test/asynchronous/unified_format.py | 16 +++++------ test/asynchronous/utils_spec_runner.py | 29 ++++++++++---------- test/helpers.py | 37 ++++++++++++++++++++++++++ test/unified_format.py | 2 +- test/utils_spec_runner.py | 13 +++++---- tools/synchro.py | 1 + 7 files changed, 104 insertions(+), 31 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index b5fc5d8ac..7758f281e 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -15,6 +15,7 @@ """Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -30,6 +31,8 @@ import unittest import warnings from asyncio import iscoroutinefunction +from pymongo._asyncio_task import create_task + try: import ipaddress @@ -369,3 +372,37 @@ class SystemCertsPatcher: os.environ.pop("SSL_CERT_FILE") else: os.environ["SSL_CERT_FILE"] = self.original_certs + + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class ConcurrentRunner(PARENT): + def __init__(self, name, *args, **kwargs): + if _IS_SYNC: + super().__init__(*args, **kwargs) + self.name = name + self.stopped = False + self.task = None + if "target" in kwargs: + self.target = kwargs["target"] + + if not _IS_SYNC: + + async def start(self): + self.task = create_task(self.run(), name=self.name) + + async def join(self, timeout: float | None = 0): # type: ignore[override] + if self.task is not None: + await asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + async def run(self): + if self.target: + await self.target() + self.stopped = True diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 37248e9ad..149aad978 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -35,6 +35,7 @@ from test.asynchronous import ( client_knobs, unittest, ) +from test.asynchronous.utils_spec_runner import SpecRunnerTask from test.unified_format_shared import ( KMS_TLS_OPTS, PLACEHOLDER_MAP, @@ -58,7 +59,6 @@ from test.utils import ( snake_to_camel, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread from test.version import Version from typing import Any, Dict, List, Mapping, Optional @@ -382,8 +382,8 @@ class EntityMapUtil: return elif entity_type == "thread": name = spec["id"] - thread = SpecRunnerThread(name) - thread.start() + thread = SpecRunnerTask(name) + await thread.start() self[name] = thread return @@ -1177,16 +1177,16 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): wait_until(primary_changed, "change primary", timeout=timeout) - def _testOperation_runOnThread(self, spec): + async def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"])) - def _testOperation_waitForThread(self, spec): + async def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.stop() - thread.join(10) + await thread.stop() + await thread.join(10) if thread.exc: raise thread.exc self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index b79e5258b..d10337431 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -18,11 +18,11 @@ from __future__ import annotations import asyncio import functools import os -import threading import unittest from asyncio import iscoroutinefunction from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs +from test.asynchronous.helpers import ConcurrentRunner from test.utils import ( CMAPListener, CompareType, @@ -47,6 +47,7 @@ from pymongo.asynchronous import client_session from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -55,38 +56,36 @@ from pymongo.write_concern import WriteConcern _IS_SYNC = False -class SpecRunnerThread(threading.Thread): +class SpecRunnerTask(ConcurrentRunner): def __init__(self, name): - super().__init__() - self.name = name + super().__init__(name) self.exc = None self.daemon = True - self.cond = threading.Condition() + self.cond = _async_create_condition(_async_create_lock()) self.ops = [] - self.stopped = False - def schedule(self, work): + async def schedule(self, work): self.ops.append(work) - with self.cond: + async with self.cond: self.cond.notify() - def stop(self): + async def stop(self): self.stopped = True - with self.cond: + async with self.cond: self.cond.notify() - def run(self): + async def run(self): while not self.stopped or self.ops: if not self.ops: - with self.cond: - self.cond.wait(10) + async with self.cond: + await _async_cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) - work() + await work() except Exception as exc: self.exc = exc - self.stop() + await self.stop() class AsyncSpecTestCreator: diff --git a/test/helpers.py b/test/helpers.py index 11d5ab037..bd9e23bba 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -15,6 +15,7 @@ """Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -30,6 +31,8 @@ import unittest import warnings from asyncio import iscoroutinefunction +from pymongo._asyncio_task import create_task + try: import ipaddress @@ -369,3 +372,37 @@ class SystemCertsPatcher: os.environ.pop("SSL_CERT_FILE") else: os.environ["SSL_CERT_FILE"] = self.original_certs + + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class ConcurrentRunner(PARENT): + def __init__(self, name, *args, **kwargs): + if _IS_SYNC: + super().__init__(*args, **kwargs) + self.name = name + self.stopped = False + self.task = None + if "target" in kwargs: + self.target = kwargs["target"] + + if not _IS_SYNC: + + def start(self): + self.task = create_task(self.run(), name=self.name) + + def join(self, timeout: float | None = 0): # type: ignore[override] + if self.task is not None: + asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + def run(self): + if self.target: + self.target() + self.stopped = True diff --git a/test/unified_format.py b/test/unified_format.py index e66b57f9d..b2e6ae1e8 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1167,7 +1167,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + thread.schedule(functools.partial(self.run_entity_operation, spec["operation"])) def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 4508502cd..6a62112af 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -18,11 +18,11 @@ from __future__ import annotations import asyncio import functools import os -import threading import unittest from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs +from test.helpers import ConcurrentRunner from test.utils import ( CMAPListener, CompareType, @@ -44,6 +44,7 @@ from bson.son import SON from gridfs import GridFSBucket from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from pymongo.lock import _cond_wait, _create_condition, _create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -55,15 +56,13 @@ from pymongo.write_concern import WriteConcern _IS_SYNC = True -class SpecRunnerThread(threading.Thread): +class SpecRunnerThread(ConcurrentRunner): def __init__(self, name): - super().__init__() - self.name = name + super().__init__(name) self.exc = None self.daemon = True - self.cond = threading.Condition() + self.cond = _create_condition(_create_lock()) self.ops = [] - self.stopped = False def schedule(self, work): self.ops.append(work) @@ -79,7 +78,7 @@ class SpecRunnerThread(threading.Thread): while not self.stopped or self.ops: if not self.ops: with self.cond: - self.cond.wait(10) + _cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) diff --git a/tools/synchro.py b/tools/synchro.py index f83c5d4ca..8c03a346e 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -119,6 +119,7 @@ replacements = { "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", } From 68237f78ecda3399518c3700b3593b028eb9eeef Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 09:42:21 -0500 Subject: [PATCH 18/28] PYTHON-5098 - Convert test.test_run_command to async (#2114) --- test/asynchronous/test_run_command.py | 41 +++++++++++++++++++++++++++ test/test_run_command.py | 26 +++++++++++++++-- tools/synchro.py | 1 + 3 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 test/asynchronous/test_run_command.py diff --git a/test/asynchronous/test_run_command.py b/test/asynchronous/test_run_command.py new file mode 100644 index 000000000..3ac8c3270 --- /dev/null +++ b/test/asynchronous/test_run_command.py @@ -0,0 +1,41 @@ +# Copyright 2024-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run Command unified tests.""" +from __future__ import annotations + +import os +import unittest +from pathlib import Path +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command") + + +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "unified"), + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_run_command.py b/test/test_run_command.py index 486a4c7e3..d2ef43b97 100644 --- a/test/test_run_command.py +++ b/test/test_run_command.py @@ -1,15 +1,37 @@ +# Copyright 2024-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run Command unified tests.""" from __future__ import annotations import os import unittest +from pathlib import Path from test.unified_format import generate_test_classes -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "run_command") +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command") globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "unified"), + os.path.join(TEST_PATH, "unified"), module=__name__, ) ) diff --git a/tools/synchro.py b/tools/synchro.py index 8c03a346e..4ac5604f2 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -226,6 +226,7 @@ converted_tests = [ "test_retryable_reads_unified.py", "test_retryable_writes.py", "test_retryable_writes_unified.py", + "test_run_command.py", "test_session.py", "test_srv_polling.py", "test_transactions.py", From 554e1fddb8ac83d19237975cdfa7682c6b0f491c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 09:43:35 -0500 Subject: [PATCH 19/28] PYTHON-5106 - Convert test.test_ssl to async (#2125) --- test/asynchronous/test_ssl.py | 662 ++++++++++++++++++++++++++++++++++ test/test_ssl.py | 30 +- tools/synchro.py | 1 + 3 files changed, 677 insertions(+), 16 deletions(-) create mode 100644 test/asynchronous/test_ssl.py diff --git a/test/asynchronous/test_ssl.py b/test/asynchronous/test_ssl.py new file mode 100644 index 000000000..d50bb220b --- /dev/null +++ b/test/asynchronous/test_ssl.py @@ -0,0 +1,662 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SSL support.""" +from __future__ import annotations + +import os +import pathlib +import socket +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import ( + HAVE_IPADDRESS, + AsyncIntegrationTest, + AsyncPyMongoTestCase, + SkipTest, + async_client_context, + connected, + remove_all_users, + unittest, +) +from test.utils import ( + EventListener, + OvertCommandListener, + cat_files, + ignore_deprecations, +) +from urllib.parse import quote_plus + +from pymongo import AsyncMongoClient, ssl_support +from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure +from pymongo.hello import HelloCompat +from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context +from pymongo.write_concern import WriteConcern + +_HAVE_PYOPENSSL = False +try: + # All of these must be available to use PyOpenSSL + import OpenSSL + import requests + import service_identity + + # Ensure service_identity>=18.1 is installed + from service_identity.pyopenssl import verify_ip_address + + from pymongo.ocsp_support import _load_trusted_ca_certs + + _HAVE_PYOPENSSL = True +except ImportError: + _load_trusted_ca_certs = None # type: ignore + + +if HAVE_SSL: + import ssl + +_IS_SYNC = False + +if _IS_SYNC: + CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "certificates") +else: + CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "certificates") + +CLIENT_PEM = os.path.join(CERT_PATH, "client.pem") +CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem") +CA_PEM = os.path.join(CERT_PATH, "ca.pem") +CA_BUNDLE_PEM = os.path.join(CERT_PATH, "trusted-ca.pem") +CRL_PEM = os.path.join(CERT_PATH, "crl.pem") +MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client" + +# To fully test this start a mongod instance (built with SSL support) like so: +# mongod --dbpath /path/to/data/directory --sslOnNormalPorts \ +# --sslPEMKeyFile /path/to/pymongo/test/certificates/server.pem \ +# --sslCAFile /path/to/pymongo/test/certificates/ca.pem \ +# --sslWeakCertificateValidation +# Also, make sure you have 'server' as an alias for localhost in /etc/hosts +# +# Note: For all replica set tests to pass, the replica set configuration must +# use 'localhost' for the hostname of all hosts. + + +class TestClientSSL(AsyncPyMongoTestCase): + @unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.") + def test_no_ssl_module(self): + # Explicit + self.assertRaises(ConfigurationError, self.simple_client, ssl=True) + + # Implied + self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM) + + @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") + @ignore_deprecations + def test_config_ssl(self): + # Tests various ssl configurations + self.assertRaises(ValueError, self.simple_client, ssl="foo") + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ) + self.assertRaises(TypeError, self.simple_client, ssl=0) + self.assertRaises(TypeError, self.simple_client, ssl=5.5) + self.assertRaises(TypeError, self.simple_client, ssl=[]) + + self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile") + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True) + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[]) + + # Test invalid combinations + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM) + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False + ) + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False + ) + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False + ) + + @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") + def test_use_pyopenssl_when_available(self): + self.assertTrue(_ssl.IS_PYOPENSSL) + + @unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL") + def test_load_trusted_ca_certs(self): + trusted_ca_certs = _load_trusted_ca_certs(CA_BUNDLE_PEM) + self.assertEqual(2, len(trusted_ca_certs)) + + +class TestSSL(AsyncIntegrationTest): + saved_port: int + + async def assertClientWorks(self, client): + coll = client.pymongo_test.ssl_test.with_options( + write_concern=WriteConcern(w=async_client_context.w) + ) + await coll.drop() + await coll.insert_one({"ssl": True}) + self.assertTrue((await coll.find_one())["ssl"]) + await coll.drop() + + @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") + async def asyncSetUp(self): + await super().asyncSetUp() + # MongoClient should connect to the primary by default. + self.saved_port = AsyncMongoClient.PORT + AsyncMongoClient.PORT = await async_client_context.port + + async def asyncTearDown(self): + AsyncMongoClient.PORT = self.saved_port + + @async_client_context.require_tls + async def test_simple_ssl(self): + # Expects the server to be running with ssl and with + # no --sslPEMKeyFile or with --sslWeakCertificateValidation + await self.assertClientWorks(self.client) + + @async_client_context.require_tlsCertificateKeyFile + @ignore_deprecations + async def test_tlsCertificateKeyFilePassword(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL: + self.assertRaises( + ConfigurationError, + self.simple_client, + "localhost", + ssl=True, + tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, + tlsCertificateKeyFilePassword="qwerty", + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=1000, + ) + else: + await connected( + self.simple_client( + "localhost", + ssl=True, + tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, + tlsCertificateKeyFilePassword="qwerty", + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=5000, + **self.credentials, # type: ignore[arg-type] + ) + ) + + uri_fmt = ( + "mongodb://localhost/?ssl=true" + "&tlsCertificateKeyFile=%s&tlsCertificateKeyFilePassword=qwerty" + "&tlsCAFile=%s&serverSelectionTimeoutMS=5000" + ) + await connected( + self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + ) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_no_auth + @ignore_deprecations + async def test_cert_ssl_implicitly_set(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + # + + # test that setting tlsCertificateKeyFile causes ssl to be set to True + client = self.simple_client( + await async_client_context.host, + await async_client_context.port, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) + response = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" in response: + client = self.simple_client( + await async_client_context.pair, + replicaSet=response["setName"], + w=len(response["hosts"]), + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) + + await self.assertClientWorks(client) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_no_auth + @ignore_deprecations + async def test_cert_ssl_validation(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + # + client = self.simple_client( + "localhost", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) + response = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" in response: + if response["primary"].split(":")[0] != "localhost": + raise SkipTest( + "No hosts in the replicaset for 'localhost'. " + "Cannot validate hostname in the certificate" + ) + + client = self.simple_client( + "localhost", + replicaSet=response["setName"], + w=len(response["hosts"]), + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) + + await self.assertClientWorks(client) + + if HAVE_IPADDRESS: + client = self.simple_client( + "127.0.0.1", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) + await self.assertClientWorks(client) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_no_auth + @ignore_deprecations + async def test_cert_ssl_uri_support(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + # + uri_fmt = ( + "mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates" + "=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false" + ) + client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) + await self.assertClientWorks(client) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_server_resolvable + @ignore_deprecations + async def test_cert_ssl_validation_hostname_matching(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + ctx = get_ssl_context(None, None, None, None, True, True, False) + self.assertFalse(ctx.check_hostname) + ctx = get_ssl_context(None, None, None, None, True, False, False) + self.assertFalse(ctx.check_hostname) + ctx = get_ssl_context(None, None, None, None, False, True, False) + self.assertFalse(ctx.check_hostname) + ctx = get_ssl_context(None, None, None, None, False, False, False) + self.assertTrue(ctx.check_hostname) + + response = await self.client.admin.command(HelloCompat.LEGACY_CMD) + + with self.assertRaises(ConnectionFailure): + await connected( + self.simple_client( + "server", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=500, + **self.credentials, # type: ignore[arg-type] + ) + ) + + await connected( + self.simple_client( + "server", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=500, + **self.credentials, # type: ignore[arg-type] + ) + ) + + if "setName" in response: + with self.assertRaises(ConnectionFailure): + await connected( + self.simple_client( + "server", + replicaSet=response["setName"], + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=500, + **self.credentials, # type: ignore[arg-type] + ) + ) + + await connected( + self.simple_client( + "server", + replicaSet=response["setName"], + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=500, + **self.credentials, # type: ignore[arg-type] + ) + ) + + @async_client_context.require_tlsCertificateKeyFile + @ignore_deprecations + async def test_tlsCRLFile_support(self): + if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL: + self.assertRaises( + ConfigurationError, + self.simple_client, + "localhost", + ssl=True, + tlsCAFile=CA_PEM, + tlsCRLFile=CRL_PEM, + serverSelectionTimeoutMS=1000, + ) + else: + await connected( + self.simple_client( + "localhost", + ssl=True, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=1000, + **self.credentials, # type: ignore[arg-type] + ) + ) + + with self.assertRaises(ConnectionFailure): + await connected( + self.simple_client( + "localhost", + ssl=True, + tlsCAFile=CA_PEM, + tlsCRLFile=CRL_PEM, + serverSelectionTimeoutMS=1000, + **self.credentials, # type: ignore[arg-type] + ) + ) + + uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000" + await connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore + + uri_fmt = ( + "mongodb://localhost/?ssl=true&tlsCRLFile=%s" + "&tlsCAFile=%s&serverSelectionTimeoutMS=1000" + ) + with self.assertRaises(ConnectionFailure): + await connected( + self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + ) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_server_resolvable + @ignore_deprecations + async def test_validation_with_system_ca_certs(self): + # Expects the server to be running with server.pem and ca.pem. + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + # --sslWeakCertificateValidation + # + self.patch_system_certs(CA_PEM) + with self.assertRaises(ConnectionFailure): + # Server cert is verified but hostname matching fails + await connected( + self.simple_client( + "server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] + ) + + # Server cert is verified. Disable hostname matching. + await connected( + self.simple_client( + "server", + ssl=True, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=1000, + **self.credentials, # type: ignore[arg-type] + ) + ) + + # Server cert and hostname are verified. + await connected( + self.simple_client( + "localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] + ) + + # Server cert and hostname are verified. + await connected( + self.simple_client( + "mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000", + **self.credentials, # type: ignore[arg-type] + ) + ) + + def test_system_certs_config_error(self): + ctx = get_ssl_context(None, None, None, None, True, True, False) + if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr( + ctx, "load_default_certs" + ): + raise SkipTest("Can't test when system CA certificates are loadable.") + + have_certifi = ssl_support.HAVE_CERTIFI + have_wincertstore = ssl_support.HAVE_WINCERTSTORE + # Force the test regardless of environment. + ssl_support.HAVE_CERTIFI = False + ssl_support.HAVE_WINCERTSTORE = False + try: + with self.assertRaises(ConfigurationError): + self.simple_client("mongodb://localhost/?ssl=true") + finally: + ssl_support.HAVE_CERTIFI = have_certifi + ssl_support.HAVE_WINCERTSTORE = have_wincertstore + + def test_certifi_support(self): + if hasattr(ssl, "SSLContext"): + # SSLSocket doesn't provide ca_certs attribute on pythons + # with SSLContext and SSLContext provides no information + # about ca_certs. + raise SkipTest("Can't test when SSLContext available.") + if not ssl_support.HAVE_CERTIFI: + raise SkipTest("Need certifi to test certifi support.") + + have_wincertstore = ssl_support.HAVE_WINCERTSTORE + # Force the test on Windows, regardless of environment. + ssl_support.HAVE_WINCERTSTORE = False + try: + ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False) + ssl_sock = ctx.wrap_socket(socket.socket()) + self.assertEqual(ssl_sock.ca_certs, CA_PEM) + + ctx = get_ssl_context(None, None, None, None, False, False, False) + ssl_sock = ctx.wrap_socket(socket.socket()) + self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where()) + finally: + ssl_support.HAVE_WINCERTSTORE = have_wincertstore + + def test_wincertstore(self): + if sys.platform != "win32": + raise SkipTest("Only valid on Windows.") + if hasattr(ssl, "SSLContext"): + # SSLSocket doesn't provide ca_certs attribute on pythons + # with SSLContext and SSLContext provides no information + # about ca_certs. + raise SkipTest("Can't test when SSLContext available.") + if not ssl_support.HAVE_WINCERTSTORE: + raise SkipTest("Need wincertstore to test wincertstore.") + + ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False) + ssl_sock = ctx.wrap_socket(socket.socket()) + self.assertEqual(ssl_sock.ca_certs, CA_PEM) + + ctx = get_ssl_context(None, None, None, None, False, False, False) + ssl_sock = ctx.wrap_socket(socket.socket()) + self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name) + + @async_client_context.require_auth + @async_client_context.require_tlsCertificateKeyFile + @ignore_deprecations + async def test_mongodb_x509_auth(self): + host, port = await async_client_context.host, await async_client_context.port + self.addAsyncCleanup(remove_all_users, async_client_context.client["$external"]) + + # Give x509 user all necessary privileges. + await async_client_context.create_user( + "$external", + MONGODB_X509_USERNAME, + roles=[ + {"role": "readWriteAnyDatabase", "db": "admin"}, + {"role": "userAdminAnyDatabase", "db": "admin"}, + ], + ) + + noauth = self.simple_client( + await async_client_context.pair, + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) + + with self.assertRaises(OperationFailure): + await noauth.pymongo_test.test.find_one() + + listener = EventListener() + auth = self.simple_client( + await async_client_context.pair, + authMechanism="MONGODB-X509", + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + event_listeners=[listener], + ) + + # No error + await auth.pymongo_test.test.find_one() + names = listener.started_command_names() + if async_client_context.version.at_least(4, 4, -1): + # Speculative auth skips the authenticate command. + self.assertEqual(names, ["find"]) + else: + self.assertEqual(names, ["authenticate", "find"]) + + uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % ( + quote_plus(MONGODB_X509_USERNAME), + host, + port, + ) + client = self.simple_client( + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) + # No error + await client.pymongo_test.test.find_one() + + uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port) + client = self.simple_client( + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) + # No error + await client.pymongo_test.test.find_one() + # Auth should fail if username and certificate do not match + uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % ( + quote_plus("not the username"), + host, + port, + ) + + bad_client = self.simple_client( + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) + + with self.assertRaises(OperationFailure): + await bad_client.pymongo_test.test.find_one() + + bad_client = self.simple_client( + await async_client_context.pair, + username="not the username", + authMechanism="MONGODB-X509", + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) + + with self.assertRaises(OperationFailure): + await bad_client.pymongo_test.test.find_one() + + # Invalid certificate (using CA certificate as client certificate) + uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % ( + quote_plus(MONGODB_X509_USERNAME), + host, + port, + ) + try: + await connected( + self.simple_client( + uri, + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CA_PEM, + serverSelectionTimeoutMS=1000, + ) + ) + except (ConnectionFailure, ConfigurationError): + pass + else: + self.fail("Invalid certificate accepted.") + + @async_client_context.require_tlsCertificateKeyFile + @ignore_deprecations + async def test_connect_with_ca_bundle(self): + def remove(path): + try: + os.remove(path) + except OSError: + pass + + temp_ca_bundle = os.path.join(CERT_PATH, "trusted-ca-bundle.pem") + self.addCleanup(remove, temp_ca_bundle) + # Add the CA cert file to the bundle. + cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM) + async with self.simple_client( + "localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle + ) as client: + self.assertTrue(await client.admin.command("ping")) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_ssl.py b/test/test_ssl.py index 04db9b61a..7d6c3f7cd 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import socket import sys @@ -65,7 +66,13 @@ except ImportError: if HAVE_SSL: import ssl -CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") +_IS_SYNC = True + +if _IS_SYNC: + CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "certificates") +else: + CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "certificates") + CLIENT_PEM = os.path.join(CERT_PATH, "client.pem") CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem") CA_PEM = os.path.join(CERT_PATH, "ca.pem") @@ -144,21 +151,18 @@ class TestSSL(IntegrationTest): ) coll.drop() coll.insert_one({"ssl": True}) - self.assertTrue(coll.find_one()["ssl"]) + self.assertTrue((coll.find_one())["ssl"]) coll.drop() - @classmethod @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() # MongoClient should connect to the primary by default. - cls.saved_port = MongoClient.PORT + self.saved_port = MongoClient.PORT MongoClient.PORT = client_context.port - @classmethod - def tearDownClass(cls): - MongoClient.PORT = cls.saved_port - super().tearDownClass() + def tearDown(self): + MongoClient.PORT = self.saved_port @client_context.require_tls def test_simple_ssl(self): @@ -548,7 +552,6 @@ class TestSSL(IntegrationTest): tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM, ) - self.addCleanup(noauth.close) with self.assertRaises(OperationFailure): noauth.pymongo_test.test.find_one() @@ -562,7 +565,6 @@ class TestSSL(IntegrationTest): tlsCertificateKeyFile=CLIENT_PEM, event_listeners=[listener], ) - self.addCleanup(auth.close) # No error auth.pymongo_test.test.find_one() @@ -581,7 +583,6 @@ class TestSSL(IntegrationTest): client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) - self.addCleanup(client.close) # No error client.pymongo_test.test.find_one() @@ -589,7 +590,6 @@ class TestSSL(IntegrationTest): client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) - self.addCleanup(client.close) # No error client.pymongo_test.test.find_one() # Auth should fail if username and certificate do not match @@ -602,7 +602,6 @@ class TestSSL(IntegrationTest): bad_client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) - self.addCleanup(bad_client.close) with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() @@ -615,7 +614,6 @@ class TestSSL(IntegrationTest): tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM, ) - self.addCleanup(bad_client.close) with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() diff --git a/tools/synchro.py b/tools/synchro.py index 4ac5604f2..9f59448bb 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -229,6 +229,7 @@ converted_tests = [ "test_run_command.py", "test_session.py", "test_srv_polling.py", + "test_ssl.py", "test_transactions.py", "test_unified_format.py", "unified_format.py", From 097a853805f986bf02e75b5a92e5b12d470d570d Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Tue, 4 Feb 2025 09:52:26 -0800 Subject: [PATCH 20/28] PYTHON 5104 - Convert test.test_sessions_unified to async (#2123) Co-authored-by: Noah Stapp --- test/asynchronous/test_sessions_unified.py | 40 ++++++++++++++++++++++ test/test_sessions_unified.py | 9 ++++- tools/synchro.py | 1 + 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_sessions_unified.py diff --git a/test/asynchronous/test_sessions_unified.py b/test/asynchronous/test_sessions_unified.py new file mode 100644 index 000000000..b4cbac570 --- /dev/null +++ b/test/asynchronous/test_sessions_unified.py @@ -0,0 +1,40 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Sessions unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions") + + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_sessions_unified.py b/test/test_sessions_unified.py index c51b4642e..3c80c70d3 100644 --- a/test/test_sessions_unified.py +++ b/test/test_sessions_unified.py @@ -17,14 +17,21 @@ from __future__ import annotations import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sessions") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions") + # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/tools/synchro.py b/tools/synchro.py index 9f59448bb..7b5892f27 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -228,6 +228,7 @@ converted_tests = [ "test_retryable_writes_unified.py", "test_run_command.py", "test_session.py", + "test_sessions_unified.py", "test_srv_polling.py", "test_ssl.py", "test_transactions.py", From 2c492155a6b2284746e53ededf65ac30800b5536 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Tue, 4 Feb 2025 10:30:35 -0800 Subject: [PATCH 21/28] PYTHON-5103 Convert test.test_server_selection_rtt to async (#2122) Co-authored-by: Noah Stapp --- .../asynchronous/test_server_selection_rtt.py | 77 +++++++++++++++++++ test/test_server_selection_rtt.py | 14 +++- tools/synchro.py | 1 + 3 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 test/asynchronous/test_server_selection_rtt.py diff --git a/test/asynchronous/test_server_selection_rtt.py b/test/asynchronous/test_server_selection_rtt.py new file mode 100644 index 000000000..1f8f6bc7d --- /dev/null +++ b/test/asynchronous/test_server_selection_rtt.py @@ -0,0 +1,77 @@ +# Copyright 2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module.""" +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous import AsyncPyMongoTestCase + +from pymongo.read_preferences import MovingAverage + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection/rtt") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection/rtt") + + +class TestAllScenarios(AsyncPyMongoTestCase): + pass + + +def create_test(scenario_def): + def run_scenario(self): + moving_average = MovingAverage() + + if scenario_def["avg_rtt_ms"] != "NULL": + moving_average.add_sample(scenario_def["avg_rtt_ms"]) + + if scenario_def["new_rtt_ms"] != "NULL": + moving_average.add_sample(scenario_def["new_rtt_ms"]) + + self.assertAlmostEqual(moving_average.get(), scenario_def["new_avg_rtt"]) + + return run_scenario + + +def create_tests(): + for dirpath, _, filenames in os.walk(TEST_PATH): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json.load(scenario_stream) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + +create_tests() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_server_selection_rtt.py b/test/test_server_selection_rtt.py index a129af458..2aef36a58 100644 --- a/test/test_server_selection_rtt.py +++ b/test/test_server_selection_rtt.py @@ -18,18 +18,24 @@ from __future__ import annotations import json import os import sys +from pathlib import Path sys.path[0:0] = [""] -from test import unittest +from test import PyMongoTestCase, unittest from pymongo.read_preferences import MovingAverage +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection/rtt") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection/rtt") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection/rtt") -class TestAllScenarios(unittest.TestCase): +class TestAllScenarios(PyMongoTestCase): pass @@ -49,7 +55,7 @@ def create_test(scenario_def): def create_tests(): - for dirpath, _, filenames in os.walk(_TEST_PATH): + for dirpath, _, filenames in os.walk(TEST_PATH): dirname = os.path.split(dirpath)[-1] for filename in filenames: diff --git a/tools/synchro.py b/tools/synchro.py index 7b5892f27..f9a4c5208 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -228,6 +228,7 @@ converted_tests = [ "test_retryable_writes_unified.py", "test_run_command.py", "test_session.py", + "test_server_selection_rtt.py", "test_sessions_unified.py", "test_srv_polling.py", "test_ssl.py", From 8ae9a0432a867e8d0d9dbb81d15830323cc3c7ae Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Tue, 4 Feb 2025 10:31:11 -0800 Subject: [PATCH 22/28] PYTHON-5102 Convert test.test_server_selection_logging to async (#2121) Co-authored-by: Noah Stapp --- .../test_server_selection_logging.py | 45 +++++++++++++++++++ test/test_server_selection_logging.py | 10 ++++- tools/synchro.py | 1 + 3 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 test/asynchronous/test_server_selection_logging.py diff --git a/test/asynchronous/test_server_selection_logging.py b/test/asynchronous/test_server_selection_logging.py new file mode 100644 index 000000000..6b0975318 --- /dev/null +++ b/test/asynchronous/test_server_selection_logging.py @@ -0,0 +1,45 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the server selection logging unified format spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging") + + +globals().update( + generate_test_classes( + TEST_PATH, + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_server_selection_logging.py b/test/test_server_selection_logging.py index 2df749cb1..d53d8dc84 100644 --- a/test/test_server_selection_logging.py +++ b/test/test_server_selection_logging.py @@ -17,19 +17,25 @@ from __future__ import annotations import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection_logging") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging") globals().update( generate_test_classes( - _TEST_PATH, + TEST_PATH, module=__name__, ) ) diff --git a/tools/synchro.py b/tools/synchro.py index f9a4c5208..06dc708e0 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -227,6 +227,7 @@ converted_tests = [ "test_retryable_writes.py", "test_retryable_writes_unified.py", "test_run_command.py", + "test_server_selection_logging.py", "test_session.py", "test_server_selection_rtt.py", "test_sessions_unified.py", From 7a4150ac17859444eea38dd98682672c0d5935bb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 08:48:54 -0500 Subject: [PATCH 23/28] PYTHON-5080 - Convert test.test_examples to async (#2097) --- test/asynchronous/helpers.py | 17 +- test/asynchronous/test_examples.py | 1461 ++++++++++++++++++++++++ test/asynchronous/utils_spec_runner.py | 2 +- test/helpers.py | 17 +- test/test_examples.py | 161 +-- test/utils_spec_runner.py | 2 +- tools/synchro.py | 1 + 7 files changed, 1579 insertions(+), 82 deletions(-) create mode 100644 test/asynchronous/test_examples.py diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 7758f281e..a35c71b10 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -381,14 +381,14 @@ else: class ConcurrentRunner(PARENT): - def __init__(self, name, *args, **kwargs): + def __init__(self, **kwargs): if _IS_SYNC: - super().__init__(*args, **kwargs) - self.name = name + super().__init__(**kwargs) + self.name = kwargs.get("name", "ConcurrentRunner") self.stopped = False self.task = None - if "target" in kwargs: - self.target = kwargs["target"] + self.target = kwargs.get("target", None) + self.args = kwargs.get("args", []) if not _IS_SYNC: @@ -403,6 +403,7 @@ class ConcurrentRunner(PARENT): return not self.stopped async def run(self): - if self.target: - await self.target() - self.stopped = True + try: + await self.target(*self.args) + finally: + self.stopped = True diff --git a/test/asynchronous/test_examples.py b/test/asynchronous/test_examples.py new file mode 100644 index 000000000..7fea9d41a --- /dev/null +++ b/test/asynchronous/test_examples.py @@ -0,0 +1,1461 @@ +# Copyright 2017 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MongoDB documentation examples in Python.""" +from __future__ import annotations + +import asyncio +import datetime +import functools +import sys +import threading +import time +from test.asynchronous.helpers import ConcurrentRunner + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import async_wait_until + +import pymongo +from pymongo.asynchronous.helpers import anext +from pymongo.errors import ConnectionFailure, OperationFailure +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.server_api import ServerApi +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestSampleShellCommands(AsyncIntegrationTest): + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.inventory.drop() + + async def asyncTearDown(self): + # Run after every test. + await self.db.inventory.drop() + await self.client.drop_database("pymongo_test") + + async def test_first_three_examples(self): + db = self.db + + # Start Example 1 + await db.inventory.insert_one( + { + "item": "canvas", + "qty": 100, + "tags": ["cotton"], + "size": {"h": 28, "w": 35.5, "uom": "cm"}, + } + ) + # End Example 1 + + self.assertEqual(await db.inventory.count_documents({}), 1) + + # Start Example 2 + cursor = db.inventory.find({"item": "canvas"}) + # End Example 2 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 3 + await db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "tags": ["blank", "red"], + "size": {"h": 14, "w": 21, "uom": "cm"}, + }, + { + "item": "mat", + "qty": 85, + "tags": ["gray"], + "size": {"h": 27.9, "w": 35.5, "uom": "cm"}, + }, + { + "item": "mousepad", + "qty": 25, + "tags": ["gel", "blue"], + "size": {"h": 19, "w": 22.85, "uom": "cm"}, + }, + ] + ) + # End Example 3 + + self.assertEqual(await db.inventory.count_documents({}), 4) + + async def test_query_top_level_fields(self): + db = self.db + + # Start Example 6 + await db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "A", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + ] + ) + # End Example 6 + + self.assertEqual(await db.inventory.count_documents({}), 5) + + # Start Example 7 + cursor = db.inventory.find({}) + # End Example 7 + + self.assertEqual(len(await cursor.to_list()), 5) + + # Start Example 9 + cursor = db.inventory.find({"status": "D"}) + # End Example 9 + + self.assertEqual(len(await cursor.to_list()), 2) + + # Start Example 10 + cursor = db.inventory.find({"status": {"$in": ["A", "D"]}}) + # End Example 10 + + self.assertEqual(len(await cursor.to_list()), 5) + + # Start Example 11 + cursor = db.inventory.find({"status": "A", "qty": {"$lt": 30}}) + # End Example 11 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 12 + cursor = db.inventory.find({"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]}) + # End Example 12 + + self.assertEqual(len(await cursor.to_list()), 3) + + # Start Example 13 + cursor = db.inventory.find( + {"status": "A", "$or": [{"qty": {"$lt": 30}}, {"item": {"$regex": "^p"}}]} + ) + # End Example 13 + + self.assertEqual(len(await cursor.to_list()), 2) + + async def test_query_embedded_documents(self): + db = self.db + + # Start Example 14 + # Subdocument key order matters in a few of these examples so we have + # to use bson.son.SON instead of a Python dict. + from bson.son import SON + + await db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": SON([("h", 14), ("w", 21), ("uom", "cm")]), + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), + "status": "A", + }, + { + "item": "paper", + "qty": 100, + "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": SON([("h", 22.85), ("w", 30), ("uom", "cm")]), + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": SON([("h", 10), ("w", 15.25), ("uom", "cm")]), + "status": "A", + }, + ] + ) + # End Example 14 + + # Start Example 15 + cursor = db.inventory.find({"size": SON([("h", 14), ("w", 21), ("uom", "cm")])}) + # End Example 15 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 16 + cursor = db.inventory.find({"size": SON([("w", 21), ("h", 14), ("uom", "cm")])}) + # End Example 16 + + self.assertEqual(len(await cursor.to_list()), 0) + + # Start Example 17 + cursor = db.inventory.find({"size.uom": "in"}) + # End Example 17 + + self.assertEqual(len(await cursor.to_list()), 2) + + # Start Example 18 + cursor = db.inventory.find({"size.h": {"$lt": 15}}) + # End Example 18 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 19 + cursor = db.inventory.find({"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"}) + # End Example 19 + + self.assertEqual(len(await cursor.to_list()), 1) + + async def test_query_arrays(self): + db = self.db + + # Start Example 20 + await db.inventory.insert_many( + [ + {"item": "journal", "qty": 25, "tags": ["blank", "red"], "dim_cm": [14, 21]}, + {"item": "notebook", "qty": 50, "tags": ["red", "blank"], "dim_cm": [14, 21]}, + { + "item": "paper", + "qty": 100, + "tags": ["red", "blank", "plain"], + "dim_cm": [14, 21], + }, + {"item": "planner", "qty": 75, "tags": ["blank", "red"], "dim_cm": [22.85, 30]}, + {"item": "postcard", "qty": 45, "tags": ["blue"], "dim_cm": [10, 15.25]}, + ] + ) + # End Example 20 + + # Start Example 21 + cursor = db.inventory.find({"tags": ["red", "blank"]}) + # End Example 21 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 22 + cursor = db.inventory.find({"tags": {"$all": ["red", "blank"]}}) + # End Example 22 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 23 + cursor = db.inventory.find({"tags": "red"}) + # End Example 23 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 24 + cursor = db.inventory.find({"dim_cm": {"$gt": 25}}) + # End Example 24 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 25 + cursor = db.inventory.find({"dim_cm": {"$gt": 15, "$lt": 20}}) + # End Example 25 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 26 + cursor = db.inventory.find({"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}}) + # End Example 26 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 27 + cursor = db.inventory.find({"dim_cm.1": {"$gt": 25}}) + # End Example 27 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 28 + cursor = db.inventory.find({"tags": {"$size": 3}}) + # End Example 28 + + self.assertEqual(len(await cursor.to_list()), 1) + + async def test_query_array_of_documents(self): + db = self.db + + # Start Example 29 + # Subdocument key order matters in a few of these examples so we have + # to use bson.son.SON instead of a Python dict. + from bson.son import SON + + await db.inventory.insert_many( + [ + { + "item": "journal", + "instock": [ + SON([("warehouse", "A"), ("qty", 5)]), + SON([("warehouse", "C"), ("qty", 15)]), + ], + }, + {"item": "notebook", "instock": [SON([("warehouse", "C"), ("qty", 5)])]}, + { + "item": "paper", + "instock": [ + SON([("warehouse", "A"), ("qty", 60)]), + SON([("warehouse", "B"), ("qty", 15)]), + ], + }, + { + "item": "planner", + "instock": [ + SON([("warehouse", "A"), ("qty", 40)]), + SON([("warehouse", "B"), ("qty", 5)]), + ], + }, + { + "item": "postcard", + "instock": [ + SON([("warehouse", "B"), ("qty", 15)]), + SON([("warehouse", "C"), ("qty", 35)]), + ], + }, + ] + ) + # End Example 29 + + # Start Example 30 + cursor = db.inventory.find({"instock": SON([("warehouse", "A"), ("qty", 5)])}) + # End Example 30 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 31 + cursor = db.inventory.find({"instock": SON([("qty", 5), ("warehouse", "A")])}) + # End Example 31 + + self.assertEqual(len(await cursor.to_list()), 0) + + # Start Example 32 + cursor = db.inventory.find({"instock.0.qty": {"$lte": 20}}) + # End Example 32 + + self.assertEqual(len(await cursor.to_list()), 3) + + # Start Example 33 + cursor = db.inventory.find({"instock.qty": {"$lte": 20}}) + # End Example 33 + + self.assertEqual(len(await cursor.to_list()), 5) + + # Start Example 34 + cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}}) + # End Example 34 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 35 + cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}}) + # End Example 35 + + self.assertEqual(len(await cursor.to_list()), 3) + + # Start Example 36 + cursor = db.inventory.find({"instock.qty": {"$gt": 10, "$lte": 20}}) + # End Example 36 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 37 + cursor = db.inventory.find({"instock.qty": 5, "instock.warehouse": "A"}) + # End Example 37 + + self.assertEqual(len(await cursor.to_list()), 2) + + async def test_query_null(self): + db = self.db + + # Start Example 38 + await db.inventory.insert_many([{"_id": 1, "item": None}, {"_id": 2}]) + # End Example 38 + + # Start Example 39 + cursor = db.inventory.find({"item": None}) + # End Example 39 + + self.assertEqual(len(await cursor.to_list()), 2) + + # Start Example 40 + cursor = db.inventory.find({"item": {"$type": 10}}) + # End Example 40 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 41 + cursor = db.inventory.find({"item": {"$exists": False}}) + # End Example 41 + + self.assertEqual(len(await cursor.to_list()), 1) + + async def test_projection(self): + db = self.db + + # Start Example 42 + await db.inventory.insert_many( + [ + { + "item": "journal", + "status": "A", + "size": {"h": 14, "w": 21, "uom": "cm"}, + "instock": [{"warehouse": "A", "qty": 5}], + }, + { + "item": "notebook", + "status": "A", + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "instock": [{"warehouse": "C", "qty": 5}], + }, + { + "item": "paper", + "status": "D", + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "instock": [{"warehouse": "A", "qty": 60}], + }, + { + "item": "planner", + "status": "D", + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "instock": [{"warehouse": "A", "qty": 40}], + }, + { + "item": "postcard", + "status": "A", + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "instock": [{"warehouse": "B", "qty": 15}, {"warehouse": "C", "qty": 35}], + }, + ] + ) + # End Example 42 + + # Start Example 43 + cursor = db.inventory.find({"status": "A"}) + # End Example 43 + + self.assertEqual(len(await cursor.to_list()), 3) + + # Start Example 44 + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1}) + # End Example 44 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertFalse("size" in doc) + self.assertFalse("instock" in doc) + + # Start Example 45 + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "_id": 0}) + # End Example 45 + + async for doc in cursor: + self.assertFalse("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertFalse("size" in doc) + self.assertFalse("instock" in doc) + + # Start Example 46 + cursor = db.inventory.find({"status": "A"}, {"status": 0, "instock": 0}) + # End Example 46 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertFalse("status" in doc) + self.assertTrue("size" in doc) + self.assertFalse("instock" in doc) + + # Start Example 47 + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "size.uom": 1}) + # End Example 47 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertTrue("size" in doc) + self.assertFalse("instock" in doc) + size = doc["size"] + self.assertTrue("uom" in size) + self.assertFalse("h" in size) + self.assertFalse("w" in size) + + # Start Example 48 + cursor = db.inventory.find({"status": "A"}, {"size.uom": 0}) + # End Example 48 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertTrue("size" in doc) + self.assertTrue("instock" in doc) + size = doc["size"] + self.assertFalse("uom" in size) + self.assertTrue("h" in size) + self.assertTrue("w" in size) + + # Start Example 49 + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "instock.qty": 1}) + # End Example 49 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertFalse("size" in doc) + self.assertTrue("instock" in doc) + for subdoc in doc["instock"]: + self.assertFalse("warehouse" in subdoc) + self.assertTrue("qty" in subdoc) + + # Start Example 50 + cursor = db.inventory.find( + {"status": "A"}, {"item": 1, "status": 1, "instock": {"$slice": -1}} + ) + # End Example 50 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertFalse("size" in doc) + self.assertTrue("instock" in doc) + self.assertEqual(len(doc["instock"]), 1) + + async def test_update_and_replace(self): + db = self.db + + # Start Example 51 + await db.inventory.insert_many( + [ + { + "item": "canvas", + "qty": 100, + "size": {"h": 28, "w": 35.5, "uom": "cm"}, + "status": "A", + }, + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "mat", + "qty": 85, + "size": {"h": 27.9, "w": 35.5, "uom": "cm"}, + "status": "A", + }, + { + "item": "mousepad", + "qty": 25, + "size": {"h": 19, "w": 22.85, "uom": "cm"}, + "status": "P", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "P", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + { + "item": "sketchbook", + "qty": 80, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "sketch pad", + "qty": 95, + "size": {"h": 22.85, "w": 30.5, "uom": "cm"}, + "status": "A", + }, + ] + ) + # End Example 51 + + # Start Example 52 + await db.inventory.update_one( + {"item": "paper"}, + {"$set": {"size.uom": "cm", "status": "P"}, "$currentDate": {"lastModified": True}}, + ) + # End Example 52 + + async for doc in db.inventory.find({"item": "paper"}): + self.assertEqual(doc["size"]["uom"], "cm") + self.assertEqual(doc["status"], "P") + self.assertTrue("lastModified" in doc) + + # Start Example 53 + await db.inventory.update_many( + {"qty": {"$lt": 50}}, + {"$set": {"size.uom": "in", "status": "P"}, "$currentDate": {"lastModified": True}}, + ) + # End Example 53 + + async for doc in db.inventory.find({"qty": {"$lt": 50}}): + self.assertEqual(doc["size"]["uom"], "in") + self.assertEqual(doc["status"], "P") + self.assertTrue("lastModified" in doc) + + # Start Example 54 + await db.inventory.replace_one( + {"item": "paper"}, + { + "item": "paper", + "instock": [{"warehouse": "A", "qty": 60}, {"warehouse": "B", "qty": 40}], + }, + ) + # End Example 54 + + async for doc in db.inventory.find({"item": "paper"}, {"_id": 0}): + self.assertEqual(len(doc.keys()), 2) + self.assertTrue("item" in doc) + self.assertTrue("instock" in doc) + self.assertEqual(len(doc["instock"]), 2) + + async def test_delete(self): + db = self.db + + # Start Example 55 + await db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "P", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + ] + ) + # End Example 55 + + self.assertEqual(await db.inventory.count_documents({}), 5) + + # Start Example 57 + await db.inventory.delete_many({"status": "A"}) + # End Example 57 + + self.assertEqual(await db.inventory.count_documents({}), 3) + + # Start Example 58 + await db.inventory.delete_one({"status": "D"}) + # End Example 58 + + self.assertEqual(await db.inventory.count_documents({}), 2) + + # Start Example 56 + await db.inventory.delete_many({}) + # End Example 56 + + self.assertEqual(await db.inventory.count_documents({}), 0) + + @async_client_context.require_change_streams + async def test_change_streams(self): + db = self.db + done = False + + async def insert_docs(): + nonlocal done + while not done: + await db.inventory.insert_one({"username": "alice"}) + await db.inventory.delete_one({"username": "alice"}) + await asyncio.sleep(0.005) + + t = ConcurrentRunner(target=insert_docs) + await t.start() + + try: + # 1. The database for reactive, real-time applications + # Start Changestream Example 1 + cursor = await db.inventory.watch() + await anext(cursor) + # End Changestream Example 1 + await cursor.close() + + # Start Changestream Example 2 + cursor = await db.inventory.watch(full_document="updateLookup") + await anext(cursor) + # End Changestream Example 2 + await cursor.close() + + # Start Changestream Example 3 + resume_token = cursor.resume_token + cursor = await db.inventory.watch(resume_after=resume_token) + await anext(cursor) + # End Changestream Example 3 + await cursor.close() + + # Start Changestream Example 4 + pipeline = [ + {"$match": {"fullDocument.username": "alice"}}, + {"$addFields": {"newField": "this is an added field!"}}, + ] + cursor = await db.inventory.watch(pipeline=pipeline) + await anext(cursor) + # End Changestream Example 4 + await cursor.close() + finally: + done = True + await t.join() + + async def test_aggregate_examples(self): + db = self.db + + # Start Aggregation Example 1 + await db.sales.aggregate([{"$match": {"items.fruit": "banana"}}, {"$sort": {"date": 1}}]) + # End Aggregation Example 1 + + # Start Aggregation Example 2 + await db.sales.aggregate( + [ + {"$unwind": "$items"}, + {"$match": {"items.fruit": "banana"}}, + { + "$group": { + "_id": {"day": {"$dayOfWeek": "$date"}}, + "count": {"$sum": "$items.quantity"}, + } + }, + {"$project": {"dayOfWeek": "$_id.day", "numberSold": "$count", "_id": 0}}, + {"$sort": {"numberSold": 1}}, + ] + ) + # End Aggregation Example 2 + + # Start Aggregation Example 3 + await db.sales.aggregate( + [ + {"$unwind": "$items"}, + { + "$group": { + "_id": {"day": {"$dayOfWeek": "$date"}}, + "items_sold": {"$sum": "$items.quantity"}, + "revenue": {"$sum": {"$multiply": ["$items.quantity", "$items.price"]}}, + } + }, + { + "$project": { + "day": "$_id.day", + "revenue": 1, + "items_sold": 1, + "discount": { + "$cond": {"if": {"$lte": ["$revenue", 250]}, "then": 25, "else": 0} + }, + } + }, + ] + ) + # End Aggregation Example 3 + + # Start Aggregation Example 4 + await db.air_alliances.aggregate( + [ + { + "$lookup": { + "from": "air_airlines", + "let": {"constituents": "$airlines"}, + "pipeline": [{"$match": {"$expr": {"$in": ["$name", "$$constituents"]}}}], + "as": "airlines", + } + }, + { + "$project": { + "_id": 0, + "name": 1, + "airlines": { + "$filter": { + "input": "$airlines", + "as": "airline", + "cond": {"$eq": ["$$airline.country", "Canada"]}, + } + }, + } + }, + ] + ) + # End Aggregation Example 4 + + @async_client_context.require_version_min(4, 4) + async def test_aggregate_projection_example(self): + db = self.db + + # Start Aggregation Projection Example 1 + db.inventory.find( + {}, + { + "_id": 0, + "item": 1, + "status": { + "$switch": { + "branches": [ + {"case": {"$eq": ["$status", "A"]}, "then": "Available"}, + {"case": {"$eq": ["$status", "D"]}, "then": "Discontinued"}, + ], + "default": "No status found", + } + }, + "area": { + "$concat": [ + {"$toString": {"$multiply": ["$size.h", "$size.w"]}}, + " ", + "$size.uom", + ] + }, + "reportNumber": {"$literal": 1}, + }, + ) + + # End Aggregation Projection Example 1 + + async def test_commands(self): + db = self.db + await db.restaurants.insert_one({}) + + # Start runCommand Example 1 + await db.command("buildInfo") + # End runCommand Example 1 + + # Start runCommand Example 2 + await db.command("count", "restaurants") + # End runCommand Example 2 + + async def test_index_management(self): + db = self.db + + # Start Index Example 1 + await db.records.create_index("score") + # End Index Example 1 + + # Start Index Example 1 + await db.restaurants.create_index( + [("cuisine", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], + partialFilterExpression={"rating": {"$gt": 5}}, + ) + # End Index Example 1 + + @async_client_context.require_replica_set + async def test_misc(self): + # Marketing examples + client = self.client + self.addAsyncCleanup(client.drop_database, "test") + self.addAsyncCleanup(client.drop_database, "my_database") + + # 2. Tunable consistency controls + collection = client.my_database.my_collection + async with client.start_session() as session: + await collection.insert_one({"_id": 1}, session=session) + await collection.update_one({"_id": 1}, {"$set": {"a": 1}}, session=session) + async for _doc in collection.find({}, session=session): + pass + + # 3. Exploiting the power of arrays + collection = client.test.array_updates_test + await collection.update_one( + {"_id": 1}, {"$set": {"a.$[i].b": 2}}, array_filters=[{"i.b": 0}] + ) + + +class TestTransactionExamples(AsyncIntegrationTest): + @async_client_context.require_transactions + async def test_transactions(self): + # Transaction examples + client = self.client + self.addAsyncCleanup(client.drop_database, "hr") + self.addAsyncCleanup(client.drop_database, "reporting") + + employees = client.hr.employees + events = client.reporting.events + await employees.insert_one({"employee": 3, "status": "Active"}) + await events.insert_one({"employee": 3, "status": {"new": "Active", "old": None}}) + + # Start Transactions Intro Example 1 + + async def update_employee_info(session): + employees_coll = session.client.hr.employees + events_coll = session.client.reporting.events + + async with await session.start_transaction( + read_concern=ReadConcern("snapshot"), write_concern=WriteConcern(w="majority") + ): + await employees_coll.update_one( + {"employee": 3}, {"$set": {"status": "Inactive"}}, session=session + ) + await events_coll.insert_one( + {"employee": 3, "status": {"new": "Inactive", "old": "Active"}}, session=session + ) + + while True: + try: + # Commit uses write concern set at transaction start. + await session.commit_transaction() + print("Transaction committed.") + break + except (ConnectionFailure, OperationFailure) as exc: + # Can retry commit + if exc.has_error_label("UnknownTransactionCommitResult"): + print("UnknownTransactionCommitResult, retrying commit operation ...") + continue + else: + print("Error during commit ...") + raise + + # End Transactions Intro Example 1 + + async with client.start_session() as session: + await update_employee_info(session) + + employee = await employees.find_one({"employee": 3}) + assert employee is not None + self.assertIsNotNone(employee) + self.assertEqual(employee["status"], "Inactive") + + # Start Transactions Retry Example 1 + async def run_transaction_with_retry(txn_func, session): + while True: + try: + await txn_func(session) # performs transaction + break + except (ConnectionFailure, OperationFailure) as exc: + print("Transaction aborted. Caught exception during transaction.") + + # If transient error, retry the whole transaction + if exc.has_error_label("TransientTransactionError"): + print("TransientTransactionError, retrying transaction ...") + continue + else: + raise + + # End Transactions Retry Example 1 + + async with client.start_session() as session: + await run_transaction_with_retry(update_employee_info, session) + + employee = await employees.find_one({"employee": 3}) + assert employee is not None + self.assertIsNotNone(employee) + self.assertEqual(employee["status"], "Inactive") + + # Start Transactions Retry Example 2 + async def commit_with_retry(session): + while True: + try: + # Commit uses write concern set at transaction start. + await session.commit_transaction() + print("Transaction committed.") + break + except (ConnectionFailure, OperationFailure) as exc: + # Can retry commit + if exc.has_error_label("UnknownTransactionCommitResult"): + print("UnknownTransactionCommitResult, retrying commit operation ...") + continue + else: + print("Error during commit ...") + raise + + # End Transactions Retry Example 2 + + # Test commit_with_retry from the previous examples + async def _insert_employee_retry_commit(session): + async with await session.start_transaction(): + await employees.insert_one({"employee": 4, "status": "Active"}, session=session) + await events.insert_one( + {"employee": 4, "status": {"new": "Active", "old": None}}, session=session + ) + + await commit_with_retry(session) + + async with client.start_session() as session: + await run_transaction_with_retry(_insert_employee_retry_commit, session) + + employee = await employees.find_one({"employee": 4}) + assert employee is not None + self.assertIsNotNone(employee) + self.assertEqual(employee["status"], "Active") + + # Start Transactions Retry Example 3 + + async def run_transaction_with_retry(txn_func, session): + while True: + try: + await txn_func(session) # performs transaction + break + except (ConnectionFailure, OperationFailure) as exc: + # If transient error, retry the whole transaction + if exc.has_error_label("TransientTransactionError"): + print("TransientTransactionError, retrying transaction ...") + continue + else: + raise + + async def commit_with_retry(session): + while True: + try: + # Commit uses write concern set at transaction start. + await session.commit_transaction() + print("Transaction committed.") + break + except (ConnectionFailure, OperationFailure) as exc: + # Can retry commit + if exc.has_error_label("UnknownTransactionCommitResult"): + print("UnknownTransactionCommitResult, retrying commit operation ...") + continue + else: + print("Error during commit ...") + raise + + # Updates two collections in a transactions + + async def update_employee_info(session): + employees_coll = session.client.hr.employees + events_coll = session.client.reporting.events + + async with await session.start_transaction( + read_concern=ReadConcern("snapshot"), + write_concern=WriteConcern(w="majority"), + read_preference=ReadPreference.PRIMARY, + ): + await employees_coll.update_one( + {"employee": 3}, {"$set": {"status": "Inactive"}}, session=session + ) + await events_coll.insert_one( + {"employee": 3, "status": {"new": "Inactive", "old": "Active"}}, session=session + ) + + await commit_with_retry(session) + + # Start a session. + async with client.start_session() as session: + try: + await run_transaction_with_retry(update_employee_info, session) + except Exception: + # Do something with error. + raise + + # End Transactions Retry Example 3 + + employee = await employees.find_one({"employee": 3}) + assert employee is not None + self.assertIsNotNone(employee) + self.assertEqual(employee["status"], "Inactive") + + async def MongoClient(_): + return await self.async_rs_client() + + uriString = None + + # Start Transactions withTxn API Example 1 + + # For a replica set, include the replica set name and a seedlist of the members in the URI string; e.g. + # uriString = 'mongodb://mongodb0.example.com:27017,mongodb1.example.com:27017/?replicaSet=myRepl' + # For a sharded cluster, connect to the mongos instances; e.g. + # uriString = 'mongodb://mongos0.example.com:27017,mongos1.example.com:27017/' + + client = await MongoClient(uriString) + wc_majority = WriteConcern("majority", wtimeout=1000) + + # Prereq: Create collections. + await client.get_database("mydb1", write_concern=wc_majority).foo.insert_one({"abc": 0}) + await client.get_database("mydb2", write_concern=wc_majority).bar.insert_one({"xyz": 0}) + + # Step 1: Define the callback that specifies the sequence of operations to perform inside the transactions. + async def callback(session): + collection_one = session.client.mydb1.foo + collection_two = session.client.mydb2.bar + + # Important:: You must pass the session to the operations. + await collection_one.insert_one({"abc": 1}, session=session) + await collection_two.insert_one({"xyz": 999}, session=session) + + # Step 2: Start a client session. + async with client.start_session() as session: + # Step 3: Use with_transaction to start a transaction, execute the callback, and commit (or abort on error). + await session.with_transaction( + callback, + read_concern=ReadConcern("local"), + write_concern=wc_majority, + read_preference=ReadPreference.PRIMARY, + ) + + # End Transactions withTxn API Example 1 + + +class TestCausalConsistencyExamples(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + @async_client_context.require_no_mmap + async def test_causal_consistency(self): + # Causal consistency examples + client = self.client + self.addAsyncCleanup(client.drop_database, "test") + await client.test.drop_collection("items") + await client.test.items.insert_one( + {"sku": "111", "name": "Peanuts", "start": datetime.datetime.today()} + ) + + # Start Causal Consistency Example 1 + async with client.start_session(causal_consistency=True) as s1: + current_date = datetime.datetime.today() + items = client.get_database( + "test", + read_concern=ReadConcern("majority"), + write_concern=WriteConcern("majority", wtimeout=1000), + ).items + await items.update_one( + {"sku": "111", "end": None}, {"$set": {"end": current_date}}, session=s1 + ) + await items.insert_one( + {"sku": "nuts-111", "name": "Pecans", "start": current_date}, session=s1 + ) + # End Causal Consistency Example 1 + + assert s1.cluster_time is not None + assert s1.operation_time is not None + + # Start Causal Consistency Example 2 + async with client.start_session(causal_consistency=True) as s2: + s2.advance_cluster_time(s1.cluster_time) + s2.advance_operation_time(s1.operation_time) + + items = client.get_database( + "test", + read_preference=ReadPreference.SECONDARY, + read_concern=ReadConcern("majority"), + write_concern=WriteConcern("majority", wtimeout=1000), + ).items + async for item in items.find({"end": None}, session=s2): + print(item) + # End Causal Consistency Example 2 + + +class TestVersionedApiExamples(AsyncIntegrationTest): + @async_client_context.require_version_min(4, 7) + async def test_versioned_api(self): + # Versioned API examples + async def MongoClient(_, server_api): + return await self.async_rs_client(server_api=server_api, connect=False) + + uri = None + + # Start Versioned API Example 1 + from pymongo.server_api import ServerApi + + await MongoClient(uri, server_api=ServerApi("1")) + # End Versioned API Example 1 + + # Start Versioned API Example 2 + await MongoClient(uri, server_api=ServerApi("1", strict=True)) + # End Versioned API Example 2 + + # Start Versioned API Example 3 + await MongoClient(uri, server_api=ServerApi("1", strict=False)) + # End Versioned API Example 3 + + # Start Versioned API Example 4 + await MongoClient(uri, server_api=ServerApi("1", deprecation_errors=True)) + # End Versioned API Example 4 + + @unittest.skip("PYTHON-3167 count has been added to API version 1") + @async_client_context.require_version_min(4, 7) + async def test_versioned_api_migration(self): + # SERVER-58785 + if await async_client_context.is_topology_type( + ["sharded"] + ) and not async_client_context.version.at_least(5, 0, 2): + self.skipTest("This test needs MongoDB 5.0.2 or newer") + + client = await self.async_rs_client(server_api=ServerApi("1", strict=True)) + await client.db.sales.drop() + + # Start Versioned API Example 5 + def strptime(s): + return datetime.datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ") + + await client.db.sales.insert_many( + [ + { + "_id": 1, + "item": "abc", + "price": 10, + "quantity": 2, + "date": strptime("2021-01-01T08:00:00Z"), + }, + { + "_id": 2, + "item": "jkl", + "price": 20, + "quantity": 1, + "date": strptime("2021-02-03T09:00:00Z"), + }, + { + "_id": 3, + "item": "xyz", + "price": 5, + "quantity": 5, + "date": strptime("2021-02-03T09:05:00Z"), + }, + { + "_id": 4, + "item": "abc", + "price": 10, + "quantity": 10, + "date": strptime("2021-02-15T08:00:00Z"), + }, + { + "_id": 5, + "item": "xyz", + "price": 5, + "quantity": 10, + "date": strptime("2021-02-15T09:05:00Z"), + }, + { + "_id": 6, + "item": "xyz", + "price": 5, + "quantity": 5, + "date": strptime("2021-02-15T12:05:10Z"), + }, + { + "_id": 7, + "item": "xyz", + "price": 5, + "quantity": 10, + "date": strptime("2021-02-15T14:12:12Z"), + }, + { + "_id": 8, + "item": "abc", + "price": 10, + "quantity": 5, + "date": strptime("2021-03-16T20:20:13Z"), + }, + ] + ) + # End Versioned API Example 5 + + with self.assertRaisesRegex( + OperationFailure, + "Provided apiStrict:true, but the command count is not in API Version 1", + ): + await client.db.command("count", "sales", query={}) + # Start Versioned API Example 6 + # pymongo.errors.OperationFailure: Provided apiStrict:true, but the command count is not in API Version 1, full error: {'ok': 0.0, 'errmsg': 'Provided apiStrict:true, but the command count is not in API Version 1', 'code': 323, 'codeName': 'APIStrictError'} + # End Versioned API Example 6 + + # Start Versioned API Example 7 + await client.db.sales.count_documents({}) + # End Versioned API Example 7 + + # Start Versioned API Example 8 + # 8 + # End Versioned API Example 8 + + +class TestSnapshotQueryExamples(AsyncIntegrationTest): + @async_client_context.require_version_min(5, 0) + async def test_snapshot_query(self): + client = self.client + + if not await async_client_context.is_topology_type(["replicaset", "sharded"]): + self.skipTest("Must be a sharded or replicaset") + + self.addAsyncCleanup(client.drop_database, "pets") + db = client.pets + await db.drop_collection("cats") + await db.drop_collection("dogs") + await db.cats.insert_one( + {"name": "Whiskers", "color": "white", "age": 10, "adoptable": True} + ) + await db.dogs.insert_one( + {"name": "Pebbles", "color": "Brown", "age": 10, "adoptable": True} + ) + + async def predicate_one(): + return await self.check_for_snapshot(db.cats) + + async def predicate_two(): + return await self.check_for_snapshot(db.dogs) + + await async_wait_until(predicate_two, "success") + await async_wait_until(predicate_one, "success") + + # Start Snapshot Query Example 1 + + db = client.pets + async with client.start_session(snapshot=True) as s: + adoptablePetsCount = ( + await ( + await db.cats.aggregate( + [{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}], + session=s, + ) + ).next() + )["adoptableCatsCount"] + + adoptablePetsCount += ( + await ( + await db.dogs.aggregate( + [{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}], + session=s, + ) + ).next() + )["adoptableDogsCount"] + + print(adoptablePetsCount) + + # End Snapshot Query Example 1 + db = client.retail + self.addAsyncCleanup(client.drop_database, "retail") + await db.drop_collection("sales") + + saleDate = datetime.datetime.now() + await db.sales.insert_one({"shoeType": "boot", "price": 30, "saleDate": saleDate}) + + async def predicate_three(): + return await self.check_for_snapshot(db.sales) + + await async_wait_until(predicate_three, "success") + + # Start Snapshot Query Example 2 + db = client.retail + async with client.start_session(snapshot=True) as s: + _ = ( + await ( + await db.sales.aggregate( + [ + { + "$match": { + "$expr": { + "$gt": [ + "$saleDate", + { + "$dateSubtract": { + "startDate": "$$NOW", + "unit": "day", + "amount": 1, + } + }, + ] + } + } + }, + {"$count": "totalDailySales"}, + ], + session=s, + ) + ).next() + )["totalDailySales"] + + # End Snapshot Query Example 2 + + async def check_for_snapshot(self, collection): + """Wait for snapshot reads to become available to prevent this error: + [246:SnapshotUnavailable]: Unable to read from a snapshot due to pending collection catalog changes; please retry the operation. Snapshot timestamp is Timestamp(1646666892, 4). Collection minimum is Timestamp(1646666892, 5) (on localhost:27017, modern retry, attempt 1) + From https://github.com/mongodb/mongo-ruby-driver/commit/7c4117b58e3d12e237f7536f7521e18fc15f79ac + """ + async with self.client.start_session(snapshot=True) as s: + try: + if await collection.find_one(session=s): + return True + return False + except OperationFailure as e: + # Retry them as the server demands... + if e.code == 246: # SnapshotUnavailable + return False + raise + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index d10337431..d433f1a7e 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -58,7 +58,7 @@ _IS_SYNC = False class SpecRunnerTask(ConcurrentRunner): def __init__(self, name): - super().__init__(name) + super().__init__(name=name) self.exc = None self.daemon = True self.cond = _async_create_condition(_async_create_lock()) diff --git a/test/helpers.py b/test/helpers.py index bd9e23bba..705843efc 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -381,14 +381,14 @@ else: class ConcurrentRunner(PARENT): - def __init__(self, name, *args, **kwargs): + def __init__(self, **kwargs): if _IS_SYNC: - super().__init__(*args, **kwargs) - self.name = name + super().__init__(**kwargs) + self.name = kwargs.get("name", "ConcurrentRunner") self.stopped = False self.task = None - if "target" in kwargs: - self.target = kwargs["target"] + self.target = kwargs.get("target", None) + self.args = kwargs.get("args", []) if not _IS_SYNC: @@ -403,6 +403,7 @@ class ConcurrentRunner(PARENT): return not self.stopped def run(self): - if self.target: - self.target() - self.stopped = True + try: + self.target(*self.args) + finally: + self.stopped = True diff --git a/test/test_examples.py b/test/test_examples.py index 7f98226e7..9bcc27624 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -15,9 +15,13 @@ """MongoDB documentation examples in Python.""" from __future__ import annotations +import asyncio import datetime +import functools import sys import threading +import time +from test.helpers import ConcurrentRunner sys.path[0:0] = [""] @@ -29,8 +33,11 @@ from pymongo.errors import ConnectionFailure, OperationFailure from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.server_api import ServerApi +from pymongo.synchronous.helpers import next from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestSampleShellCommands(IntegrationTest): def setUp(self): @@ -62,7 +69,7 @@ class TestSampleShellCommands(IntegrationTest): cursor = db.inventory.find({"item": "canvas"}) # End Example 2 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 3 db.inventory.insert_many( @@ -137,31 +144,31 @@ class TestSampleShellCommands(IntegrationTest): cursor = db.inventory.find({}) # End Example 7 - self.assertEqual(len(list(cursor)), 5) + self.assertEqual(len(cursor.to_list()), 5) # Start Example 9 cursor = db.inventory.find({"status": "D"}) # End Example 9 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) # Start Example 10 cursor = db.inventory.find({"status": {"$in": ["A", "D"]}}) # End Example 10 - self.assertEqual(len(list(cursor)), 5) + self.assertEqual(len(cursor.to_list()), 5) # Start Example 11 cursor = db.inventory.find({"status": "A", "qty": {"$lt": 30}}) # End Example 11 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 12 cursor = db.inventory.find({"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]}) # End Example 12 - self.assertEqual(len(list(cursor)), 3) + self.assertEqual(len(cursor.to_list()), 3) # Start Example 13 cursor = db.inventory.find( @@ -169,7 +176,7 @@ class TestSampleShellCommands(IntegrationTest): ) # End Example 13 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) def test_query_embedded_documents(self): db = self.db @@ -219,31 +226,31 @@ class TestSampleShellCommands(IntegrationTest): cursor = db.inventory.find({"size": SON([("h", 14), ("w", 21), ("uom", "cm")])}) # End Example 15 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 16 cursor = db.inventory.find({"size": SON([("w", 21), ("h", 14), ("uom", "cm")])}) # End Example 16 - self.assertEqual(len(list(cursor)), 0) + self.assertEqual(len(cursor.to_list()), 0) # Start Example 17 cursor = db.inventory.find({"size.uom": "in"}) # End Example 17 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) # Start Example 18 cursor = db.inventory.find({"size.h": {"$lt": 15}}) # End Example 18 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 19 cursor = db.inventory.find({"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"}) # End Example 19 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) def test_query_arrays(self): db = self.db @@ -269,49 +276,49 @@ class TestSampleShellCommands(IntegrationTest): cursor = db.inventory.find({"tags": ["red", "blank"]}) # End Example 21 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 22 cursor = db.inventory.find({"tags": {"$all": ["red", "blank"]}}) # End Example 22 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 23 cursor = db.inventory.find({"tags": "red"}) # End Example 23 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 24 cursor = db.inventory.find({"dim_cm": {"$gt": 25}}) # End Example 24 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 25 cursor = db.inventory.find({"dim_cm": {"$gt": 15, "$lt": 20}}) # End Example 25 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 26 cursor = db.inventory.find({"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}}) # End Example 26 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 27 cursor = db.inventory.find({"dim_cm.1": {"$gt": 25}}) # End Example 27 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 28 cursor = db.inventory.find({"tags": {"$size": 3}}) # End Example 28 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) def test_query_array_of_documents(self): db = self.db @@ -360,49 +367,49 @@ class TestSampleShellCommands(IntegrationTest): cursor = db.inventory.find({"instock": SON([("warehouse", "A"), ("qty", 5)])}) # End Example 30 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 31 cursor = db.inventory.find({"instock": SON([("qty", 5), ("warehouse", "A")])}) # End Example 31 - self.assertEqual(len(list(cursor)), 0) + self.assertEqual(len(cursor.to_list()), 0) # Start Example 32 cursor = db.inventory.find({"instock.0.qty": {"$lte": 20}}) # End Example 32 - self.assertEqual(len(list(cursor)), 3) + self.assertEqual(len(cursor.to_list()), 3) # Start Example 33 cursor = db.inventory.find({"instock.qty": {"$lte": 20}}) # End Example 33 - self.assertEqual(len(list(cursor)), 5) + self.assertEqual(len(cursor.to_list()), 5) # Start Example 34 cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}}) # End Example 34 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 35 cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}}) # End Example 35 - self.assertEqual(len(list(cursor)), 3) + self.assertEqual(len(cursor.to_list()), 3) # Start Example 36 cursor = db.inventory.find({"instock.qty": {"$gt": 10, "$lte": 20}}) # End Example 36 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 37 cursor = db.inventory.find({"instock.qty": 5, "instock.warehouse": "A"}) # End Example 37 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) def test_query_null(self): db = self.db @@ -415,19 +422,19 @@ class TestSampleShellCommands(IntegrationTest): cursor = db.inventory.find({"item": None}) # End Example 39 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) # Start Example 40 cursor = db.inventory.find({"item": {"$type": 10}}) # End Example 40 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 41 cursor = db.inventory.find({"item": {"$exists": False}}) # End Example 41 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) def test_projection(self): db = self.db @@ -473,7 +480,7 @@ class TestSampleShellCommands(IntegrationTest): cursor = db.inventory.find({"status": "A"}) # End Example 43 - self.assertEqual(len(list(cursor)), 3) + self.assertEqual(len(cursor.to_list()), 3) # Start Example 44 cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1}) @@ -746,8 +753,9 @@ class TestSampleShellCommands(IntegrationTest): while not done: db.inventory.insert_one({"username": "alice"}) db.inventory.delete_one({"username": "alice"}) + time.sleep(0.005) - t = threading.Thread(target=insert_docs) + t = ConcurrentRunner(target=insert_docs) t.start() try: @@ -1347,20 +1355,37 @@ class TestSnapshotQueryExamples(IntegrationTest): db.drop_collection("dogs") db.cats.insert_one({"name": "Whiskers", "color": "white", "age": 10, "adoptable": True}) db.dogs.insert_one({"name": "Pebbles", "color": "Brown", "age": 10, "adoptable": True}) - wait_until(lambda: self.check_for_snapshot(db.cats), "success") - wait_until(lambda: self.check_for_snapshot(db.dogs), "success") + + def predicate_one(): + return self.check_for_snapshot(db.cats) + + def predicate_two(): + return self.check_for_snapshot(db.dogs) + + wait_until(predicate_two, "success") + wait_until(predicate_one, "success") # Start Snapshot Query Example 1 db = client.pets with client.start_session(snapshot=True) as s: - adoptablePetsCount = db.cats.aggregate( - [{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}], session=s - ).next()["adoptableCatsCount"] + adoptablePetsCount = ( + ( + db.cats.aggregate( + [{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}], + session=s, + ) + ).next() + )["adoptableCatsCount"] - adoptablePetsCount += db.dogs.aggregate( - [{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}], session=s - ).next()["adoptableDogsCount"] + adoptablePetsCount += ( + ( + db.dogs.aggregate( + [{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}], + session=s, + ) + ).next() + )["adoptableDogsCount"] print(adoptablePetsCount) @@ -1371,33 +1396,41 @@ class TestSnapshotQueryExamples(IntegrationTest): saleDate = datetime.datetime.now() db.sales.insert_one({"shoeType": "boot", "price": 30, "saleDate": saleDate}) - wait_until(lambda: self.check_for_snapshot(db.sales), "success") + + def predicate_three(): + return self.check_for_snapshot(db.sales) + + wait_until(predicate_three, "success") # Start Snapshot Query Example 2 db = client.retail with client.start_session(snapshot=True) as s: - db.sales.aggregate( - [ - { - "$match": { - "$expr": { - "$gt": [ - "$saleDate", - { - "$dateSubtract": { - "startDate": "$$NOW", - "unit": "day", - "amount": 1, - } - }, - ] - } - } - }, - {"$count": "totalDailySales"}, - ], - session=s, - ).next()["totalDailySales"] + _ = ( + ( + db.sales.aggregate( + [ + { + "$match": { + "$expr": { + "$gt": [ + "$saleDate", + { + "$dateSubtract": { + "startDate": "$$NOW", + "unit": "day", + "amount": 1, + } + }, + ] + } + } + }, + {"$count": "totalDailySales"}, + ], + session=s, + ) + ).next() + )["totalDailySales"] # End Snapshot Query Example 2 diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 6a62112af..98949431d 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -58,7 +58,7 @@ _IS_SYNC = True class SpecRunnerThread(ConcurrentRunner): def __init__(self, name): - super().__init__(name) + super().__init__(name=name) self.exc = None self.daemon = True self.cond = _create_condition(_create_lock()) diff --git a/tools/synchro.py b/tools/synchro.py index 06dc708e0..ffbea4e53 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -210,6 +210,7 @@ converted_tests = [ "test_data_lake.py", "test_dns.py", "test_encryption.py", + "test_examples.py", "test_heartbeat_monitoring.py", "test_index_management.py", "test_grid_file.py", From 02d6cc9cfdcac8f52a140c8549b19d5edc34d8f1 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 12:10:44 -0500 Subject: [PATCH 24/28] PYTHON-5107 - Convert test.test_streaming_protocol to async (#2126) --- test/asynchronous/test_streaming_protocol.py | 228 +++++++++++++++++++ test/test_streaming_protocol.py | 8 +- tools/synchro.py | 1 + 3 files changed, 232 insertions(+), 5 deletions(-) create mode 100644 test/asynchronous/test_streaming_protocol.py diff --git a/test/asynchronous/test_streaming_protocol.py b/test/asynchronous/test_streaming_protocol.py new file mode 100644 index 000000000..fd890d29f --- /dev/null +++ b/test/asynchronous/test_streaming_protocol.py @@ -0,0 +1,228 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the database module.""" +from __future__ import annotations + +import sys +import time + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import ( + HeartbeatEventListener, + ServerEventListener, + async_wait_until, +) + +from pymongo import monitoring +from pymongo.hello import HelloCompat + +_IS_SYNC = False + + +class TestStreamingProtocol(AsyncIntegrationTest): + @async_client_context.require_failCommand_appName + async def test_failCommand_streaming(self): + listener = ServerEventListener() + hb_listener = HeartbeatEventListener() + client = await self.async_rs_or_single_client( + event_listeners=[listener, hb_listener], + heartbeatFrequencyMS=500, + appName="failingHeartbeatTest", + ) + # Force a connection. + await client.admin.command("ping") + address = await client.address + listener.reset() + + fail_hello = { + "configureFailPoint": "failCommand", + "mode": {"times": 4}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "closeConnection": False, + "errorCode": 10107, + "appName": "failingHeartbeatTest", + }, + } + async with self.fail_point(fail_hello): + + def _marked_unknown(event): + return ( + event.server_address == address + and not event.new_description.is_server_type_known + ) + + def _discovered_node(event): + return ( + event.server_address == address + and not event.previous_description.is_server_type_known + and event.new_description.is_server_type_known + ) + + def marked_unknown(): + return len(listener.matching(_marked_unknown)) >= 1 + + def rediscovered(): + return len(listener.matching(_discovered_node)) >= 1 + + # Topology events are not published synchronously + await async_wait_until(marked_unknown, "mark node unknown") + await async_wait_until(rediscovered, "rediscover node") + + # Server should be selectable. + await client.admin.command("ping") + + @async_client_context.require_failCommand_appName + async def test_streaming_rtt(self): + listener = ServerEventListener() + hb_listener = HeartbeatEventListener() + # On Windows, RTT can actually be 0.0 because time.time() only has + # 1-15 millisecond resolution. We need to delay the initial hello + # to ensure that RTT is never zero. + name = "streamingRttTest" + delay_hello: dict = { + "configureFailPoint": "failCommand", + "mode": {"times": 1000}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "blockConnection": True, + "blockTimeMS": 20, + # This can be uncommented after SERVER-49220 is fixed. + # 'appName': name, + }, + } + async with self.fail_point(delay_hello): + client = await self.async_rs_or_single_client( + event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name + ) + # Force a connection. + await client.admin.command("ping") + address = await client.address + + delay_hello["data"]["blockTimeMS"] = 500 + delay_hello["data"]["appName"] = name + async with self.fail_point(delay_hello): + + def rtt_exceeds_250_ms(): + # XXX: Add a public TopologyDescription getter to MongoClient? + topology = client._topology + sd = topology.description.server_descriptions()[address] + assert sd.round_trip_time is not None + return sd.round_trip_time > 0.250 + + await async_wait_until(rtt_exceeds_250_ms, "exceed 250ms RTT") + + # Server should be selectable. + await client.admin.command("ping") + + def changed_event(event): + return event.server_address == address and isinstance( + event, monitoring.ServerDescriptionChangedEvent + ) + + # There should only be one event published, for the initial discovery. + events = listener.matching(changed_event) + self.assertEqual(1, len(events)) + self.assertGreater(events[0].new_description.round_trip_time, 0) + + @async_client_context.require_failCommand_appName + async def test_monitor_waits_after_server_check_error(self): + # This test implements: + # https://github.com/mongodb/specifications/blob/master/source/server-discovery-and-monitoring/server-discovery-and-monitoring-tests.md#monitors-sleep-at-least-minheartbeatfreqencyms-between-checks + fail_hello = { + "mode": {"times": 5}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 1234, + "appName": "SDAMMinHeartbeatFrequencyTest", + }, + } + async with self.fail_point(fail_hello): + start = time.time() + client = await self.async_single_client( + appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000 + ) + # Force a connection. + await client.admin.command("ping") + duration = time.time() - start + # Explanation of the expected events: + # 0ms: run configureFailPoint + # 1ms: create MongoClient + # 2ms: failed monitor handshake, 1 + # 502ms: failed monitor handshake, 2 + # 1002ms: failed monitor handshake, 3 + # 1502ms: failed monitor handshake, 4 + # 2002ms: failed monitor handshake, 5 + # 2502ms: monitor handshake succeeds + # 2503ms: run awaitable hello + # 2504ms: application handshake succeeds + # 2505ms: ping command succeeds + self.assertGreaterEqual(duration, 2) + self.assertLessEqual(duration, 3.5) + + @async_client_context.require_failCommand_appName + async def test_heartbeat_awaited_flag(self): + hb_listener = HeartbeatEventListener() + client = await self.async_single_client( + event_listeners=[hb_listener], + heartbeatFrequencyMS=500, + appName="heartbeatEventAwaitedFlag", + ) + # Force a connection. + await client.admin.command("ping") + + def hb_succeeded(event): + return isinstance(event, monitoring.ServerHeartbeatSucceededEvent) + + def hb_failed(event): + return isinstance(event, monitoring.ServerHeartbeatFailedEvent) + + fail_heartbeat = { + "mode": {"times": 2}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "closeConnection": True, + "appName": "heartbeatEventAwaitedFlag", + }, + } + async with self.fail_point(fail_heartbeat): + await async_wait_until( + lambda: hb_listener.matching(hb_failed), "published failed event" + ) + # Reconnect. + await client.admin.command("ping") + + hb_succeeded_events = hb_listener.matching(hb_succeeded) + hb_failed_events = hb_listener.matching(hb_failed) + self.assertFalse(hb_succeeded_events[0].awaited) + self.assertTrue(hb_failed_events[0].awaited) + # Depending on thread scheduling, the failed heartbeat could occur on + # the second or third check. + events = [type(e) for e in hb_listener.events[:4]] + if events == [ + monitoring.ServerHeartbeatStartedEvent, + monitoring.ServerHeartbeatSucceededEvent, + monitoring.ServerHeartbeatStartedEvent, + monitoring.ServerHeartbeatFailedEvent, + ]: + self.assertFalse(hb_succeeded_events[1].awaited) + else: + self.assertTrue(hb_succeeded_events[1].awaited) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index d782aa1dd..894e89e20 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -30,6 +30,8 @@ from test.utils import ( from pymongo import monitoring from pymongo.hello import HelloCompat +_IS_SYNC = True + class TestStreamingProtocol(IntegrationTest): @client_context.require_failCommand_appName @@ -41,7 +43,6 @@ class TestStreamingProtocol(IntegrationTest): heartbeatFrequencyMS=500, appName="failingHeartbeatTest", ) - self.addCleanup(client.close) # Force a connection. client.admin.command("ping") address = client.address @@ -78,7 +79,7 @@ class TestStreamingProtocol(IntegrationTest): def rediscovered(): return len(listener.matching(_discovered_node)) >= 1 - # Topology events are published asynchronously + # Topology events are not published synchronously wait_until(marked_unknown, "mark node unknown") wait_until(rediscovered, "rediscover node") @@ -108,7 +109,6 @@ class TestStreamingProtocol(IntegrationTest): client = self.rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name ) - self.addCleanup(client.close) # Force a connection. client.admin.command("ping") address = client.address @@ -156,7 +156,6 @@ class TestStreamingProtocol(IntegrationTest): client = self.single_client( appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000 ) - self.addCleanup(client.close) # Force a connection. client.admin.command("ping") duration = time.time() - start @@ -183,7 +182,6 @@ class TestStreamingProtocol(IntegrationTest): heartbeatFrequencyMS=500, appName="heartbeatEventAwaitedFlag", ) - self.addCleanup(client.close) # Force a connection. client.admin.command("ping") diff --git a/tools/synchro.py b/tools/synchro.py index ffbea4e53..d1fc032eb 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -234,6 +234,7 @@ converted_tests = [ "test_sessions_unified.py", "test_srv_polling.py", "test_ssl.py", + "test_streaming_protocol.py", "test_transactions.py", "test_unified_format.py", "unified_format.py", From 7108c2199d1b61b3132d89cd1cceaa3928792b8b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 12:11:03 -0500 Subject: [PATCH 25/28] PYTHON-5108 - Convert test.test_transactions_unified to async (#2128) --- .../asynchronous/test_transactions_unified.py | 56 +++++++++++++++++++ test/test_transactions_unified.py | 17 ++++-- tools/synchro.py | 1 + 3 files changed, 70 insertions(+), 4 deletions(-) create mode 100644 test/asynchronous/test_transactions_unified.py diff --git a/test/asynchronous/test_transactions_unified.py b/test/asynchronous/test_transactions_unified.py new file mode 100644 index 000000000..4519a0e39 --- /dev/null +++ b/test/asynchronous/test_transactions_unified.py @@ -0,0 +1,56 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Transactions unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import client_context, unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + + +@client_context.require_no_mmap +def setUpModule(): + pass + + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified") + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +# Location of JSON test specifications for transactions-convenient-api. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified" + ) + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_transactions_unified.py b/test/test_transactions_unified.py index 81137bf65..641e05108 100644 --- a/test/test_transactions_unified.py +++ b/test/test_transactions_unified.py @@ -17,12 +17,15 @@ from __future__ import annotations import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import client_context, unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + @client_context.require_no_mmap def setUpModule(): @@ -30,15 +33,21 @@ def setUpModule(): # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "transactions", "unified") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) # Location of JSON test specifications for transactions-convenient-api. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "transactions-convenient-api", "unified" -) +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified" + ) # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/tools/synchro.py b/tools/synchro.py index d1fc032eb..fc6b16082 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -236,6 +236,7 @@ converted_tests = [ "test_ssl.py", "test_streaming_protocol.py", "test_transactions.py", + "test_transactions_unified.py", "test_unified_format.py", "unified_format.py", ] From ac8fa2d645eaa22dd6346320bfee039294139dbe Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 12:13:56 -0500 Subject: [PATCH 26/28] PYTHON-5094 - Convert test.test_read_preferences to async (#2110) --- test/__init__.py | 2 +- test/asynchronous/__init__.py | 4 +- test/asynchronous/test_read_preferences.py | 730 +++++++++++++++++++++ test/test_read_preferences.py | 108 +-- tools/synchro.py | 1 + 5 files changed, 795 insertions(+), 50 deletions(-) create mode 100644 test/asynchronous/test_read_preferences.py diff --git a/test/__init__.py b/test/__init__.py index b49eee99a..6eda00bde 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -593,7 +593,7 @@ class ClientContext: if self.has_secondaries: return True if self.is_mongos: - shard = self.client.config.shards.find_one()["host"] # type:ignore[index] + shard = (self.client.config.shards.find_one())["host"] # type:ignore[index] num_members = shard.count(",") + 1 return num_members > 1 return False diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index a6ba29baa..b3b0ca93e 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -592,10 +592,10 @@ class AsyncClientContext: @property async def supports_secondary_read_pref(self): - if self.has_secondaries: + if await self.has_secondaries: return True if self.is_mongos: - shard = await self.client.config.shards.find_one()["host"] # type:ignore[index] + shard = (await self.client.config.shards.find_one())["host"] # type:ignore[index] num_members = shard.count(",") + 1 return num_members > 1 return False diff --git a/test/asynchronous/test_read_preferences.py b/test/asynchronous/test_read_preferences.py new file mode 100644 index 000000000..077bc21ea --- /dev/null +++ b/test/asynchronous/test_read_preferences.py @@ -0,0 +1,730 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the replica_set_connection module.""" +from __future__ import annotations + +import contextlib +import copy +import pickle +import random +import sys +from typing import Any + +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + SkipTest, + async_client_context, + connected, + unittest, +) +from test.utils import ( + OvertCommandListener, + async_wait_until, + one, +) +from test.version import Version + +from bson.son import SON +from pymongo.asynchronous.helpers import anext +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.message import _maybe_add_read_preference +from pymongo.read_preferences import ( + MovingAverage, + Nearest, + Primary, + PrimaryPreferred, + ReadPreference, + Secondary, + SecondaryPreferred, +) +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection, readable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestSelections(AsyncIntegrationTest): + @async_client_context.require_connection + async def test_bool(self): + client = await self.async_single_client() + + async def predicate(): + return await client.address + + await async_wait_until(predicate, "discover primary") + selection = Selection.from_topology_description(client._topology.description) + + self.assertTrue(selection) + self.assertFalse(selection.with_server_descriptions([])) + + +class TestReadPreferenceObjects(unittest.TestCase): + prefs = [ + Primary(), + PrimaryPreferred(), + Secondary(), + Nearest(tag_sets=[{"a": 1}, {"b": 2}]), + SecondaryPreferred(max_staleness=30), + ] + + def test_pickle(self): + for pref in self.prefs: + self.assertEqual(pref, pickle.loads(pickle.dumps(pref))) + + def test_copy(self): + for pref in self.prefs: + self.assertEqual(pref, copy.copy(pref)) + + def test_deepcopy(self): + for pref in self.prefs: + self.assertEqual(pref, copy.deepcopy(pref)) + + +class TestReadPreferencesBase(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + # Insert some data so we can use cursors in read_from_which_host + await self.client.pymongo_test.test.drop() + await self.client.get_database( + "pymongo_test", write_concern=WriteConcern(w=async_client_context.w) + ).test.insert_many([{"_id": i} for i in range(10)]) + + self.addAsyncCleanup(self.client.pymongo_test.test.drop) + + async def read_from_which_host(self, client): + """Do a find() on the client and return which host was used""" + cursor = client.pymongo_test.test.find() + await anext(cursor) + return cursor.address + + async def read_from_which_kind(self, client): + """Do a find() on the client and return 'primary' or 'secondary' + depending on which the client used. + """ + address = await self.read_from_which_host(client) + if address == await client.primary: + return "primary" + elif address in await client.secondaries: + return "secondary" + else: + self.fail( + f"Cursor used address {address}, expected either primary " + f"{client.primary} or secondaries {client.secondaries}" + ) + + async def assertReadsFrom(self, expected, **kwargs): + c = await self.async_rs_client(**kwargs) + + async def predicate(): + return len(c.nodes - await c.arbiters) == async_client_context.w + + await async_wait_until(predicate, "discovered all nodes") + + used = await self.read_from_which_kind(c) + self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}") + + +class TestSingleSecondaryOk(TestReadPreferencesBase): + async def test_reads_from_secondary(self): + host, port = next(iter(await self.client.secondaries)) + # Direct connection to a secondary. + client = await self.async_single_client(host, port) + self.assertFalse(await client.is_primary) + + # Regardless of read preference, we should be able to do + # "reads" with a direct connection to a secondary. + # See server-selection.rst#topology-type-single. + self.assertEqual(client.read_preference, ReadPreference.PRIMARY) + + db = client.pymongo_test + coll = db.test + + # Test find and find_one. + self.assertIsNotNone(await coll.find_one()) + self.assertEqual(10, len(await coll.find().to_list())) + + # Test some database helpers. + self.assertIsNotNone(await db.list_collection_names()) + self.assertIsNotNone(await db.validate_collection("test")) + self.assertIsNotNone(await db.command("ping")) + + # Test some collection helpers. + self.assertEqual(10, await coll.count_documents({})) + self.assertEqual(10, len(await coll.distinct("_id"))) + self.assertIsNotNone(await coll.aggregate([])) + self.assertIsNotNone(await coll.index_information()) + + +class TestReadPreferences(TestReadPreferencesBase): + async def test_mode_validation(self): + for mode in ( + ReadPreference.PRIMARY, + ReadPreference.PRIMARY_PREFERRED, + ReadPreference.SECONDARY, + ReadPreference.SECONDARY_PREFERRED, + ReadPreference.NEAREST, + ): + self.assertEqual( + mode, (await self.async_rs_client(read_preference=mode)).read_preference + ) + + with self.assertRaises(TypeError): + await self.async_rs_client(read_preference="foo") + + async def test_tag_sets_validation(self): + S = Secondary(tag_sets=[{}]) + self.assertEqual( + [{}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets + ) + + S = Secondary(tag_sets=[{"k": "v"}]) + self.assertEqual( + [{"k": "v"}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets + ) + + S = Secondary(tag_sets=[{"k": "v"}, {}]) + self.assertEqual( + [{"k": "v"}, {}], + (await self.async_rs_client(read_preference=S)).read_preference.tag_sets, + ) + + self.assertRaises(ValueError, Secondary, tag_sets=[]) + + # One dict not ok, must be a list of dicts + self.assertRaises(TypeError, Secondary, tag_sets={"k": "v"}) + + self.assertRaises(TypeError, Secondary, tag_sets="foo") + + self.assertRaises(TypeError, Secondary, tag_sets=["foo"]) + + async def test_threshold_validation(self): + self.assertEqual( + 17, + ( + await self.async_rs_client(localThresholdMS=17, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 42, + ( + await self.async_rs_client(localThresholdMS=42, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 666, + ( + await self.async_rs_client(localThresholdMS=666, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 0, + ( + await self.async_rs_client(localThresholdMS=0, connect=False) + ).options.local_threshold_ms, + ) + + with self.assertRaises(ValueError): + await self.async_rs_client(localthresholdms=-1) + + async def test_zero_latency(self): + ping_times: set = set() + # Generate unique ping times. + while len(ping_times) < len(self.client.nodes): + ping_times.add(random.random()) + for ping_time, host in zip(ping_times, self.client.nodes): + ServerDescription._host_to_round_trip_time[host] = ping_time + try: + client = await connected( + await self.async_rs_client(readPreference="nearest", localThresholdMS=0) + ) + await async_wait_until( + lambda: client.nodes == self.client.nodes, "discovered all nodes" + ) + host = await self.read_from_which_host(client) + for _ in range(5): + self.assertEqual(host, await self.read_from_which_host(client)) + finally: + ServerDescription._host_to_round_trip_time.clear() + + async def test_primary(self): + await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY) + + async def test_primary_with_tags(self): + # Tags not allowed with PRIMARY + with self.assertRaises(ConfigurationError): + await self.async_rs_client(tag_sets=[{"dc": "ny"}]) + + async def test_primary_preferred(self): + await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) + + async def test_secondary(self): + await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY) + + async def test_secondary_preferred(self): + await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY_PREFERRED) + + async def test_nearest(self): + # With high localThresholdMS, expect to read from any + # member + c = await self.async_rs_client( + read_preference=ReadPreference.NEAREST, localThresholdMS=10000 + ) # 10 seconds + + data_members = {await self.client.primary} | await self.client.secondaries + + # This is a probabilistic test; track which members we've read from so + # far, and keep reading until we've used all the members or give up. + # Chance of using only 2 of 3 members 10k times if there's no bug = + # 3 * (2/3)**10000, very low. + used: set = set() + i = 0 + while data_members.difference(used) and i < 10000: + address = await self.read_from_which_host(c) + used.add(address) + i += 1 + + not_used = data_members.difference(used) + latencies = ", ".join( + "%s: %sms" % (server.description.address, server.description.round_trip_time) + for server in await (await c._get_topology()).select_servers( + readable_server_selector, _Op.TEST + ) + ) + + self.assertFalse( + not_used, + "Expected to use primary and all secondaries for mode NEAREST," + f" but didn't use {not_used}\nlatencies: {latencies}", + ) + + +class ReadPrefTester(AsyncMongoClient): + def __init__(self, *args, **kwargs): + self.has_read_from = set() + client_options = async_client_context.client_options + client_options.update(kwargs) + super().__init__(*args, **client_options) + + async def _conn_for_reads(self, read_preference, session, operation): + context = await super()._conn_for_reads(read_preference, session, operation) + return context + + @contextlib.asynccontextmanager + async def _conn_from_server(self, read_preference, server, session): + context = super()._conn_from_server(read_preference, server, session) + async with context as (conn, read_preference): + await self.record_a_read(conn.address) + yield conn, read_preference + + async def record_a_read(self, address): + server = await (await self._get_topology()).select_server_by_address(address, _Op.TEST, 0) + self.has_read_from.add(server) + + +_PREF_MAP = [ + (Primary, SERVER_TYPE.RSPrimary), + (PrimaryPreferred, SERVER_TYPE.RSPrimary), + (Secondary, SERVER_TYPE.RSSecondary), + (SecondaryPreferred, SERVER_TYPE.RSSecondary), + (Nearest, "any"), +] + + +class TestCommandAndReadPreference(AsyncIntegrationTest): + c: ReadPrefTester + client_version: Version + + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + self.c = ReadPrefTester( + # Ignore round trip times, to test ReadPreference modes only. + localThresholdMS=1000 * 1000, + ) + self.client_version = await Version.async_from_client(self.c) + # mapReduce fails if the collection does not exist. + coll = self.c.pymongo_test.get_collection( + "test", write_concern=WriteConcern(w=async_client_context.w) + ) + await coll.insert_one({}) + + async def asyncTearDown(self): + await self.c.drop_database("pymongo_test") + await self.c.close() + + async def executed_on_which_server(self, client, fn, *args, **kwargs): + """Execute fn(*args, **kwargs) and return the Server instance used.""" + client.has_read_from.clear() + await fn(*args, **kwargs) + self.assertEqual(1, len(client.has_read_from)) + return one(client.has_read_from) + + async def assertExecutedOn(self, server_type, client, fn, *args, **kwargs): + server = await self.executed_on_which_server(client, fn, *args, **kwargs) + self.assertEqual( + SERVER_TYPE._fields[server_type], SERVER_TYPE._fields[server.description.server_type] + ) + + async def _test_fn(self, server_type, fn): + for _ in range(10): + if server_type == "any": + used = set() + for _ in range(1000): + server = await self.executed_on_which_server(self.c, fn) + used.add(server.description.address) + if len(used) == len(await self.c.secondaries) + 1: + # Success + break + + assert await self.c.primary is not None + unused = (await self.c.secondaries).union({await self.c.primary}).difference(used) + if unused: + self.fail("Some members not used for NEAREST: %s" % (unused)) + else: + await self.assertExecutedOn(server_type, self.c, fn) + + async def _test_primary_helper(self, func): + # Helpers that ignore read preference. + await self._test_fn(SERVER_TYPE.RSPrimary, func) + + async def _test_coll_helper(self, secondary_ok, coll, meth, *args, **kwargs): + for mode, server_type in _PREF_MAP: + new_coll = coll.with_options(read_preference=mode()) + + async def func(): + return await getattr(new_coll, meth)(*args, **kwargs) + + if secondary_ok: + await self._test_fn(server_type, func) + else: + await self._test_fn(SERVER_TYPE.RSPrimary, func) + + async def test_command(self): + # Test that the generic command helper obeys the read preference + # passed to it. + for mode, server_type in _PREF_MAP: + + async def func(): + return await self.c.pymongo_test.command("dbStats", read_preference=mode()) + + await self._test_fn(server_type, func) + + async def test_create_collection(self): + # create_collection runs listCollections on the primary to check if + # the collection already exists. + async def func(): + return await self.c.pymongo_test.create_collection( + "some_collection%s" % random.randint(0, sys.maxsize) + ) + + await self._test_primary_helper(func) + + async def test_count_documents(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {}) + + async def test_estimated_document_count(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "estimated_document_count") + + async def test_distinct(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "distinct", "a") + + async def test_aggregate(self): + await self._test_coll_helper( + True, self.c.pymongo_test.test, "aggregate", [{"$project": {"_id": 1}}] + ) + + async def test_aggregate_write(self): + # 5.0 servers support $out on secondaries. + secondary_ok = async_client_context.version.at_least(5, 0) + await self._test_coll_helper( + secondary_ok, + self.c.pymongo_test.test, + "aggregate", + [{"$project": {"_id": 1}}, {"$out": "agg_write_test"}], + ) + + +class TestMovingAverage(unittest.TestCase): + def test_moving_average(self): + avg = MovingAverage() + self.assertIsNone(avg.get()) + avg.add_sample(10) + self.assertAlmostEqual(10, avg.get()) # type: ignore + avg.add_sample(20) + self.assertAlmostEqual(12, avg.get()) # type: ignore + avg.add_sample(30) + self.assertAlmostEqual(15.6, avg.get()) # type: ignore + + +class TestMongosAndReadPreference(AsyncIntegrationTest): + def test_read_preference_document(self): + pref = Primary() + self.assertEqual(pref.document, {"mode": "primary"}) + + pref = PrimaryPreferred() + self.assertEqual(pref.document, {"mode": "primaryPreferred"}) + pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "primaryPreferred", "tags": [{"dc": "sf"}]}) + pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, + {"mode": "primaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}, + ) + + pref = Secondary() + self.assertEqual(pref.document, {"mode": "secondary"}) + pref = Secondary(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}]}) + pref = Secondary(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30} + ) + + pref = SecondaryPreferred() + self.assertEqual(pref.document, {"mode": "secondaryPreferred"}) + pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}]}) + pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, + {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}, + ) + + pref = Nearest() + self.assertEqual(pref.document, {"mode": "nearest"}) + pref = Nearest(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}]}) + pref = Nearest(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30} + ) + + with self.assertRaises(TypeError): + # Float is prohibited. + Nearest(max_staleness=1.5) # type: ignore + + with self.assertRaises(ValueError): + Nearest(max_staleness=0) + + with self.assertRaises(ValueError): + Nearest(max_staleness=-2) + + def test_read_preference_document_hedge(self): + cases = { + "primaryPreferred": PrimaryPreferred, + "secondary": Secondary, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, + } + for mode, cls in cases.items(): + with self.assertRaises(TypeError): + cls(hedge=[]) # type: ignore + + pref = cls(hedge={}) + self.assertEqual(pref.document, {"mode": mode}) + out = _maybe_add_read_preference({}, pref) + if cls == SecondaryPreferred: + # SecondaryPreferred without hedge doesn't add $readPreference. + self.assertEqual(out, {}) + else: + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge: dict[str, Any] = {"enabled": True} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge = {"enabled": False} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge = {"enabled": False, "extra": "option"} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + async def test_send_hedge(self): + cases = { + "primaryPreferred": PrimaryPreferred, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, + } + if await async_client_context.supports_secondary_read_pref: + cases["secondary"] = Secondary + listener = OvertCommandListener() + client = await self.async_rs_client(event_listeners=[listener]) + await client.admin.command("ping") + for _mode, cls in cases.items(): + pref = cls(hedge={"enabled": True}) + coll = client.test.get_collection("test", read_preference=pref) + listener.reset() + await coll.find_one() + started = listener.started_events + self.assertEqual(len(started), 1, started) + cmd = started[0].command + if async_client_context.is_rs or async_client_context.is_mongos: + self.assertIn("$readPreference", cmd) + self.assertEqual(cmd["$readPreference"], pref.document) + else: + self.assertNotIn("$readPreference", cmd) + + def test_maybe_add_read_preference(self): + # Primary doesn't add $readPreference + out = _maybe_add_read_preference({}, Primary()) + self.assertEqual(out, {}) + + pref = PrimaryPreferred() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = PrimaryPreferred(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + pref = Secondary() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = Secondary(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + # SecondaryPreferred without tag_sets or max_staleness doesn't add + # $readPreference + pref = SecondaryPreferred() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, {}) + pref = SecondaryPreferred(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = SecondaryPreferred(max_staleness=120) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + pref = Nearest() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = Nearest(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + criteria = SON([("$query", {}), ("$orderby", SON([("_id", 1)]))]) + pref = Nearest() + out = _maybe_add_read_preference(criteria, pref) + self.assertEqual( + out, + SON( + [ + ("$query", {}), + ("$orderby", SON([("_id", 1)])), + ("$readPreference", pref.document), + ] + ), + ) + pref = Nearest(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference(criteria, pref) + self.assertEqual( + out, + SON( + [ + ("$query", {}), + ("$orderby", SON([("_id", 1)])), + ("$readPreference", pref.document), + ] + ), + ) + + @async_client_context.require_mongos + async def test_mongos(self): + res = await async_client_context.client.config.shards.find_one() + assert res is not None + shard = res["host"] + num_members = shard.count(",") + 1 + if num_members == 1: + raise SkipTest("Need a replica set shard to test.") + coll = async_client_context.client.pymongo_test.get_collection( + "test", write_concern=WriteConcern(w=num_members) + ) + await coll.drop() + res = await coll.insert_many([{} for _ in range(5)]) + first_id = res.inserted_ids[0] + last_id = res.inserted_ids[-1] + + # Note - this isn't a perfect test since there's no way to + # tell what shard member a query ran on. + for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()): + qcoll = coll.with_options(read_preference=pref) + results = await qcoll.find().sort([("_id", 1)]).to_list() + self.assertEqual(first_id, results[0]["_id"]) + self.assertEqual(last_id, results[-1]["_id"]) + results = await qcoll.find().sort([("_id", -1)]).to_list() + self.assertEqual(first_id, results[-1]["_id"]) + self.assertEqual(last_id, results[0]["_id"]) + + @async_client_context.require_mongos + async def test_mongos_max_staleness(self): + # Sanity check that we're sending maxStalenessSeconds + coll = async_client_context.client.pymongo_test.get_collection( + "test", read_preference=SecondaryPreferred(max_staleness=120) + ) + # No error + await coll.find_one() + + coll = async_client_context.client.pymongo_test.get_collection( + "test", read_preference=SecondaryPreferred(max_staleness=10) + ) + try: + await coll.find_one() + except OperationFailure as exc: + self.assertEqual(160, exc.code) + else: + self.fail("mongos accepted invalid staleness") + + coll = ( + await self.async_single_client( + readPreference="secondaryPreferred", maxStalenessSeconds=120 + ) + ).pymongo_test.test + # No error + await coll.find_one() + + coll = ( + await self.async_single_client( + readPreference="secondaryPreferred", maxStalenessSeconds=10 + ) + ).pymongo_test.test + try: + await coll.find_one() + except OperationFailure as exc: + self.assertEqual(160, exc.code) + else: + self.fail("mongos accepted invalid staleness") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 32883399e..0d38f3f00 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -26,7 +26,13 @@ from pymongo.operations import _Op sys.path[0:0] = [""] -from test import IntegrationTest, SkipTest, client_context, connected, unittest +from test import ( + IntegrationTest, + SkipTest, + client_context, + connected, + unittest, +) from test.utils import ( OvertCommandListener, one, @@ -49,16 +55,22 @@ from pymongo.read_preferences import ( from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection, readable_server_selector from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.helpers import next from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestSelections(IntegrationTest): @client_context.require_connection def test_bool(self): client = self.single_client() - wait_until(lambda: client.address, "discover primary") + def predicate(): + return client.address + + wait_until(predicate, "discover primary") selection = Selection.from_topology_description(client._topology.description) self.assertTrue(selection) @@ -88,11 +100,7 @@ class TestReadPreferenceObjects(unittest.TestCase): class TestReadPreferencesBase(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() - def setUp(self): super().setUp() # Insert some data so we can use cursors in read_from_which_host @@ -123,11 +131,14 @@ class TestReadPreferencesBase(IntegrationTest): f"Cursor used address {address}, expected either primary " f"{client.primary} or secondaries {client.secondaries}" ) - return None def assertReadsFrom(self, expected, **kwargs): c = self.rs_client(**kwargs) - wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes") + + def predicate(): + return len(c.nodes - c.arbiters) == client_context.w + + wait_until(predicate, "discovered all nodes") used = self.read_from_which_kind(c) self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}") @@ -150,7 +161,7 @@ class TestSingleSecondaryOk(TestReadPreferencesBase): # Test find and find_one. self.assertIsNotNone(coll.find_one()) - self.assertEqual(10, len(list(coll.find()))) + self.assertEqual(10, len(coll.find().to_list())) # Test some database helpers. self.assertIsNotNone(db.list_collection_names()) @@ -173,20 +184,22 @@ class TestReadPreferences(TestReadPreferencesBase): ReadPreference.SECONDARY_PREFERRED, ReadPreference.NEAREST, ): - self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference) + self.assertEqual(mode, (self.rs_client(read_preference=mode)).read_preference) - self.assertRaises(TypeError, self.rs_client, read_preference="foo") + with self.assertRaises(TypeError): + self.rs_client(read_preference="foo") def test_tag_sets_validation(self): S = Secondary(tag_sets=[{}]) - self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{}], (self.rs_client(read_preference=S)).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}]) - self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{"k": "v"}], (self.rs_client(read_preference=S)).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}, {}]) self.assertEqual( - [{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets + [{"k": "v"}, {}], + (self.rs_client(read_preference=S)).read_preference.tag_sets, ) self.assertRaises(ValueError, Secondary, tag_sets=[]) @@ -200,22 +213,27 @@ class TestReadPreferences(TestReadPreferencesBase): def test_threshold_validation(self): self.assertEqual( - 17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms + 17, + (self.rs_client(localThresholdMS=17, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms + 42, + (self.rs_client(localThresholdMS=42, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms + 666, + (self.rs_client(localThresholdMS=666, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms + 0, + (self.rs_client(localThresholdMS=0, connect=False)).options.local_threshold_ms, ) - self.assertRaises(ValueError, self.rs_client, localthresholdms=-1) + with self.assertRaises(ValueError): + self.rs_client(localthresholdms=-1) def test_zero_latency(self): ping_times: set = set() @@ -238,7 +256,8 @@ class TestReadPreferences(TestReadPreferencesBase): def test_primary_with_tags(self): # Tags not allowed with PRIMARY - self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}]) + with self.assertRaises(ConfigurationError): + self.rs_client(tag_sets=[{"dc": "ny"}]) def test_primary_preferred(self): self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) @@ -272,7 +291,7 @@ class TestReadPreferences(TestReadPreferencesBase): not_used = data_members.difference(used) latencies = ", ".join( "%s: %sms" % (server.description.address, server.description.round_trip_time) - for server in c._get_topology().select_servers(readable_server_selector, _Op.TEST) + for server in (c._get_topology()).select_servers(readable_server_selector, _Op.TEST) ) self.assertFalse( @@ -289,12 +308,9 @@ class ReadPrefTester(MongoClient): client_options.update(kwargs) super().__init__(*args, **client_options) - @contextlib.contextmanager def _conn_for_reads(self, read_preference, session, operation): context = super()._conn_for_reads(read_preference, session, operation) - with context as (conn, read_preference): - self.record_a_read(conn.address) - yield conn, read_preference + return context @contextlib.contextmanager def _conn_from_server(self, read_preference, server, session): @@ -304,7 +320,7 @@ class ReadPrefTester(MongoClient): yield conn, read_preference def record_a_read(self, address): - server = self._get_topology().select_server_by_address(address, _Op.TEST, 0) + server = (self._get_topology()).select_server_by_address(address, _Op.TEST, 0) self.has_read_from.add(server) @@ -321,25 +337,23 @@ class TestCommandAndReadPreference(IntegrationTest): c: ReadPrefTester client_version: Version - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() - cls.c = ReadPrefTester( + def setUp(self): + super().setUp() + self.c = ReadPrefTester( # Ignore round trip times, to test ReadPreference modes only. localThresholdMS=1000 * 1000, ) - cls.client_version = Version.from_client(cls.c) + self.client_version = Version.from_client(self.c) # mapReduce fails if the collection does not exist. - coll = cls.c.pymongo_test.get_collection( + coll = self.c.pymongo_test.get_collection( "test", write_concern=WriteConcern(w=client_context.w) ) coll.insert_one({}) - @classmethod - def tearDownClass(cls): - cls.c.drop_database("pymongo_test") - cls.c.close() + def tearDown(self): + self.c.drop_database("pymongo_test") + self.c.close() def executed_on_which_server(self, client, fn, *args, **kwargs): """Execute fn(*args, **kwargs) and return the Server instance used.""" @@ -366,7 +380,7 @@ class TestCommandAndReadPreference(IntegrationTest): break assert self.c.primary is not None - unused = self.c.secondaries.union({self.c.primary}).difference(used) + unused = (self.c.secondaries).union({self.c.primary}).difference(used) if unused: self.fail("Some members not used for NEAREST: %s" % (unused)) else: @@ -401,11 +415,12 @@ class TestCommandAndReadPreference(IntegrationTest): def test_create_collection(self): # create_collection runs listCollections on the primary to check if # the collection already exists. - self._test_primary_helper( - lambda: self.c.pymongo_test.create_collection( + def func(): + return self.c.pymongo_test.create_collection( "some_collection%s" % random.randint(0, sys.maxsize) ) - ) + + self._test_primary_helper(func) def test_count_documents(self): self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {}) @@ -545,7 +560,6 @@ class TestMongosAndReadPreference(IntegrationTest): cases["secondary"] = Secondary listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) - self.addCleanup(client.close) client.admin.command("ping") for _mode, cls in cases.items(): pref = cls(hedge={"enabled": True}) @@ -645,10 +659,10 @@ class TestMongosAndReadPreference(IntegrationTest): # tell what shard member a query ran on. for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()): qcoll = coll.with_options(read_preference=pref) - results = list(qcoll.find().sort([("_id", 1)])) + results = qcoll.find().sort([("_id", 1)]).to_list() self.assertEqual(first_id, results[0]["_id"]) self.assertEqual(last_id, results[-1]["_id"]) - results = list(qcoll.find().sort([("_id", -1)])) + results = qcoll.find().sort([("_id", -1)]).to_list() self.assertEqual(first_id, results[-1]["_id"]) self.assertEqual(last_id, results[0]["_id"]) @@ -671,14 +685,14 @@ class TestMongosAndReadPreference(IntegrationTest): else: self.fail("mongos accepted invalid staleness") - coll = self.single_client( - readPreference="secondaryPreferred", maxStalenessSeconds=120 + coll = ( + self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=120) ).pymongo_test.test # No error coll.find_one() - coll = self.single_client( - readPreference="secondaryPreferred", maxStalenessSeconds=10 + coll = ( + self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=10) ).pymongo_test.test try: coll.find_one() diff --git a/tools/synchro.py b/tools/synchro.py index fc6b16082..443d57d41 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -222,6 +222,7 @@ converted_tests = [ "test_on_demand_csfle.py", "test_raw_bson.py", "test_read_concern.py", + "test_read_preferences.py", "test_read_write_concern_spec.py", "test_retryable_reads.py", "test_retryable_reads_unified.py", From f344eb7965a95d0d949c58f65d71bff9a07f6adb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 12:14:12 -0500 Subject: [PATCH 27/28] PYTHON-5109 - Convert test.test_versioned_api to async (#2129) --- .../test_versioned_api_integration.py | 86 +++++++++++++++++++ test/test_versioned_api.py | 47 +--------- test/test_versioned_api_integration.py | 82 ++++++++++++++++++ tools/synchro.py | 1 + 4 files changed, 173 insertions(+), 43 deletions(-) create mode 100644 test/asynchronous/test_versioned_api_integration.py create mode 100644 test/test_versioned_api_integration.py diff --git a/test/asynchronous/test_versioned_api_integration.py b/test/asynchronous/test_versioned_api_integration.py new file mode 100644 index 000000000..7e9a79da9 --- /dev/null +++ b/test/asynchronous/test_versioned_api_integration.py @@ -0,0 +1,86 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +from pathlib import Path +from test.asynchronous.unified_format import generate_test_classes + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import OvertCommandListener + +from pymongo.server_api import ServerApi + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api") + + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + + +class TestServerApiIntegration(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + + def assertServerApi(self, event): + self.assertIn("apiVersion", event.command) + self.assertEqual(event.command["apiVersion"], "1") + + def assertServerApiInAllCommands(self, events): + for event in events: + self.assertServerApi(event) + + @async_client_context.require_version_min(4, 7) + async def test_command_options(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + server_api=ServerApi("1"), event_listeners=[listener] + ) + coll = client.test.test + await coll.insert_many([{} for _ in range(100)]) + self.addAsyncCleanup(coll.delete_many, {}) + await coll.find(batch_size=25).to_list() + await client.admin.command("ping") + self.assertServerApiInAllCommands(listener.started_events) + + @async_client_context.require_version_min(4, 7) + @async_client_context.require_transactions + async def test_command_options_txn(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + server_api=ServerApi("1"), event_listeners=[listener] + ) + coll = client.test.test + await coll.insert_many([{} for _ in range(100)]) + self.addAsyncCleanup(coll.delete_many, {}) + + listener.reset() + async with client.start_session() as s, await s.start_transaction(): + await coll.insert_many([{} for _ in range(100)], session=s) + await coll.find(batch_size=25, session=s).to_list() + await client.test.command("find", "test", session=s) + self.assertServerApiInAllCommands(listener.started_events) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_versioned_api.py b/test/test_versioned_api.py index 7a25a507d..19b125770 100644 --- a/test/test_versioned_api.py +++ b/test/test_versioned_api.py @@ -13,28 +13,18 @@ # limitations under the License. from __future__ import annotations -import os import sys +from test import UnitTest sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener +from test import unittest +from pymongo.mongo_client import MongoClient from pymongo.server_api import ServerApi, ServerApiVersion -from pymongo.synchronous.mongo_client import MongoClient - -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "versioned-api") - -# Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) -class TestServerApi(IntegrationTest): - RUN_ON_LOAD_BALANCER = True - RUN_ON_SERVERLESS = True - +class TestServerApi(UnitTest): def test_server_api_defaults(self): api = ServerApi(ServerApiVersion.V1) self.assertEqual(api.version, "1") @@ -74,35 +64,6 @@ class TestServerApi(IntegrationTest): for event in events: self.assertServerApi(event) - @client_context.require_version_min(4, 7) - def test_command_options(self): - listener = OvertCommandListener() - client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) - self.addCleanup(client.close) - coll = client.test.test - coll.insert_many([{} for _ in range(100)]) - self.addCleanup(coll.delete_many, {}) - list(coll.find(batch_size=25)) - client.admin.command("ping") - self.assertServerApiInAllCommands(listener.started_events) - - @client_context.require_version_min(4, 7) - @client_context.require_transactions - def test_command_options_txn(self): - listener = OvertCommandListener() - client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) - self.addCleanup(client.close) - coll = client.test.test - coll.insert_many([{} for _ in range(100)]) - self.addCleanup(coll.delete_many, {}) - - listener.reset() - with client.start_session() as s, s.start_transaction(): - coll.insert_many([{} for _ in range(100)], session=s) - list(coll.find(batch_size=25, session=s)) - client.test.command("find", "test", session=s) - self.assertServerApiInAllCommands(listener.started_events) - if __name__ == "__main__": unittest.main() diff --git a/test/test_versioned_api_integration.py b/test/test_versioned_api_integration.py new file mode 100644 index 000000000..502198576 --- /dev/null +++ b/test/test_versioned_api_integration.py @@ -0,0 +1,82 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +from pathlib import Path +from test.unified_format import generate_test_classes + +sys.path[0:0] = [""] + +from test import IntegrationTest, client_context, unittest +from test.utils import OvertCommandListener + +from pymongo.server_api import ServerApi + +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api") + + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + + +class TestServerApiIntegration(IntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + + def assertServerApi(self, event): + self.assertIn("apiVersion", event.command) + self.assertEqual(event.command["apiVersion"], "1") + + def assertServerApiInAllCommands(self, events): + for event in events: + self.assertServerApi(event) + + @client_context.require_version_min(4, 7) + def test_command_options(self): + listener = OvertCommandListener() + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + coll = client.test.test + coll.insert_many([{} for _ in range(100)]) + self.addCleanup(coll.delete_many, {}) + coll.find(batch_size=25).to_list() + client.admin.command("ping") + self.assertServerApiInAllCommands(listener.started_events) + + @client_context.require_version_min(4, 7) + @client_context.require_transactions + def test_command_options_txn(self): + listener = OvertCommandListener() + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + coll = client.test.test + coll.insert_many([{} for _ in range(100)]) + self.addCleanup(coll.delete_many, {}) + + listener.reset() + with client.start_session() as s, s.start_transaction(): + coll.insert_many([{} for _ in range(100)], session=s) + coll.find(batch_size=25, session=s).to_list() + client.test.command("find", "test", session=s) + self.assertServerApiInAllCommands(listener.started_events) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index 443d57d41..4b6326a49 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -239,6 +239,7 @@ converted_tests = [ "test_transactions.py", "test_transactions_unified.py", "test_unified_format.py", + "test_versioned_api_integration.py", "unified_format.py", ] From 1b818470fcb14fd8307f33127262bbaeeafab3f9 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 15:05:41 -0500 Subject: [PATCH 28/28] PYTHON-5053 - AsyncMongoClient.close() should await all background tasks (#2127) --- pymongo/asynchronous/mongo_client.py | 6 ++++++ pymongo/asynchronous/monitor.py | 9 +++++++-- pymongo/asynchronous/topology.py | 29 +++++++++++++++++++++++++++- pymongo/periodic_executor.py | 2 ++ pymongo/synchronous/mongo_client.py | 6 ++++++ pymongo/synchronous/monitor.py | 7 +++++-- pymongo/synchronous/topology.py | 29 +++++++++++++++++++++++++++- 7 files changed, 82 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index cf7de19c2..365fc6210 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1565,6 +1565,12 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() self._closed = True + if not _IS_SYNC: + await asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.aclosing. diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index ad1bc70ab..abde7a905 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -112,9 +112,9 @@ class MonitorBase: """ self.gc_safe_close() - async def join(self, timeout: Optional[int] = None) -> None: + async def join(self) -> None: """Wait for the monitor to stop.""" - await self._executor.join(timeout) + await self._executor.join() def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -189,6 +189,11 @@ class Monitor(MonitorBase): self._rtt_monitor.gc_safe_close() self.cancel_check() + async def join(self) -> None: + await asyncio.gather( + self._executor.join(), self._rtt_monitor.join(), return_exceptions=True + ) # type: ignore[func-returns-value] + async def close(self) -> None: self.gc_safe_close() await self._rtt_monitor.close() diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 6d67710a7..3033377de 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import logging import os import queue @@ -29,7 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.asynchronous.monitor import SrvMonitor +from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor from pymongo.asynchronous.pool import Pool from pymongo.asynchronous.server import Server from pymongo.errors import ( @@ -207,6 +208,9 @@ class Topology: if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) + # Stores all monitor tasks that need to be joined on close or server selection + self._monitor_tasks: list[MonitorBase] = [] + async def open(self) -> None: """Start monitoring, or restart after a fork. @@ -241,6 +245,8 @@ class Topology: # Close servers and clear the pools. for server in self._servers.values(): await server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() @@ -283,6 +289,10 @@ class Topology: else: server_timeout = server_selection_timeout + # Cleanup any completed monitor tasks safely + if not _IS_SYNC and self._monitor_tasks: + await self.cleanup_monitors() + async with self._lock: server_descriptions = await self._select_servers_loop( selector, server_timeout, operation, operation_id, address @@ -520,6 +530,8 @@ class Topology: and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): await self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) # Clear the pool from a failed heartbeat. if reset_pool: @@ -695,6 +707,8 @@ class Topology: old_td = self._description for server in self._servers.values(): await server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -705,6 +719,8 @@ class Topology: # Stop SRV polling thread. if self._srv_monitor: await self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -944,6 +960,8 @@ class Topology: for address, server in list(self._servers.items()): if not self._description.has_server(address): await server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: @@ -1031,6 +1049,15 @@ class Topology: else: return ",".join(str(server.error) for server in servers if server.error) + async def cleanup_monitors(self) -> None: + tasks = [] + try: + while self._monitor_tasks: + tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value] + def __repr__(self) -> str: msg = "" if not self._opened: diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 9b10f6e7e..f51a98872 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -75,6 +75,8 @@ class AsyncPeriodicExecutor: callback; see monitor.py. """ self._stopped = True + if self._task is not None: + self._task.cancel() async def join(self, timeout: Optional[int] = None) -> None: if self._task is not None: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 706623c21..8cd08ab72 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1559,6 +1559,12 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() self._closed = True + if not _IS_SYNC: + asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.closing. diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index df4130d4a..211635d8b 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -112,9 +112,9 @@ class MonitorBase: """ self.gc_safe_close() - def join(self, timeout: Optional[int] = None) -> None: + def join(self) -> None: """Wait for the monitor to stop.""" - self._executor.join(timeout) + self._executor.join() def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -189,6 +189,9 @@ class Monitor(MonitorBase): self._rtt_monitor.gc_safe_close() self.cancel_check() + def join(self) -> None: + asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value] + def close(self) -> None: self.gc_safe_close() self._rtt_monitor.close() diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index b03269ae4..09b61f6d0 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import logging import os import queue @@ -61,7 +62,7 @@ from pymongo.server_selectors import ( writable_server_selector, ) from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.synchronous.monitor import SrvMonitor +from pymongo.synchronous.monitor import MonitorBase, SrvMonitor from pymongo.synchronous.pool import Pool from pymongo.synchronous.server import Server from pymongo.topology_description import ( @@ -207,6 +208,9 @@ class Topology: if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) + # Stores all monitor tasks that need to be joined on close or server selection + self._monitor_tasks: list[MonitorBase] = [] + def open(self) -> None: """Start monitoring, or restart after a fork. @@ -241,6 +245,8 @@ class Topology: # Close servers and clear the pools. for server in self._servers.values(): server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() @@ -283,6 +289,10 @@ class Topology: else: server_timeout = server_selection_timeout + # Cleanup any completed monitor tasks safely + if not _IS_SYNC and self._monitor_tasks: + self.cleanup_monitors() + with self._lock: server_descriptions = self._select_servers_loop( selector, server_timeout, operation, operation_id, address @@ -520,6 +530,8 @@ class Topology: and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) # Clear the pool from a failed heartbeat. if reset_pool: @@ -693,6 +705,8 @@ class Topology: old_td = self._description for server in self._servers.values(): server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -703,6 +717,8 @@ class Topology: # Stop SRV polling thread. if self._srv_monitor: self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -942,6 +958,8 @@ class Topology: for address, server in list(self._servers.items()): if not self._description.has_server(address): server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: @@ -1029,6 +1047,15 @@ class Topology: else: return ",".join(str(server.error) for server in servers if server.error) + def cleanup_monitors(self) -> None: + tasks = [] + try: + while self._monitor_tasks: + tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value] + def __repr__(self) -> str: msg = "" if not self._opened: