diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 8d7a9f082..5e8429dd2 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html if [ -z "$TEST_SUITES" ]; then - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS else - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS fi else python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index fbc7ff739..858d269e0 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -67,7 +67,7 @@ jobs: # Note: the default manylinux is manylinux2014 run: | python -m pip install -U pip - python -m pip install "cibuildwheel>=2.17,<3" + python -m pip install "cibuildwheel>=2.20,<3" - name: Build wheels env: @@ -89,6 +89,9 @@ jobs: ls wheelhouse/*cp310*.whl ls wheelhouse/*cp311*.whl ls wheelhouse/*cp312*.whl + ls wheelhouse/*cp313*.whl + # Free-threading builds: + ls wheelhouse/*cp313t*.whl - uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index e55444cec..40991440d 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -51,11 +51,18 @@ jobs: strategy: matrix: os: [ubuntu-20.04] - python-version: ["3.9", "pypy-3.9", "3.13"] + python-version: ["3.9", "pypy-3.9", "3.13", "3.13t"] name: CPython ${{ matrix.python-version }}-${{ matrix.os }} steps: - uses: actions/checkout@v4 - - name: Setup Python + - if: ${{ matrix.python-version == '3.13t' }} + name: Setup free-threaded Python + uses: deadsnakes/action@v3.2.0 + with: + python-version: 3.13 + nogil: true + - if: ${{ matrix.python-version != '3.13t' }} + name: Setup Python uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -65,9 +72,13 @@ jobs: - name: Install dependencies run: | pip install -U pip - if [ "${{ matrix.python-version }}" == "3.13" ]; then + if [[ "${{ matrix.python-version }}" == "3.13" ]]; then pip install --pre cffi setuptools pip install --no-build-isolation hatch + elif [[ "${{ matrix.python-version }}" == "3.13t" ]]; then + # Hatch can't be installed on 3.13t, use pytest directly. + pip install . + pip install -r requirements/test.txt else pip install hatch fi @@ -77,7 +88,11 @@ jobs: mongodb-version: 6.0 - name: Run tests run: | - hatch run test:test + if [[ "${{ matrix.python-version }}" == "3.13t" ]]; then + pytest -v --durations=5 --maxfail=10 + else + hatch run test:test + fi doctest: runs-on: ubuntu-latest diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 223c39228..a66071c28 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -3184,6 +3184,9 @@ static PyModuleDef_Slot _cbson_slots[] = { {Py_mod_exec, _cbson_exec}, #if defined(Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED) {Py_mod_multiple_interpreters, Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED}, +#endif +#if PY_VERSION_HEX >= 0x030D0000 + {Py_mod_gil, Py_MOD_GIL_NOT_USED}, #endif {0, NULL}, }; diff --git a/doc/changelog.rst b/doc/changelog.rst index 6a118f56c..e7b160b17 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -12,6 +12,8 @@ PyMongo 4.11 brings a number of changes including: - Dropped support for Python 3.8. - Dropped support for MongoDB 3.6. +- Added support for free-threaded Python with the GIL disabled. For more information see: + `Free-threaded CPython `_. Issues Resolved ............... diff --git a/pymongo/_cmessagemodule.c b/pymongo/_cmessagemodule.c index b5adbeec3..eb457b341 100644 --- a/pymongo/_cmessagemodule.c +++ b/pymongo/_cmessagemodule.c @@ -1022,6 +1022,9 @@ static PyModuleDef_Slot _cmessage_slots[] = { {Py_mod_exec, _cmessage_exec}, #ifdef Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED {Py_mod_multiple_interpreters, Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED}, +#endif +#if PY_VERSION_HEX >= 0x030D0000 + {Py_mod_gil, Py_MOD_GIL_NOT_USED}, #endif {0, NULL}, }; diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 9b00c13e1..735e54304 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -180,10 +180,20 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc] while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = conn.recv(kms_context.bytes_needed) + if _IS_SYNC: + data = conn.recv(kms_context.bytes_needed) + else: + from pymongo.network_layer import ( # type: ignore[attr-defined] + async_receive_data_socket, + ) + + data = await async_receive_data_socket(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) + # Async raises an OSError instead of returning empty bytes + except OSError as err: + raise OSError("KMS connection closed") from err except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") from None finally: diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 4b57620d8..d14a21f41 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -130,7 +130,7 @@ if sys.platform != "win32": loop.remove_writer(fd) async def _async_receive_ssl( - conn: _sslConn, length: int, loop: AbstractEventLoop + conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) total_read = 0 @@ -145,6 +145,9 @@ if sys.platform != "win32": read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") + # KMS responses update their expected size after the first batch, stop reading after one loop + if once: + return mv[:read] total_read += read except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() @@ -275,6 +278,28 @@ async def async_receive_data( sock.settimeout(sock_timeout) +async def async_receive_data_socket( + sock: Union[socket.socket, _sslConn], length: int +) -> memoryview: + sock_timeout = sock.gettimeout() + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_event_loop() + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + return await asyncio.wait_for( + _async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] + timeout=timeout, + ) + else: + return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] + except asyncio.TimeoutError as err: + raise socket.timeout("timed out") from err + finally: + sock.settimeout(sock_timeout) + + async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: mv = memoryview(bytearray(length)) bytes_read = 0 diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index efef6df9e..506ff8bcb 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -180,10 +180,20 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc] while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = conn.recv(kms_context.bytes_needed) + if _IS_SYNC: + data = conn.recv(kms_context.bytes_needed) + else: + from pymongo.network_layer import ( # type: ignore[attr-defined] + receive_data_socket, + ) + + data = receive_data_socket(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) + # Async raises an OSError instead of returning empty bytes + except OSError as err: + raise OSError("KMS connection closed") from err except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") from None finally: diff --git a/pyproject.toml b/pyproject.toml index 2688aab27..b4f59f67d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -236,6 +236,8 @@ partial_branches = ["if (.*and +)*not _use_c( and.*)*:"] directory = "htmlcov" [tool.cibuildwheel] +# Enable free-threaded support +free-threaded-support = true skip = "pp* *-musllinux*" build-frontend = "build" test-command = "python {project}/tools/fail_if_no_c.py" diff --git a/test/__init__.py b/test/__init__.py index af12bc032..fd33fde29 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -464,11 +464,12 @@ class ClientContext: if not self.connected: pair = self.pair raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and condition(): - if wraps_async: - return f(*args, **kwargs) - else: - return f(*args, **kwargs) + if iscoroutinefunction(condition): + if condition(): + if wraps_async: + return f(*args, **kwargs) + else: + return f(*args, **kwargs) elif condition(): if wraps_async: return f(*args, **kwargs) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 2a44785b2..0579828c4 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -466,11 +466,12 @@ class AsyncClientContext: if not self.connected: pair = await self.pair raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and await condition(): - if wraps_async: - return await f(*args, **kwargs) - else: - return f(*args, **kwargs) + if iscoroutinefunction(condition): + if await condition(): + if wraps_async: + return await f(*args, **kwargs) + else: + return f(*args, **kwargs) elif condition(): if wraps_async: return await f(*args, **kwargs) diff --git a/test/asynchronous/test_client_context.py b/test/asynchronous/test_client_context.py index a0cb53a14..6d7781843 100644 --- a/test/asynchronous/test_client_context.py +++ b/test/asynchronous/test_client_context.py @@ -61,6 +61,13 @@ class TestAsyncClientContext(AsyncUnitTest): self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"]) + def test_free_threading_is_enabled(self): + if "free-threading build" not in sys.version: + raise SkipTest("this test requires the Python free-threading build") + + # If the GIL is enabled then pymongo or one of our deps does not support free-threading. + self.assertFalse(sys._is_gil_enabled()) # type: ignore[attr-defined] + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py new file mode 100644 index 000000000..be3ea22e4 --- /dev/null +++ b/test/asynchronous/test_collation.py @@ -0,0 +1,290 @@ +# Copyright 2016-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the collation module.""" +from __future__ import annotations + +import functools +import warnings +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import EventListener +from typing import Any + +from pymongo.asynchronous.helpers import anext +from pymongo.collation import ( + Collation, + CollationAlternate, + CollationCaseFirst, + CollationMaxVariable, + CollationStrength, +) +from pymongo.errors import ConfigurationError +from pymongo.operations import ( + DeleteMany, + DeleteOne, + IndexModel, + ReplaceOne, + UpdateMany, + UpdateOne, +) +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestCollationObject(unittest.TestCase): + def test_constructor(self): + self.assertRaises(TypeError, Collation, locale=42) + # Fill in a locale to test the other options. + _Collation = functools.partial(Collation, "en_US") + # No error. + _Collation(caseFirst=CollationCaseFirst.UPPER) + self.assertRaises(TypeError, _Collation, caseLevel="true") + self.assertRaises(ValueError, _Collation, strength="six") + self.assertRaises(TypeError, _Collation, numericOrdering="true") + self.assertRaises(TypeError, _Collation, alternate=5) + self.assertRaises(TypeError, _Collation, maxVariable=2) + self.assertRaises(TypeError, _Collation, normalization="false") + self.assertRaises(TypeError, _Collation, backwards="true") + + # No errors. + Collation("en_US", future_option="bar", another_option=42) + collation = Collation( + "en_US", + caseLevel=True, + caseFirst=CollationCaseFirst.UPPER, + strength=CollationStrength.QUATERNARY, + numericOrdering=True, + alternate=CollationAlternate.SHIFTED, + maxVariable=CollationMaxVariable.SPACE, + normalization=True, + backwards=True, + ) + + self.assertEqual( + { + "locale": "en_US", + "caseLevel": True, + "caseFirst": "upper", + "strength": 4, + "numericOrdering": True, + "alternate": "shifted", + "maxVariable": "space", + "normalization": True, + "backwards": True, + }, + collation.document, + ) + + self.assertEqual( + {"locale": "en_US", "backwards": True}, Collation("en_US", backwards=True).document + ) + + +class TestCollation(AsyncIntegrationTest): + listener: EventListener + warn_context: Any + collation: Collation + + @classmethod + @async_client_context.require_connection + async def _setup_class(cls): + await super()._setup_class() + cls.listener = EventListener() + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) + cls.db = cls.client.pymongo_test + cls.collation = Collation("en_US") + cls.warn_context = warnings.catch_warnings() + cls.warn_context.__enter__() + warnings.simplefilter("ignore", DeprecationWarning) + + @classmethod + async def _tearDown_class(cls): + cls.warn_context.__exit__() + cls.warn_context = None + await cls.client.close() + await super()._tearDown_class() + + def tearDown(self): + self.listener.reset() + super().tearDown() + + def last_command_started(self): + return self.listener.started_events[-1].command + + def assertCollationInLastCommand(self): + self.assertEqual(self.collation.document, self.last_command_started()["collation"]) + + async def test_create_collection(self): + await self.db.test.drop() + await self.db.create_collection("test", collation=self.collation) + self.assertCollationInLastCommand() + + # Test passing collation as a dict as well. + await self.db.test.drop() + self.listener.reset() + await self.db.create_collection("test", collation=self.collation.document) + self.assertCollationInLastCommand() + + def test_index_model(self): + model = IndexModel([("a", 1), ("b", -1)], collation=self.collation) + self.assertEqual(self.collation.document, model.document["collation"]) + + async def test_create_index(self): + await self.db.test.create_index("foo", collation=self.collation) + ci_cmd = self.listener.started_events[0].command + self.assertEqual(self.collation.document, ci_cmd["indexes"][0]["collation"]) + + async def test_aggregate(self): + await self.db.test.aggregate([{"$group": {"_id": 42}}], collation=self.collation) + self.assertCollationInLastCommand() + + async def test_count_documents(self): + await self.db.test.count_documents({}, collation=self.collation) + self.assertCollationInLastCommand() + + async def test_distinct(self): + await self.db.test.distinct("foo", collation=self.collation) + self.assertCollationInLastCommand() + + self.listener.reset() + await self.db.test.find(collation=self.collation).distinct("foo") + self.assertCollationInLastCommand() + + async def test_find_command(self): + await self.db.test.insert_one({"is this thing on?": True}) + self.listener.reset() + await anext(self.db.test.find(collation=self.collation)) + self.assertCollationInLastCommand() + + async def test_explain_command(self): + self.listener.reset() + await self.db.test.find(collation=self.collation).explain() + # The collation should be part of the explained command. + self.assertEqual( + self.collation.document, self.last_command_started()["explain"]["collation"] + ) + + async def test_delete(self): + await self.db.test.delete_one({"foo": 42}, collation=self.collation) + command = self.listener.started_events[0].command + self.assertEqual(self.collation.document, command["deletes"][0]["collation"]) + + self.listener.reset() + await self.db.test.delete_many({"foo": 42}, collation=self.collation) + command = self.listener.started_events[0].command + self.assertEqual(self.collation.document, command["deletes"][0]["collation"]) + + async def test_update(self): + await self.db.test.replace_one({"foo": 42}, {"foo": 43}, collation=self.collation) + command = self.listener.started_events[0].command + self.assertEqual(self.collation.document, command["updates"][0]["collation"]) + + self.listener.reset() + await self.db.test.update_one({"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation) + command = self.listener.started_events[0].command + self.assertEqual(self.collation.document, command["updates"][0]["collation"]) + + self.listener.reset() + await self.db.test.update_many({"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation) + command = self.listener.started_events[0].command + self.assertEqual(self.collation.document, command["updates"][0]["collation"]) + + async def test_find_and(self): + await self.db.test.find_one_and_delete({"foo": 42}, collation=self.collation) + self.assertCollationInLastCommand() + + self.listener.reset() + await self.db.test.find_one_and_update( + {"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation + ) + self.assertCollationInLastCommand() + + self.listener.reset() + await self.db.test.find_one_and_replace({"foo": 42}, {"foo": 43}, collation=self.collation) + self.assertCollationInLastCommand() + + async def test_bulk_write(self): + await self.db.test.collection.bulk_write( + [ + DeleteOne({"noCollation": 42}), + DeleteMany({"noCollation": 42}), + DeleteOne({"foo": 42}, collation=self.collation), + DeleteMany({"foo": 42}, collation=self.collation), + ReplaceOne({"noCollation": 24}, {"bar": 42}), + UpdateOne({"noCollation": 84}, {"$set": {"bar": 10}}, upsert=True), + UpdateMany({"noCollation": 45}, {"$set": {"bar": 42}}), + ReplaceOne({"foo": 24}, {"foo": 42}, collation=self.collation), + UpdateOne( + {"foo": 84}, {"$set": {"foo": 10}}, upsert=True, collation=self.collation + ), + UpdateMany({"foo": 45}, {"$set": {"foo": 42}}, collation=self.collation), + ] + ) + + delete_cmd = self.listener.started_events[0].command + update_cmd = self.listener.started_events[1].command + + def check_ops(ops): + for op in ops: + if "noCollation" in op["q"]: + self.assertNotIn("collation", op) + else: + self.assertEqual(self.collation.document, op["collation"]) + + check_ops(delete_cmd["deletes"]) + check_ops(update_cmd["updates"]) + + async def test_indexes_same_keys_different_collations(self): + await self.db.test.drop() + usa_collation = Collation("en_US") + ja_collation = Collation("ja") + await self.db.test.create_indexes( + [ + IndexModel("fieldname", collation=usa_collation), + IndexModel("fieldname", name="japanese_version", collation=ja_collation), + IndexModel("fieldname", name="simple"), + ] + ) + indexes = await self.db.test.index_information() + self.assertEqual( + usa_collation.document["locale"], indexes["fieldname_1"]["collation"]["locale"] + ) + self.assertEqual( + ja_collation.document["locale"], indexes["japanese_version"]["collation"]["locale"] + ) + self.assertNotIn("collation", indexes["simple"]) + await self.db.test.drop_index("fieldname_1") + indexes = await self.db.test.index_information() + self.assertIn("japanese_version", indexes) + self.assertIn("simple", indexes) + self.assertNotIn("fieldname", indexes) + + async def test_unacknowledged_write(self): + unacknowledged = WriteConcern(w=0) + collection = self.db.get_collection("test", write_concern=unacknowledged) + with self.assertRaises(ConfigurationError): + await collection.update_one( + {"hello": "world"}, {"$set": {"hello": "moon"}}, collation=self.collation + ) + update_one = UpdateOne( + {"hello": "world"}, {"$set": {"hello": "moon"}}, collation=self.collation + ) + with self.assertRaises(ConfigurationError): + await collection.bulk_write([update_one]) + + async def test_cursor_collation(self): + await self.db.test.insert_one({"hello": "world"}) + await anext(self.db.test.find().collation(self.collation)) + self.assertCollationInLastCommand() diff --git a/test/asynchronous/test_command_logging.py b/test/asynchronous/test_command_logging.py new file mode 100644 index 000000000..f9b459c15 --- /dev/null +++ b/test/asynchronous/test_command_logging.py @@ -0,0 +1,44 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the command monitoring unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_command_monitoring.py b/test/asynchronous/test_command_monitoring.py new file mode 100644 index 000000000..311fd1fdc --- /dev/null +++ b/test/asynchronous/test_command_monitoring.py @@ -0,0 +1,45 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the command monitoring unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_common.py b/test/asynchronous/test_common.py new file mode 100644 index 000000000..00495e7c3 --- /dev/null +++ b/test/asynchronous/test_common.py @@ -0,0 +1,185 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the pymongo common module.""" +from __future__ import annotations + +import sys +import uuid + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, connected, unittest + +from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation +from bson.codec_options import CodecOptions +from bson.objectid import ObjectId +from pymongo.errors import OperationFailure +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestCommon(AsyncIntegrationTest): + async def test_uuid_representation(self): + coll = self.db.uuid + await coll.drop() + + # Test property + self.assertEqual(UuidRepresentation.UNSPECIFIED, coll.codec_options.uuid_representation) + + # Test basic query + uu = uuid.uuid4() + # Insert as binary subtype 3 + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + legacy_opts = coll.codec_options + await coll.insert_one({"uu": uu}) + self.assertEqual(uu, (await coll.find_one({"uu": uu}))["uu"]) # type: ignore + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) + self.assertEqual(STANDARD, coll.codec_options.uuid_representation) + self.assertEqual(None, await coll.find_one({"uu": uu})) + uul = Binary.from_uuid(uu, PYTHON_LEGACY) + self.assertEqual(uul, (await coll.find_one({"uu": uul}))["uu"]) # type: ignore + + # Test count_documents + self.assertEqual(0, await coll.count_documents({"uu": uu})) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + self.assertEqual(1, await coll.count_documents({"uu": uu})) + + # Test delete + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) + await coll.delete_one({"uu": uu}) + self.assertEqual(1, await coll.count_documents({})) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + await coll.delete_one({"uu": uu}) + self.assertEqual(0, await coll.count_documents({})) + + # Test update_one + await coll.insert_one({"_id": uu, "i": 1}) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) + await coll.update_one({"_id": uu}, {"$set": {"i": 2}}) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + self.assertEqual(1, (await coll.find_one({"_id": uu}))["i"]) # type: ignore + await coll.update_one({"_id": uu}, {"$set": {"i": 2}}) + self.assertEqual(2, (await coll.find_one({"_id": uu}))["i"]) # type: ignore + + # Test Cursor.distinct + self.assertEqual([2], await coll.find({"_id": uu}).distinct("i")) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) + self.assertEqual([], await coll.find({"_id": uu}).distinct("i")) + + # Test findAndModify + self.assertEqual(None, await coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}})) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + self.assertEqual(2, (await coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))["i"]) + self.assertEqual(5, (await coll.find_one({"_id": uu}))["i"]) # type: ignore + + # Test command + self.assertEqual( + 5, + ( + await self.db.command( + "findAndModify", + "uuid", + update={"$set": {"i": 6}}, + query={"_id": uu}, + codec_options=legacy_opts, + ) + )["value"]["i"], + ) + self.assertEqual( + 6, + ( + await self.db.command( + "findAndModify", + "uuid", + update={"$set": {"i": 7}}, + query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)}, + ) + )["value"]["i"], + ) + + async def test_write_concern(self): + c = await self.async_rs_or_single_client(connect=False) + self.assertEqual(WriteConcern(), c.write_concern) + + c = await self.async_rs_or_single_client(connect=False, w=2, wTimeoutMS=1000) + wc = WriteConcern(w=2, wtimeout=1000) + self.assertEqual(wc, c.write_concern) + + # Can we override back to the server default? + db = c.get_database("pymongo_test", write_concern=WriteConcern()) + self.assertEqual(db.write_concern, WriteConcern()) + + db = c.pymongo_test + self.assertEqual(wc, db.write_concern) + coll = db.test + self.assertEqual(wc, coll.write_concern) + + cwc = WriteConcern(j=True) + coll = db.get_collection("test", write_concern=cwc) + self.assertEqual(cwc, coll.write_concern) + self.assertEqual(wc, db.write_concern) + + async def test_mongo_client(self): + pair = await async_client_context.pair + m = await self.async_rs_or_single_client(w=0) + coll = m.pymongo_test.write_concern_test + await coll.drop() + doc = {"_id": ObjectId()} + await coll.insert_one(doc) + self.assertTrue(await coll.insert_one(doc)) + coll = coll.with_options(write_concern=WriteConcern(w=1)) + with self.assertRaises(OperationFailure): + await coll.insert_one(doc) + + m = await self.async_rs_or_single_client() + coll = m.pymongo_test.write_concern_test + new_coll = coll.with_options(write_concern=WriteConcern(w=0)) + self.assertTrue(await new_coll.insert_one(doc)) + with self.assertRaises(OperationFailure): + await coll.insert_one(doc) + + m = await self.async_rs_or_single_client( + f"mongodb://{pair}/", replicaSet=async_client_context.replica_set_name + ) + + coll = m.pymongo_test.write_concern_test + with self.assertRaises(OperationFailure): + await coll.insert_one(doc) + m = await self.async_rs_or_single_client( + f"mongodb://{pair}/?w=0", replicaSet=async_client_context.replica_set_name + ) + + coll = m.pymongo_test.write_concern_test + await coll.insert_one(doc) + + # Equality tests + direct = await connected(await self.async_single_client(w=0)) + direct2 = await connected( + await self.async_single_client(f"mongodb://{pair}/?w=0", **self.credentials) + ) + self.assertEqual(direct, direct2) + self.assertFalse(direct != direct2) + + async def test_validate_boolean(self): + await self.db.test.update_one({}, {"$set": {"total": 1}}, upsert=True) + with self.assertRaisesRegex( + TypeError, "upsert must be True or False, was: upsert={'upsert': True}" + ): + await self.db.test.update_one({}, {"$set": {"total": 1}}, {"upsert": True}) # type: ignore + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_connection_logging.py b/test/asynchronous/test_connection_logging.py new file mode 100644 index 000000000..6bc9835b7 --- /dev/null +++ b/test/asynchronous/test_connection_logging.py @@ -0,0 +1,45 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the connection logging unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_crud_unified.py b/test/asynchronous/test_crud_unified.py new file mode 100644 index 000000000..3d8deb36e --- /dev/null +++ b/test/asynchronous/test_crud_unified.py @@ -0,0 +1,39 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the CRUD unified spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") + +# Generate unified tests. +globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index c3f622338..88b005c4b 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -30,6 +30,7 @@ import uuid import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -59,7 +60,6 @@ from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, OvertCommandListener, - SpecTestCreator, TopologyEventListener, async_wait_until, camel_to_snake_args, @@ -626,132 +626,132 @@ AWS_TEMP_NO_SESSION_CREDS = { KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} -if _IS_SYNC: - # TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(AsyncSpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def setUpClass(cls): - super().setUpClass() +class AsyncTestSpec(AsyncSpecRunner): + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + async def _setup_class(cls): + await super()._setup_class() - def parse_auto_encrypt_opts(self, opts): - """Parse clientOptions.autoEncryptOpts.""" - opts = camel_to_snake_args(opts) - kms_providers = opts["kms_providers"] - if "aws" in kms_providers: - kms_providers["aws"] = AWS_CREDS - if not any(AWS_CREDS.values()): - self.skipTest("AWS environment credentials are not set") - if "awsTemporary" in kms_providers: - kms_providers["aws"] = AWS_TEMP_CREDS - del kms_providers["awsTemporary"] - if not any(AWS_TEMP_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "awsTemporaryNoSessionToken" in kms_providers: - kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers["awsTemporaryNoSessionToken"] - if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "azure" in kms_providers: - kms_providers["azure"] = AZURE_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("Azure environment credentials are not set") - if "gcp" in kms_providers: - kms_providers["gcp"] = GCP_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("GCP environment credentials are not set") - if "kmip" in kms_providers: - kms_providers["kmip"] = KMIP_CREDS - opts["kms_tls_options"] = KMS_TLS_OPTS - if "key_vault_namespace" not in opts: - opts["key_vault_namespace"] = "keyvault.datakeys" - if "extra_options" in opts: - opts.update(camel_to_snake_args(opts.pop("extra_options"))) + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] + if not any(AWS_TEMP_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] + if not any(AWS_TEMP_NO_SESSION_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP_CREDS + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" + if "extra_options" in opts: + opts.update(camel_to_snake_args(opts.pop("extra_options"))) - opts = dict(opts) - return AutoEncryptionOpts(**opts) + opts = dict(opts) + return AutoEncryptionOpts(**opts) - def parse_client_options(self, opts): - """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop("autoEncryptOpts", None) - if encrypt_opts: - opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop("autoEncryptOpts", None) + if encrypt_opts: + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - return super().parse_client_options(opts) + return super().parse_client_options(opts) - def get_object_name(self, op): - """Default object is collection.""" - return op.get("object", "collection") + def get_object_name(self, op): + """Default object is collection.""" + return op.get("object", "collection") - def maybe_skip_scenario(self, test): - super().maybe_skip_scenario(test) - desc = test["description"].lower() - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and sys.platform in ("win32", "darwin") - ): - self.skipTest("PYTHON-3706 flaky test on Windows/macOS") - if "type=symbol" in desc: - self.skipTest("PyMongo does not support the symbol type") + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + desc = test["description"].lower() + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and sys.platform in ("win32", "darwin") + ): + self.skipTest("PYTHON-3706 flaky test on Windows/macOS") + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") + if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC: + self.skipTest("PYTHON-4844 flaky test on async") - def setup_scenario(self, scenario_def): - """Override a test's setup.""" - key_vault_data = scenario_def["key_vault_data"] - encrypted_fields = scenario_def["encrypted_fields"] - json_schema = scenario_def["json_schema"] - data = scenario_def["data"] - coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)[ - "datakeys" - ] - coll.delete_many({}) - if key_vault_data: - coll.insert_many(key_vault_data) + async def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def["key_vault_data"] + encrypted_fields = scenario_def["encrypted_fields"] + json_schema = scenario_def["json_schema"] + data = scenario_def["data"] + coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] + await coll.delete_many({}) + if key_vault_data: + await coll.insert_many(key_vault_data) - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - db = async_client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) - wc = WriteConcern(w="majority") - kwargs: Dict[str, Any] = {} - if json_schema: - kwargs["validator"] = {"$jsonSchema": json_schema} - kwargs["codec_options"] = OPTS - if not data: - kwargs["write_concern"] = wc - if encrypted_fields: - kwargs["encryptedFields"] = encrypted_fields - db.create_collection(coll_name, **kwargs) - coll = db[coll_name] - if data: - # Load data. - coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = async_client_context.client.get_database(db_name, codec_options=OPTS) + await db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + wc = WriteConcern(w="majority") + kwargs: Dict[str, Any] = {} + if json_schema: + kwargs["validator"] = {"$jsonSchema": json_schema} + kwargs["codec_options"] = OPTS + if not data: + kwargs["write_concern"] = wc + if encrypted_fields: + kwargs["encryptedFields"] = encrypted_fields + await db.create_collection(coll_name, **kwargs) + coll = db[coll_name] + if data: + # Load data. + await coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) - def allowable_errors(self, op): - """Override expected error classes.""" - errors = super().allowable_errors(op) - # An updateOne test expects encryption to error when no $ operator - # appears but pymongo raises a client side ValueError in this case. - if op["name"] == "updateOne": - errors += (ValueError,) - return errors + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super().allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op["name"] == "updateOne": + errors += (ValueError,) + return errors - def create_test(scenario_def, test, name): - @async_client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) - return run_scenario +async def create_test(scenario_def, test, name): + @async_client_context.require_test_commands + async def run_scenario(self): + await self.run_scenario(scenario_def, test) - test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) - test_creator.create_tests() + return run_scenario - if _HAVE_PYMONGOCRYPT: - globals().update( - generate_test_classes( - os.path.join(SPEC_PATH, "unified"), - module=__name__, - ) + +test_creator = AsyncSpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator.create_tests() + +if _HAVE_PYMONGOCRYPT: + globals().update( + generate_test_classes( + os.path.join(SPEC_PATH, "unified"), + module=__name__, ) + ) # Prose Tests ALL_KMS_PROVIDERS = { diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py new file mode 100644 index 000000000..4c3742295 --- /dev/null +++ b/test/asynchronous/unified_format.py @@ -0,0 +1,1573 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified test format runner. + +https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst +""" +from __future__ import annotations + +import asyncio +import binascii +import copy +import functools +import os +import re +import sys +import time +import traceback +from asyncio import iscoroutinefunction +from collections import defaultdict +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + client_knobs, + unittest, +) +from test.unified_format_shared import ( + IS_INTERRUPTED, + KMS_TLS_OPTS, + PLACEHOLDER_MAP, + SKIP_CSOT_TESTS, + EventListenerUtil, + MatchEvaluatorUtil, + coerce_result, + parse_bulk_write_error_result, + parse_bulk_write_result, + parse_client_bulk_write_error_result, + parse_collection_or_database_options, + with_metaclass, +) +from test.utils import ( + async_get_pool, + camel_to_snake, + camel_to_snake_args, + parse_spec_options, + prepare_spec_arguments, + snake_to_camel, + wait_until, +) +from test.utils_spec_runner import SpecRunnerThread +from test.version import Version +from typing import Any, Dict, List, Mapping, Optional + +import pymongo +from bson import SON, json_util +from bson.codec_options import DEFAULT_CODEC_OPTIONS +from bson.objectid import ObjectId +from gridfs import AsyncGridFSBucket, GridOut +from pymongo import ASCENDING, AsyncMongoClient, CursorType, _csot +from pymongo.asynchronous.change_stream import AsyncChangeStream +from pymongo.asynchronous.client_session import AsyncClientSession, TransactionOptions, _TxnState +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.encryption import AsyncClientEncryption +from pymongo.asynchronous.helpers import anext +from pymongo.encryption_options import _HAVE_PYMONGOCRYPT +from pymongo.errors import ( + BulkWriteError, + ClientBulkWriteException, + ConfigurationError, + ConnectionFailure, + EncryptionError, + InvalidOperation, + NotPrimaryError, + OperationFailure, + PyMongoError, +) +from pymongo.monitoring import ( + CommandStartedEvent, +) +from pymongo.operations import ( + SearchIndexModel, +) +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.server_api import ServerApi +from pymongo.server_selectors import Selection, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.topology_description import TopologyDescription +from pymongo.typings import _Address +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +async def is_run_on_requirement_satisfied(requirement): + topology_satisfied = True + req_topologies = requirement.get("topologies") + if req_topologies: + topology_satisfied = await async_client_context.is_topology_type(req_topologies) + + server_version = Version(*async_client_context.version[:3]) + + min_version_satisfied = True + req_min_server_version = requirement.get("minServerVersion") + if req_min_server_version: + min_version_satisfied = Version.from_string(req_min_server_version) <= server_version + + max_version_satisfied = True + req_max_server_version = requirement.get("maxServerVersion") + if req_max_server_version: + max_version_satisfied = Version.from_string(req_max_server_version) >= server_version + + serverless = requirement.get("serverless") + if serverless == "require": + serverless_satisfied = async_client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not async_client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + + params_satisfied = True + params = requirement.get("serverParameters") + if params: + for param, val in params.items(): + if param not in async_client_context.server_parameters: + params_satisfied = False + elif async_client_context.server_parameters[param] != val: + params_satisfied = False + + auth_satisfied = True + req_auth = requirement.get("auth") + if req_auth is not None: + if req_auth: + auth_satisfied = async_client_context.auth_enabled + if auth_satisfied and "authMechanism" in requirement: + auth_satisfied = async_client_context.check_auth_type(requirement["authMechanism"]) + else: + auth_satisfied = not async_client_context.auth_enabled + + csfle_satisfied = True + req_csfle = requirement.get("csfle") + if req_csfle is True: + min_version_satisfied = Version.from_string("4.2") <= server_version + csfle_satisfied = _HAVE_PYMONGOCRYPT and min_version_satisfied + + return ( + topology_satisfied + and min_version_satisfied + and max_version_satisfied + and serverless_satisfied + and params_satisfied + and auth_satisfied + and csfle_satisfied + ) + + +class NonLazyCursor: + """A find cursor proxy that creates the remote cursor when initialized.""" + + def __init__(self, find_cursor, client): + self.client = client + self.find_cursor = find_cursor + # Create the server side cursor. + self.first_result = None + + @classmethod + async def create(cls, find_cursor, client): + cursor = cls(find_cursor, client) + try: + cursor.first_result = await anext(cursor.find_cursor) + except StopAsyncIteration: + cursor.first_result = None + return cursor + + @property + def alive(self): + return self.first_result is not None or self.find_cursor.alive + + async def __anext__(self): + if self.first_result is not None: + first = self.first_result + self.first_result = None + return first + return await anext(self.find_cursor) + + # Added to support the iterateOnce operation. + try_next = __anext__ + + async def close(self): + await self.find_cursor.close() + self.client = None + + +class EntityMapUtil: + """Utility class that implements an entity map as per the unified + test format specification. + """ + + def __init__(self, test_class): + self._entities: Dict[str, Any] = {} + self._listeners: Dict[str, EventListenerUtil] = {} + self._session_lsids: Dict[str, Mapping[str, Any]] = {} + self.test: UnifiedSpecTestMixinV1 = test_class + self._cluster_time: Mapping[str, Any] = {} + + def __contains__(self, item): + return item in self._entities + + def __len__(self): + return len(self._entities) + + def __getitem__(self, item): + try: + return self._entities[item] + except KeyError: + self.test.fail(f"Could not find entity named {item} in map") + + def __setitem__(self, key, value): + if not isinstance(key, str): + self.test.fail("Expected entity name of type str, got %s" % (type(key))) + + if key in self._entities: + self.test.fail(f"Entity named {key} already in map") + + self._entities[key] = value + + def _handle_placeholders(self, spec: dict, current: dict, path: str) -> Any: + if "$$placeholder" in current: + if path not in PLACEHOLDER_MAP: + raise ValueError(f"Could not find a placeholder value for {path}") + return PLACEHOLDER_MAP[path] + + for key in list(current): + value = current[key] + if isinstance(value, dict): + subpath = f"{path}/{key}" + current[key] = self._handle_placeholders(spec, value, subpath) + return current + + async def _create_entity(self, entity_spec, uri=None): + if len(entity_spec) != 1: + self.test.fail(f"Entity spec {entity_spec} did not contain exactly one top-level key") + + entity_type, spec = next(iter(entity_spec.items())) + spec = self._handle_placeholders(spec, spec, "") + if entity_type == "client": + kwargs: dict = {} + observe_events = spec.get("observeEvents", []) + + # The unified tests use topologyOpeningEvent, we use topologyOpenedEvent + for i in range(len(observe_events)): + if "topologyOpeningEvent" == observe_events[i]: + observe_events[i] = "topologyOpenedEvent" + ignore_commands = spec.get("ignoreCommandMonitoringEvents", []) + observe_sensitive_commands = spec.get("observeSensitiveCommands", False) + ignore_commands = [cmd.lower() for cmd in ignore_commands] + listener = EventListenerUtil( + observe_events, + ignore_commands, + observe_sensitive_commands, + spec.get("storeEventsAsEntities"), + self, + ) + self._listeners[spec["id"]] = listener + kwargs["event_listeners"] = [listener] + if spec.get("useMultipleMongoses"): + if async_client_context.load_balancer or async_client_context.serverless: + kwargs["h"] = async_client_context.MULTI_MONGOS_LB_URI + elif async_client_context.is_mongos: + kwargs["h"] = async_client_context.mongos_seeds() + kwargs.update(spec.get("uriOptions", {})) + server_api = spec.get("serverApi") + if "waitQueueSize" in kwargs: + raise unittest.SkipTest("PyMongo does not support waitQueueSize") + if "waitQueueMultiple" in kwargs: + raise unittest.SkipTest("PyMongo does not support waitQueueMultiple") + if server_api: + kwargs["server_api"] = ServerApi( + server_api["version"], + strict=server_api.get("strict"), + deprecation_errors=server_api.get("deprecationErrors"), + ) + if uri: + kwargs["h"] = uri + client = await self.test.async_rs_or_single_client(**kwargs) + self[spec["id"]] = client + self.test.addAsyncCleanup(client.close) + return + elif entity_type == "database": + client = self[spec["client"]] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + "Expected entity {} to be of type AsyncMongoClient, got {}".format( + spec["client"], type(client) + ) + ) + options = parse_collection_or_database_options(spec.get("databaseOptions", {})) + self[spec["id"]] = client.get_database(spec["databaseName"], **options) + return + elif entity_type == "collection": + database = self[spec["database"]] + if not isinstance(database, AsyncDatabase): + self.test.fail( + "Expected entity {} to be of type AsyncDatabase, got {}".format( + spec["database"], type(database) + ) + ) + options = parse_collection_or_database_options(spec.get("collectionOptions", {})) + self[spec["id"]] = database.get_collection(spec["collectionName"], **options) + return + elif entity_type == "session": + client = self[spec["client"]] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + "Expected entity {} to be of type AsyncMongoClient, got {}".format( + spec["client"], type(client) + ) + ) + opts = camel_to_snake_args(spec.get("sessionOptions", {})) + if "default_transaction_options" in opts: + txn_opts = parse_spec_options(opts["default_transaction_options"]) + txn_opts = TransactionOptions(**txn_opts) + opts = copy.deepcopy(opts) + opts["default_transaction_options"] = txn_opts + session = client.start_session(**dict(opts)) + self[spec["id"]] = session + self._session_lsids[spec["id"]] = copy.deepcopy(session.session_id) + self.test.addAsyncCleanup(session.end_session) + return + elif entity_type == "bucket": + db = self[spec["database"]] + kwargs = parse_spec_options(spec.get("bucketOptions", {}).copy()) + bucket = AsyncGridFSBucket(db, **kwargs) + + # PyMongo does not support AsyncGridFSBucket.drop(), emulate it. + @_csot.apply + async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None: + await self._files.drop(*args, **kwargs) + await self._chunks.drop(*args, **kwargs) + + if not hasattr(bucket, "drop"): + bucket.drop = drop.__get__(bucket) + self[spec["id"]] = bucket + return + elif entity_type == "clientEncryption": + opts = camel_to_snake_args(spec["clientEncryptionOpts"].copy()) + if isinstance(opts["key_vault_client"], str): + opts["key_vault_client"] = self[opts["key_vault_client"]] + # Set TLS options for providers like "kmip:name1". + kms_tls_options = {} + for provider in opts["kms_providers"]: + provider_type = provider.split(":")[0] + if provider_type in KMS_TLS_OPTS: + kms_tls_options[provider] = KMS_TLS_OPTS[provider_type] + self[spec["id"]] = AsyncClientEncryption( + opts["kms_providers"], + opts["key_vault_namespace"], + opts["key_vault_client"], + DEFAULT_CODEC_OPTIONS, + opts.get("kms_tls_options", kms_tls_options), + ) + return + elif entity_type == "thread": + name = spec["id"] + thread = SpecRunnerThread(name) + thread.start() + self[name] = thread + return + + self.test.fail(f"Unable to create entity of unknown type {entity_type}") + + async def create_entities_from_spec(self, entity_spec, uri=None): + for spec in entity_spec: + await self._create_entity(spec, uri=uri) + + def get_listener_for_client(self, client_name: str) -> EventListenerUtil: + client = self[client_name] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + f"Expected entity {client_name} to be of type AsyncMongoClient, got {type(client)}" + ) + + listener = self._listeners.get(client_name) + if not listener: + self.test.fail(f"No listeners configured for client {client_name}") + + return listener + + def get_lsid_for_session(self, session_name): + session = self[session_name] + if not isinstance(session, AsyncClientSession): + self.test.fail( + f"Expected entity {session_name} to be of type AsyncClientSession, got {type(session)}" + ) + + try: + return session.session_id + except InvalidOperation: + # session has been closed. + return self._session_lsids[session_name] + + async def advance_cluster_times(self) -> None: + """Manually synchronize entities when desired""" + if not self._cluster_time: + self._cluster_time = (await self.test.client.admin.command("ping")).get("$clusterTime") + for entity in self._entities.values(): + if isinstance(entity, AsyncClientSession) and self._cluster_time: + entity.advance_cluster_time(self._cluster_time) + + +class UnifiedSpecTestMixinV1(AsyncIntegrationTest): + """Mixin class to run test cases from test specification files. + + Assumes that tests conform to the `unified test format + `_. + + Specification of the test suite being currently run is available as + a class attribute ``TEST_SPEC``. + """ + + SCHEMA_VERSION = Version.from_string("1.21") + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + TEST_SPEC: Any + mongos_clients: list[AsyncMongoClient] = [] + + @staticmethod + async def should_run_on(run_on_spec): + if not run_on_spec: + # Always run these tests. + return True + + for req in run_on_spec: + if await is_run_on_requirement_satisfied(req): + return True + return False + + async def insert_initial_data(self, initial_data): + for i, collection_data in enumerate(initial_data): + coll_name = collection_data["collectionName"] + db_name = collection_data["databaseName"] + opts = collection_data.get("createOptions", {}) + documents = collection_data["documents"] + + # Setup the collection with as few majority writes as possible. + db = self.client[db_name] + await db.drop_collection(coll_name) + # Only use majority wc only on the final write. + if i == len(initial_data) - 1: + wc = WriteConcern(w="majority") + else: + wc = WriteConcern(w=1) + if documents: + if opts: + await db.create_collection(coll_name, **opts) + await db.get_collection(coll_name, write_concern=wc).insert_many(documents) + else: + # Ensure collection exists + await db.create_collection(coll_name, write_concern=wc, **opts) + + @classmethod + async def _setup_class(cls): + # super call creates internal client cls.client + await super()._setup_class() + # process file-level runOnRequirements + run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) + if not await cls.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + + # add any special-casing for skipping tests here + if async_client_context.storage_engine == "mmapv1": + if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( + cls.TEST_PATH + ): + raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") + + # Handle mongos_clients for transactions tests. + cls.mongos_clients = [] + if ( + async_client_context.supports_transactions() + and not async_client_context.load_balancer + and not async_client_context.serverless + ): + for address in async_client_context.mongoses: + cls.mongos_clients.append( + await cls.unmanaged_async_single_client("{}:{}".format(*address)) + ) + + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + async def _tearDown_class(cls): + cls.knobs.disable() + for client in cls.mongos_clients: + await client.close() + await super()._tearDown_class() + + async def asyncSetUp(self): + await super().asyncSetUp() + # process schemaVersion + # note: we check major schema version during class generation + # note: we do this here because we cannot run assertions in setUpClass + version = Version.from_string(self.TEST_SPEC["schemaVersion"]) + self.assertLessEqual( + version, + self.SCHEMA_VERSION, + f"expected schema version {self.SCHEMA_VERSION} or lower, got {version}", + ) + + # initialize internals + self.match_evaluator = MatchEvaluatorUtil(self) + + def maybe_skip_test(self, spec): + # add any special-casing for skipping tests here + if async_client_context.storage_engine == "mmapv1": + if ( + "Dirty explicit session is discarded" in spec["description"] + or "Dirty implicit session is discarded" in spec["description"] + or "Cancel server check" in spec["description"] + ): + self.skipTest("MMAPv1 does not support retryWrites=True") + if ( + "AsyncDatabase-level aggregate with $out includes read preference for 5.0+ server" + in spec["description"] + ): + if async_client_context.version[0] == 8: + self.skipTest("waiting on PYTHON-4356") + if "Aggregate with $out includes read preference for 5.0+ server" in spec["description"]: + if async_client_context.version[0] == 8: + self.skipTest("waiting on PYTHON-4356") + if "Client side error in command starting transaction" in spec["description"]: + self.skipTest("Implement PYTHON-1894") + if "timeoutMS applied to entire download" in spec["description"]: + self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + + class_name = self.__class__.__name__.lower() + description = spec["description"].lower() + if "csot" in class_name: + if "gridfs" in class_name and sys.platform == "win32": + self.skipTest("PYTHON-3522 CSOT GridFS tests are flaky on Windows") + if async_client_context.storage_engine == "mmapv1": + self.skipTest( + "MMAPv1 does not support retryable writes which is required for CSOT tests" + ) + if "change" in description or "change" in class_name: + self.skipTest("CSOT not implemented for watch()") + if "cursors" in class_name: + self.skipTest("CSOT not implemented for cursors") + if "tailable" in class_name: + self.skipTest("CSOT not implemented for tailable cursors") + if "sessions" in class_name: + self.skipTest("CSOT not implemented for sessions") + if "withtransaction" in description: + self.skipTest("CSOT not implemented for with_transaction") + if "transaction" in class_name or "transaction" in description: + self.skipTest("CSOT not implemented for transactions") + + # Some tests need to be skipped based on the operations they try to run. + for op in spec["operations"]: + name = op["name"] + if name == "count": + self.skipTest("PyMongo does not support count()") + if name == "listIndexNames": + self.skipTest("PyMongo does not support list_index_names()") + if async_client_context.storage_engine == "mmapv1": + if name == "createChangeStream": + self.skipTest("MMAPv1 does not support change streams") + if name == "withTransaction" or name == "startTransaction": + self.skipTest("MMAPv1 does not support document-level locking") + if not async_client_context.test_commands_enabled: + if name == "failPoint" or name == "targetedFailPoint": + self.skipTest("Test commands must be enabled to use fail points") + if name == "modifyCollection": + self.skipTest("PyMongo does not support modifyCollection") + if "timeoutMode" in op.get("arguments", {}): + self.skipTest("PyMongo does not support timeoutMode") + + def process_error(self, exception, spec): + if isinstance(exception, unittest.SkipTest): + raise + is_error = spec.get("isError") + is_client_error = spec.get("isClientError") + is_timeout_error = spec.get("isTimeoutError") + error_contains = spec.get("errorContains") + error_code = spec.get("errorCode") + error_code_name = spec.get("errorCodeName") + error_labels_contain = spec.get("errorLabelsContain") + error_labels_omit = spec.get("errorLabelsOmit") + expect_result = spec.get("expectResult") + error_response = spec.get("errorResponse") + if error_response: + if isinstance(exception, ClientBulkWriteException): + self.match_evaluator.match_result(error_response, exception.error.details) + else: + self.match_evaluator.match_result(error_response, exception.details) + + if is_error: + # already satisfied because exception was raised + pass + + if is_client_error: + if isinstance(exception, ClientBulkWriteException): + error = exception.error + else: + error = exception + # Connection errors are considered client errors. + if isinstance(error, ConnectionFailure): + self.assertNotIsInstance(error, NotPrimaryError) + elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)): + pass + else: + self.assertNotIsInstance(error, PyMongoError) + + if is_timeout_error: + self.assertIsInstance(exception, PyMongoError) + if not exception.timeout: + # Re-raise the exception for better diagnostics. + raise exception + + if error_contains: + if isinstance(exception, BulkWriteError): + errmsg = str(exception.details).lower() + elif isinstance(exception, ClientBulkWriteException): + errmsg = str(exception.details).lower() + else: + errmsg = str(exception).lower() + self.assertIn(error_contains.lower(), errmsg) + + if error_code: + if isinstance(exception, ClientBulkWriteException): + self.assertEqual(error_code, exception.error.details.get("code")) + else: + self.assertEqual(error_code, exception.details.get("code")) + + if error_code_name: + if isinstance(exception, ClientBulkWriteException): + self.assertEqual(error_code, exception.error.details.get("codeName")) + else: + self.assertEqual(error_code_name, exception.details.get("codeName")) + + if error_labels_contain: + if isinstance(exception, ClientBulkWriteException): + error = exception.error + else: + error = exception + labels = [ + err_label for err_label in error_labels_contain if error.has_error_label(err_label) + ] + self.assertEqual(labels, error_labels_contain) + + if error_labels_omit: + for err_label in error_labels_omit: + if exception.has_error_label(err_label): + self.fail(f"Exception '{exception}' unexpectedly had label '{err_label}'") + + if expect_result: + if isinstance(exception, BulkWriteError): + result = parse_bulk_write_error_result(exception) + self.match_evaluator.match_result(expect_result, result) + elif isinstance(exception, ClientBulkWriteException): + result = parse_client_bulk_write_error_result(exception) + self.match_evaluator.match_result(expect_result, result) + else: + self.fail( + f"expectResult can only be specified with {BulkWriteError} or {ClientBulkWriteException} exceptions" + ) + + return exception + + def __raise_if_unsupported(self, opname, target, *target_types): + if not isinstance(target, target_types): + self.fail(f"Operation {opname} not supported for entity of type {type(target)}") + + async def __entityOperation_createChangeStream(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support change streams") + self.__raise_if_unsupported( + "createChangeStream", target, AsyncMongoClient, AsyncDatabase, AsyncCollection + ) + stream = await target.watch(*args, **kwargs) + self.addAsyncCleanup(stream.close) + return stream + + async def _clientOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _databaseOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _collectionOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _databaseOperation_runCommand(self, target, **kwargs): + self.__raise_if_unsupported("runCommand", target, AsyncDatabase) + # Ensure the first key is the command name. + ordered_command = SON([(kwargs.pop("command_name"), 1)]) + ordered_command.update(kwargs["command"]) + kwargs["command"] = ordered_command + return await target.command(**kwargs) + + async def _databaseOperation_runCursorCommand(self, target, **kwargs): + return list(await self._databaseOperation_createCommandCursor(target, **kwargs)) + + async def _databaseOperation_createCommandCursor(self, target, **kwargs): + self.__raise_if_unsupported("createCommandCursor", target, AsyncDatabase) + # Ensure the first key is the command name. + ordered_command = SON([(kwargs.pop("command_name"), 1)]) + ordered_command.update(kwargs["command"]) + kwargs["command"] = ordered_command + batch_size = 0 + + cursor_type = kwargs.pop("cursor_type", "nonTailable") + if cursor_type == CursorType.TAILABLE: + ordered_command["tailable"] = True + elif cursor_type == CursorType.TAILABLE_AWAIT: + ordered_command["tailable"] = True + ordered_command["awaitData"] = True + elif cursor_type != "nonTailable": + self.fail(f"unknown cursorType: {cursor_type}") + + if "maxTimeMS" in kwargs: + kwargs["max_await_time_ms"] = kwargs.pop("maxTimeMS") + + if "batch_size" in kwargs: + batch_size = kwargs.pop("batch_size") + + cursor = await target.cursor_command(**kwargs) + + if batch_size > 0: + cursor.batch_size(batch_size) + + return cursor + + async def kill_all_sessions(self): + if getattr(self, "client", None) is None: + return + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + try: + await client.admin.command("killAllSessions", []) + except OperationFailure: + # "operation was interrupted" by killing the command's + # own session. + pass + + async def _databaseOperation_listCollections(self, target, *args, **kwargs): + if "batch_size" in kwargs: + kwargs["cursor"] = {"batchSize": kwargs.pop("batch_size")} + cursor = await target.list_collections(*args, **kwargs) + return list(cursor) + + async def _databaseOperation_createCollection(self, target, *args, **kwargs): + # PYTHON-1936 Ignore the listCollections event from create_collection. + kwargs["check_exists"] = False + ret = await target.create_collection(*args, **kwargs) + return ret + + async def __entityOperation_aggregate(self, target, *args, **kwargs): + self.__raise_if_unsupported("aggregate", target, AsyncDatabase, AsyncCollection) + return await (await target.aggregate(*args, **kwargs)).to_list() + + async def _databaseOperation_aggregate(self, target, *args, **kwargs): + return await self.__entityOperation_aggregate(target, *args, **kwargs) + + async def _collectionOperation_aggregate(self, target, *args, **kwargs): + return await self.__entityOperation_aggregate(target, *args, **kwargs) + + async def _collectionOperation_find(self, target, *args, **kwargs): + self.__raise_if_unsupported("find", target, AsyncCollection) + find_cursor = target.find(*args, **kwargs) + return await find_cursor.to_list() + + async def _collectionOperation_createFindCursor(self, target, *args, **kwargs): + self.__raise_if_unsupported("find", target, AsyncCollection) + if "filter" not in kwargs: + self.fail('createFindCursor requires a "filter" argument') + cursor = await NonLazyCursor.create(target.find(*args, **kwargs), target.database.client) + self.addAsyncCleanup(cursor.close) + return cursor + + def _collectionOperation_count(self, target, *args, **kwargs): + self.skipTest("PyMongo does not support collection.count()") + + async def _collectionOperation_listIndexes(self, target, *args, **kwargs): + if "batch_size" in kwargs: + self.skipTest("PyMongo does not support batch_size for list_indexes") + return await (await target.list_indexes(*args, **kwargs)).to_list() + + def _collectionOperation_listIndexNames(self, target, *args, **kwargs): + self.skipTest("PyMongo does not support list_index_names") + + async def _collectionOperation_createSearchIndexes(self, target, *args, **kwargs): + models = [SearchIndexModel(**i) for i in kwargs["models"]] + return await target.create_search_indexes(models) + + async def _collectionOperation_listSearchIndexes(self, target, *args, **kwargs): + name = kwargs.get("name") + agg_kwargs = kwargs.get("aggregation_options", dict()) + return await (await target.list_search_indexes(name, **agg_kwargs)).to_list() + + async def _sessionOperation_withTransaction(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support document-level locking") + self.__raise_if_unsupported("withTransaction", target, AsyncClientSession) + return await target.with_transaction(*args, **kwargs) + + async def _sessionOperation_startTransaction(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support document-level locking") + self.__raise_if_unsupported("startTransaction", target, AsyncClientSession) + return await target.start_transaction(*args, **kwargs) + + async def _changeStreamOperation_iterateUntilDocumentOrError(self, target, *args, **kwargs): + self.__raise_if_unsupported("iterateUntilDocumentOrError", target, AsyncChangeStream) + return await anext(target) + + async def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs): + self.__raise_if_unsupported( + "iterateUntilDocumentOrError", target, NonLazyCursor, AsyncCommandCursor + ) + while target.alive: + try: + return await anext(target) + except StopAsyncIteration: + pass + return None + + async def _cursor_close(self, target, *args, **kwargs): + self.__raise_if_unsupported("close", target, NonLazyCursor, AsyncCommandCursor) + return await target.close() + + async def _clientEncryptionOperation_createDataKey(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + + return await target.create_data_key(*args, **kwargs) + + async def _clientEncryptionOperation_getKeys(self, target, *args, **kwargs): + return await (await target.get_keys(*args, **kwargs)).to_list() + + async def _clientEncryptionOperation_deleteKey(self, target, *args, **kwargs): + result = await target.delete_key(*args, **kwargs) + response = result.raw_result + response["deletedCount"] = result.deleted_count + return response + + async def _clientEncryptionOperation_rewrapManyDataKey(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + data = await target.rewrap_many_data_key(*args, **kwargs) + if data.bulk_write_result: + return {"bulkWriteResult": parse_bulk_write_result(data.bulk_write_result)} + return {} + + async def _clientEncryptionOperation_encrypt(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + return await target.encrypt(*args, **kwargs) + + async def _bucketOperation_download( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> bytes: + async with await target.open_download_stream(*args, **kwargs) as gout: + return await gout.read() + + async def _bucketOperation_downloadByName( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> bytes: + async with await target.open_download_stream_by_name(*args, **kwargs) as gout: + return await gout.read() + + async def _bucketOperation_upload( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> ObjectId: + kwargs["source"] = binascii.unhexlify(kwargs.pop("source")["$$hexBytes"]) + if "content_type" in kwargs: + kwargs.setdefault("metadata", {})["contentType"] = kwargs.pop("content_type") + return await target.upload_from_stream(*args, **kwargs) + + async def _bucketOperation_uploadWithId( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> Any: + kwargs["source"] = binascii.unhexlify(kwargs.pop("source")["$$hexBytes"]) + if "content_type" in kwargs: + kwargs.setdefault("metadata", {})["contentType"] = kwargs.pop("content_type") + return await target.upload_from_stream_with_id(*args, **kwargs) + + async def _bucketOperation_find( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> List[GridOut]: + return await target.find(*args, **kwargs).to_list() + + async def run_entity_operation(self, spec): + target = self.entity_map[spec["object"]] + opname = spec["name"] + opargs = spec.get("arguments") + expect_error = spec.get("expectError") + save_as_entity = spec.get("saveResultAsEntity") + expect_result = spec.get("expectResult") + ignore = spec.get("ignoreResultAndError") + if ignore and (expect_error or save_as_entity or expect_result): + raise ValueError( + "ignoreResultAndError is incompatible with saveResultAsEntity" + ", expectError, and expectResult" + ) + if opargs: + arguments = parse_spec_options(copy.deepcopy(opargs)) + prepare_spec_arguments( + spec, + arguments, + camel_to_snake(opname), + self.entity_map, + self.run_operations_and_throw, + ) + else: + arguments = {} + + if isinstance(target, AsyncMongoClient): + method_name = f"_clientOperation_{opname}" + elif isinstance(target, AsyncDatabase): + method_name = f"_databaseOperation_{opname}" + elif isinstance(target, AsyncCollection): + method_name = f"_collectionOperation_{opname}" + # contentType is always stored in metadata in pymongo. + if target.name.endswith(".files") and opname == "find": + for doc in spec.get("expectResult", []): + if "contentType" in doc: + doc.setdefault("metadata", {})["contentType"] = doc.pop("contentType") + elif isinstance(target, AsyncChangeStream): + method_name = f"_changeStreamOperation_{opname}" + elif isinstance(target, (NonLazyCursor, AsyncCommandCursor)): + method_name = f"_cursor_{opname}" + elif isinstance(target, AsyncClientSession): + method_name = f"_sessionOperation_{opname}" + elif isinstance(target, AsyncGridFSBucket): + method_name = f"_bucketOperation_{opname}" + if "id" in arguments: + arguments["file_id"] = arguments.pop("id") + # MD5 is always disabled in pymongo. + arguments.pop("disable_md5", None) + elif isinstance(target, AsyncClientEncryption): + method_name = f"_clientEncryptionOperation_{opname}" + else: + method_name = "doesNotExist" + + try: + method = getattr(self, method_name) + except AttributeError: + target_opname = camel_to_snake(opname) + if target_opname == "iterate_once": + target_opname = "try_next" + if target_opname == "client_bulk_write": + target_opname = "bulk_write" + try: + cmd = getattr(target, target_opname) + except AttributeError: + self.fail(f"Unsupported operation {opname} on entity {target}") + else: + cmd = functools.partial(method, target) + + try: + # CSOT: Translate the spec test "timeout" arg into pymongo's context timeout API. + if "timeout" in arguments: + timeout = arguments.pop("timeout") + with pymongo.timeout(timeout): + result = await cmd(**dict(arguments)) + else: + result = await cmd(**dict(arguments)) + except Exception as exc: + # Ignore all operation errors but to avoid masking bugs don't + # ignore things like TypeError and ValueError. + if ignore and isinstance(exc, (PyMongoError,)): + return exc + if expect_error: + if method_name == "_collectionOperation_bulkWrite": + self.skipTest("Skipping test pending PYTHON-4598") + return self.process_error(exc, expect_error) + raise + else: + if method_name == "_collectionOperation_bulkWrite": + self.skipTest("Skipping test pending PYTHON-4598") + if expect_error: + self.fail(f'Excepted error {expect_error} but "{opname}" succeeded: {result}') + + if expect_result: + actual = coerce_result(opname, result) + self.match_evaluator.match_result(expect_result, actual) + + if save_as_entity: + self.entity_map[save_as_entity] = result + return None + return None + + async def __set_fail_point(self, client, command_args): + if not async_client_context.test_commands_enabled: + self.skipTest("Test commands must be enabled") + + cmd_on = SON([("configureFailPoint", "failCommand")]) + cmd_on.update(command_args) + await client.admin.command(cmd_on) + self.addAsyncCleanup( + client.admin.command, "configureFailPoint", cmd_on["configureFailPoint"], mode="off" + ) + + async def _testOperation_failPoint(self, spec): + await self.__set_fail_point( + client=self.entity_map[spec["client"]], command_args=spec["failPoint"] + ) + + async def _testOperation_targetedFailPoint(self, spec): + session = self.entity_map[spec["session"]] + if not session._pinned_address: + self.fail( + "Cannot use targetedFailPoint operation with unpinned " "session {}".format( + spec["session"] + ) + ) + + client = await self.async_single_client("{}:{}".format(*session._pinned_address)) + self.addAsyncCleanup(client.close) + await self.__set_fail_point(client=client, command_args=spec["failPoint"]) + + async def _testOperation_createEntities(self, spec): + await self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri) + await self.entity_map.advance_cluster_times() + + def _testOperation_assertSessionTransactionState(self, spec): + session = self.entity_map[spec["session"]] + expected_state = getattr(_TxnState, spec["state"].upper()) + self.assertEqual(expected_state, session._transaction.state) + + def _testOperation_assertSessionPinned(self, spec): + session = self.entity_map[spec["session"]] + self.assertIsNotNone(session._transaction.pinned_address) + + def _testOperation_assertSessionUnpinned(self, spec): + session = self.entity_map[spec["session"]] + self.assertIsNone(session._pinned_address) + self.assertIsNone(session._transaction.pinned_address) + + def __get_last_two_command_lsids(self, listener): + cmd_started_events = [] + for event in reversed(listener.events): + if isinstance(event, CommandStartedEvent): + cmd_started_events.append(event) + if len(cmd_started_events) < 2: + self.fail( + "Needed 2 CommandStartedEvents to compare lsids, " + "got %s" % (len(cmd_started_events)) + ) + return tuple([e.command["lsid"] for e in cmd_started_events][:2]) + + def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec): + listener = self.entity_map.get_listener_for_client(spec["client"]) + self.assertNotEqual(*self.__get_last_two_command_lsids(listener)) + + def _testOperation_assertSameLsidOnLastTwoCommands(self, spec): + listener = self.entity_map.get_listener_for_client(spec["client"]) + self.assertEqual(*self.__get_last_two_command_lsids(listener)) + + def _testOperation_assertSessionDirty(self, spec): + session = self.entity_map[spec["session"]] + self.assertTrue(session._server_session.dirty) + + def _testOperation_assertSessionNotDirty(self, spec): + session = self.entity_map[spec["session"]] + return self.assertFalse(session._server_session.dirty) + + async def _testOperation_assertCollectionExists(self, spec): + database_name = spec["databaseName"] + collection_name = spec["collectionName"] + collection_name_list = list( + await self.client.get_database(database_name).list_collection_names() + ) + self.assertIn(collection_name, collection_name_list) + + async def _testOperation_assertCollectionNotExists(self, spec): + database_name = spec["databaseName"] + collection_name = spec["collectionName"] + collection_name_list = list( + await self.client.get_database(database_name).list_collection_names() + ) + self.assertNotIn(collection_name, collection_name_list) + + async def _testOperation_assertIndexExists(self, spec): + collection = self.client[spec["databaseName"]][spec["collectionName"]] + index_names = [idx["name"] async for idx in await collection.list_indexes()] + self.assertIn(spec["indexName"], index_names) + + async def _testOperation_assertIndexNotExists(self, spec): + collection = self.client[spec["databaseName"]][spec["collectionName"]] + async for index in await collection.list_indexes(): + self.assertNotEqual(spec["indexName"], index["name"]) + + async def _testOperation_assertNumberConnectionsCheckedOut(self, spec): + client = self.entity_map[spec["client"]] + pool = await async_get_pool(client) + self.assertEqual(spec["connections"], pool.active_sockets) + + def _event_count(self, client_name, event): + listener = self.entity_map.get_listener_for_client(client_name) + actual_events = listener.get_events("all") + count = 0 + for actual in actual_events: + try: + self.match_evaluator.match_event(event, actual) + except AssertionError: + continue + else: + count += 1 + return count + + def _testOperation_assertEventCount(self, spec): + """Run the assertEventCount test operation. + + Assert the given event was published exactly `count` times. + """ + client, event, count = spec["client"], spec["event"], spec["count"] + self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}") + + def _testOperation_waitForEvent(self, spec): + """Run the waitForEvent test operation. + + Wait for a number of events to be published, or fail. + """ + client, event, count = spec["client"], spec["event"], spec["count"] + wait_until( + lambda: self._event_count(client, event) >= count, + f"find {count} {event} event(s)", + ) + + async def _testOperation_wait(self, spec): + """Run the "wait" test operation.""" + await asyncio.sleep(spec["ms"] / 1000.0) + + def _testOperation_recordTopologyDescription(self, spec): + """Run the recordTopologyDescription test operation.""" + self.entity_map[spec["id"]] = self.entity_map[spec["client"]].topology_description + + def _testOperation_assertTopologyType(self, spec): + """Run the assertTopologyType test operation.""" + description = self.entity_map[spec["topologyDescription"]] + self.assertIsInstance(description, TopologyDescription) + self.assertEqual(description.topology_type_name, spec["topologyType"]) + + def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: + """Run the waitForPrimaryChange test operation.""" + client = self.entity_map[spec["client"]] + old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]] + timeout = spec["timeoutMS"] / 1000.0 + + def get_primary(td: TopologyDescription) -> Optional[_Address]: + servers = writable_server_selector(Selection.from_topology_description(td)) + if servers and servers[0].server_type == SERVER_TYPE.RSPrimary: + return servers[0].address + return None + + old_primary = get_primary(old_description) + + def primary_changed() -> bool: + primary = client.primary + if primary is None: + return False + return primary != old_primary + + wait_until(primary_changed, "change primary", timeout=timeout) + + def _testOperation_runOnThread(self, spec): + """Run the 'runOnThread' operation.""" + thread = self.entity_map[spec["thread"]] + thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + + def _testOperation_waitForThread(self, spec): + """Run the 'waitForThread' operation.""" + thread = self.entity_map[spec["thread"]] + thread.stop() + thread.join(10) + if thread.exc: + raise thread.exc + self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) + + async def _testOperation_loop(self, spec): + failure_key = spec.get("storeFailuresAsEntity") + error_key = spec.get("storeErrorsAsEntity") + successes_key = spec.get("storeSuccessesAsEntity") + iteration_key = spec.get("storeIterationsAsEntity") + iteration_limiter_key = spec.get("numIterations") + for i in [failure_key, error_key]: + if i: + self.entity_map[i] = [] + for i in [successes_key, iteration_key]: + if i: + self.entity_map[i] = 0 + i = 0 + global IS_INTERRUPTED + while True: + if iteration_limiter_key and i >= iteration_limiter_key: + break + i += 1 + if IS_INTERRUPTED: + break + try: + if iteration_key: + self.entity_map._entities[iteration_key] += 1 + for op in spec["operations"]: + await self.run_entity_operation(op) + if successes_key: + self.entity_map._entities[successes_key] += 1 + except Exception as exc: + if isinstance(exc, AssertionError): + key = failure_key or error_key + else: + key = error_key or failure_key + if not key: + raise + self.entity_map[key].append( + {"error": str(exc), "time": time.time(), "type": type(exc).__name__} + ) + + async def run_special_operation(self, spec): + opname = spec["name"] + method_name = f"_testOperation_{opname}" + try: + method = getattr(self, method_name) + except AttributeError: + self.fail(f"Unsupported special test operation {opname}") + else: + if iscoroutinefunction(method): + await method(spec["arguments"]) + else: + method(spec["arguments"]) + + async def run_operations(self, spec): + for op in spec: + if op["object"] == "testRunner": + await self.run_special_operation(op) + else: + await self.run_entity_operation(op) + + async def run_operations_and_throw(self, spec): + for op in spec: + if op["object"] == "testRunner": + await self.run_special_operation(op) + else: + result = await self.run_entity_operation(op) + if isinstance(result, Exception): + raise result + + def check_events(self, spec): + for event_spec in spec: + client_name = event_spec["client"] + events = event_spec["events"] + event_type = event_spec.get("eventType", "command") + ignore_extra_events = event_spec.get("ignoreExtraEvents", False) + server_connection_id = event_spec.get("serverConnectionId") + has_server_connection_id = event_spec.get("hasServerConnectionId", False) + listener = self.entity_map.get_listener_for_client(client_name) + actual_events = listener.get_events(event_type) + if ignore_extra_events: + actual_events = actual_events[: len(events)] + + if len(events) == 0: + self.assertEqual(actual_events, []) + continue + + if len(actual_events) != len(events): + expected = "\n".join(str(e) for e in events) + actual = "\n".join(str(a) for a in actual_events) + self.assertEqual( + len(actual_events), + len(events), + f"expected events:\n{expected}\nactual events:\n{actual}", + ) + + for idx, expected_event in enumerate(events): + self.match_evaluator.match_event(expected_event, actual_events[idx]) + + if has_server_connection_id: + assert server_connection_id is not None + assert server_connection_id >= 0 + else: + assert server_connection_id is None + + def process_ignore_messages(self, ignore_logs, actual_logs): + final_logs = [] + for log in actual_logs: + ignored = False + for ignore_log in ignore_logs: + if log["data"]["message"] == ignore_log["data"][ + "message" + ] and self.match_evaluator.match_result(ignore_log, log, test=False): + ignored = True + break + if not ignored: + final_logs.append(log) + return final_logs + + async def check_log_messages(self, operations, spec): + def format_logs(log_list): + client_to_log = defaultdict(list) + for log in log_list: + if log.module == "ocsp_support": + continue + data = json_util.loads(log.getMessage()) + client = data.pop("clientId") if "clientId" in data else data.pop("topologyId") + client_to_log[client].append( + { + "level": log.levelname.lower(), + "component": log.name.replace("pymongo.", "", 1), + "data": data, + } + ) + return client_to_log + + with self.assertLogs("pymongo", level="DEBUG") as cm: + await self.run_operations(operations) + formatted_logs = format_logs(cm.records) + for client in spec: + components = set() + for message in client["messages"]: + components.add(message["component"]) + + clientid = self.entity_map[client["client"]]._topology_settings._topology_id + actual_logs = formatted_logs[clientid] + actual_logs = [log for log in actual_logs if log["component"] in components] + + ignore_logs = client.get("ignoreMessages", []) + if ignore_logs: + actual_logs = self.process_ignore_messages(ignore_logs, actual_logs) + + if client.get("ignoreExtraMessages", False): + actual_logs = actual_logs[: len(client["messages"])] + self.assertEqual( + len(client["messages"]), + len(actual_logs), + f"expected {client['messages']} but got {actual_logs}", + ) + for expected_msg, actual_msg in zip(client["messages"], actual_logs): + expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") + + if "failureIsRedacted" in expected_msg: + self.assertIn("failure", actual_data) + should_redact = expected_msg.pop("failureIsRedacted") + if should_redact: + actual_fields = set(json_util.loads(actual_data["failure"]).keys()) + self.assertTrue( + {"code", "codeName", "errorLabels"}.issuperset(actual_fields) + ) + + self.match_evaluator.match_result(expected_data, actual_data) + self.match_evaluator.match_result(expected_msg, actual_msg) + + async def verify_outcome(self, spec): + for collection_data in spec: + coll_name = collection_data["collectionName"] + db_name = collection_data["databaseName"] + expected_documents = collection_data["documents"] + + coll = self.client.get_database(db_name).get_collection( + coll_name, + read_preference=ReadPreference.PRIMARY, + read_concern=ReadConcern(level="local"), + ) + + if expected_documents: + sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"]) + actual_documents = await coll.find({}, sort=[("_id", ASCENDING)]).to_list() + self.assertListEqual(sorted_expected_documents, actual_documents) + + async def run_scenario(self, spec, uri=None): + if "csot" in self.id().lower() and SKIP_CSOT_TESTS: + raise unittest.SkipTest("SKIP_CSOT_TESTS is set, skipping...") + + # Kill all sessions before and after each test to prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + await self.kill_all_sessions() + self.addAsyncCleanup(self.kill_all_sessions) + + if "csot" in self.id().lower(): + # Retry CSOT tests up to 2 times to deal with flakey tests. + attempts = 3 + for i in range(attempts): + try: + return await self._run_scenario(spec, uri) + except AssertionError: + if i < attempts - 1: + print( + f"Retrying after attempt {i+1} of {self.id()} failed with:\n" + f"{traceback.format_exc()}", + file=sys.stderr, + ) + await self.asyncSetUp() + continue + raise + return None + else: + await self._run_scenario(spec, uri) + return None + + async def _run_scenario(self, spec, uri=None): + # maybe skip test manually + self.maybe_skip_test(spec) + + # process test-level runOnRequirements + run_on_spec = spec.get("runOnRequirements", []) + if not await self.should_run_on(run_on_spec): + raise unittest.SkipTest("runOnRequirements not satisfied") + + # process skipReason + skip_reason = spec.get("skipReason", None) + if skip_reason is not None: + raise unittest.SkipTest(f"{skip_reason}") + + # process createEntities + self._uri = uri + self.entity_map = EntityMapUtil(self) + await self.entity_map.create_entities_from_spec( + self.TEST_SPEC.get("createEntities", []), uri=uri + ) + # process initialData + if "initialData" in self.TEST_SPEC: + await self.insert_initial_data(self.TEST_SPEC["initialData"]) + self._cluster_time = (await self.client.admin.command("ping")).get("$clusterTime") + await self.entity_map.advance_cluster_times() + + if "expectLogMessages" in spec: + expect_log_messages = spec["expectLogMessages"] + self.assertTrue(expect_log_messages, "expectEvents must be non-empty") + await self.check_log_messages(spec["operations"], expect_log_messages) + else: + # process operations + await self.run_operations(spec["operations"]) + + # process expectEvents + if "expectEvents" in spec: + expect_events = spec["expectEvents"] + self.assertTrue(expect_events, "expectEvents must be non-empty") + self.check_events(expect_events) + + # process outcome + await self.verify_outcome(spec.get("outcome", [])) + + +class UnifiedSpecTestMeta(type): + """Metaclass for generating test classes.""" + + TEST_SPEC: Any + EXPECTED_FAILURES: Any + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + + def create_test(spec): + async def test_case(self): + await self.run_scenario(spec) + + return test_case + + for test_spec in cls.TEST_SPEC["tests"]: + description = test_spec["description"] + test_name = "test_{}".format( + description.strip(". ").replace(" ", "_").replace(".", "_") + ) + test_method = create_test(copy.deepcopy(test_spec)) + test_method.__name__ = str(test_name) + + for fail_pattern in cls.EXPECTED_FAILURES: + if re.search(fail_pattern, description): + test_method = unittest.expectedFailure(test_method) + break + + setattr(cls, test_name, test_method) + + +_ALL_MIXIN_CLASSES = [ + UnifiedSpecTestMixinV1, + # add mixin classes for new schema major versions here +] + + +_SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = { + KLASS.SCHEMA_VERSION[0]: KLASS for KLASS in _ALL_MIXIN_CLASSES +} + + +def generate_test_classes( + test_path, + module=__name__, + class_name_prefix="", + expected_failures=[], # noqa: B006 + bypass_test_generation_errors=False, + **kwargs, +): + """Method for generating test classes. Returns a dictionary where keys are + the names of test classes and values are the test class objects. + """ + test_klasses = {} + + def test_base_class_factory(test_spec): + """Utility that creates the base class to use for test generation. + This is needed to ensure that cls.TEST_SPEC is appropriately set when + the metaclass __init__ is invoked. + """ + + class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore + TEST_SPEC = test_spec + EXPECTED_FAILURES = expected_failures + + return SpecTestBase + + for dirpath, _, filenames in os.walk(test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + fpath = os.path.join(dirpath, filename) + with open(fpath) as scenario_stream: + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = json_util.loads(scenario_stream.read(), json_options=opts) + + test_type = os.path.splitext(filename)[0] + snake_class_name = "Test{}_{}_{}".format( + class_name_prefix, + dirname.replace("-", "_"), + test_type.replace("-", "_").replace(".", "_"), + ) + class_name = snake_to_camel(snake_class_name) + + try: + schema_version = Version.from_string(scenario_def["schemaVersion"]) + mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get(schema_version[0]) + if mixin_class is None: + raise ValueError( + f"test file '{fpath}' has unsupported schemaVersion '{schema_version}'" + ) + module_dict = {"__module__": module, "TEST_PATH": test_path} + module_dict.update(kwargs) + test_klasses[class_name] = type( + class_name, + ( + mixin_class, + test_base_class_factory(scenario_def), + ), + module_dict, + ) + except Exception: + if bypass_test_generation_errors: + continue + raise + + return test_klasses diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 12cb13c2c..4d9c4c8f2 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -15,8 +15,12 @@ """Utilities for testing driver specs.""" from __future__ import annotations +import asyncio import functools +import os import threading +import unittest +from asyncio import iscoroutinefunction from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs from test.utils import ( @@ -24,6 +28,7 @@ from test.utils import ( CompareType, EventListener, OvertCommandListener, + ScenarioDict, ServerAndTopologyEventListener, camel_to_snake, camel_to_snake_args, @@ -32,11 +37,12 @@ from test.utils import ( ) from typing import List -from bson import ObjectId, decode, encode +from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket +from gridfs.asynchronous.grid_file import AsyncGridFSBucket from pymongo.asynchronous import client_session from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor @@ -83,6 +89,161 @@ class SpecRunnerThread(threading.Thread): self.stop() +class AsyncSpecTestCreator: + """Class to create test cases from specifications.""" + + def __init__(self, create_test, test_class, test_path): + """Create a TestCreator object. + + :Parameters: + - `create_test`: callback that returns a test case. The callback + must accept the following arguments - a dictionary containing the + entire test specification (the `scenario_def`), a dictionary + containing the specification for which the test case will be + generated (the `test_def`). + - `test_class`: the unittest.TestCase class in which to create the + test case. + - `test_path`: path to the directory containing the JSON files with + the test specifications. + """ + self._create_test = create_test + self._test_class = test_class + self.test_path = test_path + + def _ensure_min_max_server_version(self, scenario_def, method): + """Test modifier that enforces a version range for the server on a + test case. + """ + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) + if min_ver is not None: + method = async_client_context.require_version_min(*min_ver)(method) + + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) + if max_ver is not None: + method = async_client_context.require_version_max(*max_ver)(method) + + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] + if serverless == "require": + serverless_satisfied = async_client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not async_client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + method = unittest.skipUnless( + serverless_satisfied, "Serverless requirement not satisfied" + )(method) + + return method + + @staticmethod + async def valid_topology(run_on_req): + return await async_client_context.is_topology_type( + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) + + @staticmethod + def min_server_version(run_on_req): + version = run_on_req.get("minServerVersion") + if version: + min_ver = tuple(int(elt) for elt in version.split(".")) + return async_client_context.version >= min_ver + return True + + @staticmethod + def max_server_version(run_on_req): + version = run_on_req.get("maxServerVersion") + if version: + max_ver = tuple(int(elt) for elt in version.split(".")) + return async_client_context.version <= max_ver + return True + + @staticmethod + def valid_auth_enabled(run_on_req): + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: + return async_client_context.auth_enabled + return not async_client_context.auth_enabled + return True + + @staticmethod + def serverless_ok(run_on_req): + serverless = run_on_req["serverless"] + if serverless == "require": + return async_client_context.serverless + elif serverless == "forbid": + return not async_client_context.serverless + else: # unset or "allow" + return True + + async def should_run_on(self, scenario_def): + run_on = scenario_def.get("runOn", []) + if not run_on: + # Always run these tests. + return True + + for req in run_on: + if ( + await self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): + return True + return False + + def ensure_run_on(self, scenario_def, method): + """Test modifier that enforces a 'runOn' on a test case.""" + + async def predicate(): + return await self.should_run_on(scenario_def) + + return async_client_context._require(predicate, "runOn not satisfied", method) + + def tests(self, scenario_def): + """Allow CMAP spec test to override the location of test.""" + return scenario_def["tests"] + + async def _create_tests(self): + for dirpath, _, filenames in os.walk(self.test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = ScenarioDict( + json_util.loads(scenario_stream.read(), json_options=opts) + ) + + test_type = os.path.splitext(filename)[0] + + # Construct test from scenario. + for test_def in self.tests(scenario_def): + test_name = "test_{}_{}_{}".format( + dirname, + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) + + new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) + + new_test.__name__ = test_name + setattr(self._test_class, new_test.__name__, new_test) + + def create_tests(self): + if _IS_SYNC: + self._create_tests() + else: + asyncio.run(self._create_tests()) + + class AsyncSpecRunner(AsyncIntegrationTest): mongos_clients: List knobs: client_knobs @@ -284,7 +445,7 @@ class AsyncSpecRunner(AsyncIntegrationTest): if object_name == "gridfsbucket": # Only create the GridFSBucket when we need it (for the gridfs # retryable reads tests). - obj = GridFSBucket(database, bucket_name=collection.name) + obj = AsyncGridFSBucket(database, bucket_name=collection.name) else: objects = { "client": database.client, @@ -312,7 +473,10 @@ class AsyncSpecRunner(AsyncIntegrationTest): args.update(arguments) arguments = args - result = cmd(**dict(arguments)) + if not _IS_SYNC and iscoroutinefunction(cmd): + result = await cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addAsyncCleanup(result.close) @@ -588,7 +752,7 @@ class AsyncSpecRunner(AsyncIntegrationTest): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list() + actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/test/client-side-encryption/spec/legacy/timeoutMS.json b/test/client-side-encryption/spec/legacy/timeoutMS.json index b667767cf..841130622 100644 --- a/test/client-side-encryption/spec/legacy/timeoutMS.json +++ b/test/client-side-encryption/spec/legacy/timeoutMS.json @@ -110,7 +110,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 60 + "blockTimeMS": 600 } }, "clientOptions": { @@ -119,7 +119,7 @@ "aws": {} } }, - "timeoutMS": 50 + "timeoutMS": 500 }, "operations": [ { diff --git a/test/test_client_context.py b/test/test_client_context.py index be8a56214..5996f9243 100644 --- a/test/test_client_context.py +++ b/test/test_client_context.py @@ -61,6 +61,13 @@ class TestClientContext(UnitTest): self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"]) + def test_free_threading_is_enabled(self): + if "free-threading build" not in sys.version: + raise SkipTest("this test requires the Python free-threading build") + + # If the GIL is enabled then pymongo or one of our deps does not support free-threading. + self.assertFalse(sys._is_gil_enabled()) # type: ignore[attr-defined] + if __name__ == "__main__": unittest.main() diff --git a/test/test_collation.py b/test/test_collation.py index 19df25c1c..e5c1c7eb1 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -37,8 +37,11 @@ from pymongo.operations import ( UpdateMany, UpdateOne, ) +from pymongo.synchronous.helpers import next from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestCollationObject(unittest.TestCase): def test_constructor(self): @@ -96,8 +99,8 @@ class TestCollation(IntegrationTest): @classmethod @client_context.require_connection - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.listener = EventListener() cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test @@ -107,11 +110,11 @@ class TestCollation(IntegrationTest): warnings.simplefilter("ignore", DeprecationWarning) @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.warn_context.__exit__() cls.warn_context = None cls.client.close() - super().tearDownClass() + super()._tearDown_class() def tearDown(self): self.listener.reset() diff --git a/test/test_command_logging.py b/test/test_command_logging.py index 9b2d52e66..cf865920c 100644 --- a/test/test_command_logging.py +++ b/test/test_command_logging.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,14 @@ sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_logging") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") + globals().update( generate_test_classes( diff --git a/test/test_command_monitoring.py b/test/test_command_monitoring.py index d2f578824..4f5ef06f2 100644 --- a/test/test_command_monitoring.py +++ b/test/test_command_monitoring.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,13 @@ sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_monitoring") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") globals().update( diff --git a/test/test_common.py b/test/test_common.py index 3228dc97f..e69b421c9 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -28,10 +28,7 @@ from bson.objectid import ObjectId from pymongo.errors import OperationFailure from pymongo.write_concern import WriteConcern - -@client_context.require_connection -def setUpModule(): - pass +_IS_SYNC = True class TestCommon(IntegrationTest): @@ -48,12 +45,12 @@ class TestCommon(IntegrationTest): coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) legacy_opts = coll.codec_options coll.insert_one({"uu": uu}) - self.assertEqual(uu, coll.find_one({"uu": uu})["uu"]) # type: ignore + self.assertEqual(uu, (coll.find_one({"uu": uu}))["uu"]) # type: ignore coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) self.assertEqual(STANDARD, coll.codec_options.uuid_representation) self.assertEqual(None, coll.find_one({"uu": uu})) uul = Binary.from_uuid(uu, PYTHON_LEGACY) - self.assertEqual(uul, coll.find_one({"uu": uul})["uu"]) # type: ignore + self.assertEqual(uul, (coll.find_one({"uu": uul}))["uu"]) # type: ignore # Test count_documents self.assertEqual(0, coll.count_documents({"uu": uu})) @@ -73,9 +70,9 @@ class TestCommon(IntegrationTest): coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) coll.update_one({"_id": uu}, {"$set": {"i": 2}}) coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - self.assertEqual(1, coll.find_one({"_id": uu})["i"]) # type: ignore + self.assertEqual(1, (coll.find_one({"_id": uu}))["i"]) # type: ignore coll.update_one({"_id": uu}, {"$set": {"i": 2}}) - self.assertEqual(2, coll.find_one({"_id": uu})["i"]) # type: ignore + self.assertEqual(2, (coll.find_one({"_id": uu}))["i"]) # type: ignore # Test Cursor.distinct self.assertEqual([2], coll.find({"_id": uu}).distinct("i")) @@ -85,27 +82,31 @@ class TestCommon(IntegrationTest): # Test findAndModify self.assertEqual(None, coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}})) coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - self.assertEqual(2, coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}})["i"]) - self.assertEqual(5, coll.find_one({"_id": uu})["i"]) # type: ignore + self.assertEqual(2, (coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))["i"]) + self.assertEqual(5, (coll.find_one({"_id": uu}))["i"]) # type: ignore # Test command self.assertEqual( 5, - self.db.command( - "findAndModify", - "uuid", - update={"$set": {"i": 6}}, - query={"_id": uu}, - codec_options=legacy_opts, + ( + self.db.command( + "findAndModify", + "uuid", + update={"$set": {"i": 6}}, + query={"_id": uu}, + codec_options=legacy_opts, + ) )["value"]["i"], ) self.assertEqual( 6, - self.db.command( - "findAndModify", - "uuid", - update={"$set": {"i": 7}}, - query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)}, + ( + self.db.command( + "findAndModify", + "uuid", + update={"$set": {"i": 7}}, + query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)}, + ) )["value"]["i"], ) @@ -140,20 +141,23 @@ class TestCommon(IntegrationTest): coll.insert_one(doc) self.assertTrue(coll.insert_one(doc)) coll = coll.with_options(write_concern=WriteConcern(w=1)) - self.assertRaises(OperationFailure, coll.insert_one, doc) + with self.assertRaises(OperationFailure): + coll.insert_one(doc) m = self.rs_or_single_client() coll = m.pymongo_test.write_concern_test new_coll = coll.with_options(write_concern=WriteConcern(w=0)) self.assertTrue(new_coll.insert_one(doc)) - self.assertRaises(OperationFailure, coll.insert_one, doc) + with self.assertRaises(OperationFailure): + coll.insert_one(doc) m = self.rs_or_single_client( f"mongodb://{pair}/", replicaSet=client_context.replica_set_name ) coll = m.pymongo_test.write_concern_test - self.assertRaises(OperationFailure, coll.insert_one, doc) + with self.assertRaises(OperationFailure): + coll.insert_one(doc) m = self.rs_or_single_client( f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name ) diff --git a/test/test_connection_logging.py b/test/test_connection_logging.py index 262ce821e..253193cc4 100644 --- a/test/test_connection_logging.py +++ b/test/test_connection_logging.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,13 @@ sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_logging") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") globals().update( diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 142af0f9a..d576a1184 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -25,14 +25,13 @@ from test import IntegrationTest, client_knobs, unittest from test.pymongo_mocks import DummyMonitor from test.utils import ( CMAPListener, - SpecTestCreator, camel_to_snake, client_context, get_pool, get_pools, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread +from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator from bson.objectid import ObjectId from bson.son import SON diff --git a/test/test_crud_unified.py b/test/test_crud_unified.py index 92a60a47f..26f34cba8 100644 --- a/test/test_crud_unified.py +++ b/test/test_crud_unified.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,11 +24,16 @@ sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "unified") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) +globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) if __name__ == "__main__": unittest.main() diff --git a/test/test_encryption.py b/test/test_encryption.py index 43c85e2c5..13a69ca9a 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -30,6 +30,7 @@ import uuid import warnings from test import IntegrationTest, PyMongoTestCase, client_context from test.test_bulk import BulkTestBase +from test.utils_spec_runner import SpecRunner, SpecTestCreator from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -58,7 +59,6 @@ from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, OvertCommandListener, - SpecTestCreator, TopologyEventListener, camel_to_snake_args, is_greenthread_patched, @@ -624,130 +624,132 @@ AWS_TEMP_NO_SESSION_CREDS = { KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} -if _IS_SYNC: - # TODO: Add synchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(SpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def setUpClass(cls): - super().setUpClass() +class TestSpec(SpecRunner): + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + def _setup_class(cls): + super()._setup_class() - def parse_auto_encrypt_opts(self, opts): - """Parse clientOptions.autoEncryptOpts.""" - opts = camel_to_snake_args(opts) - kms_providers = opts["kms_providers"] - if "aws" in kms_providers: - kms_providers["aws"] = AWS_CREDS - if not any(AWS_CREDS.values()): - self.skipTest("AWS environment credentials are not set") - if "awsTemporary" in kms_providers: - kms_providers["aws"] = AWS_TEMP_CREDS - del kms_providers["awsTemporary"] - if not any(AWS_TEMP_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "awsTemporaryNoSessionToken" in kms_providers: - kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers["awsTemporaryNoSessionToken"] - if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "azure" in kms_providers: - kms_providers["azure"] = AZURE_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("Azure environment credentials are not set") - if "gcp" in kms_providers: - kms_providers["gcp"] = GCP_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("GCP environment credentials are not set") - if "kmip" in kms_providers: - kms_providers["kmip"] = KMIP_CREDS - opts["kms_tls_options"] = KMS_TLS_OPTS - if "key_vault_namespace" not in opts: - opts["key_vault_namespace"] = "keyvault.datakeys" - if "extra_options" in opts: - opts.update(camel_to_snake_args(opts.pop("extra_options"))) + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] + if not any(AWS_TEMP_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] + if not any(AWS_TEMP_NO_SESSION_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP_CREDS + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" + if "extra_options" in opts: + opts.update(camel_to_snake_args(opts.pop("extra_options"))) - opts = dict(opts) - return AutoEncryptionOpts(**opts) + opts = dict(opts) + return AutoEncryptionOpts(**opts) - def parse_client_options(self, opts): - """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop("autoEncryptOpts", None) - if encrypt_opts: - opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop("autoEncryptOpts", None) + if encrypt_opts: + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - return super().parse_client_options(opts) + return super().parse_client_options(opts) - def get_object_name(self, op): - """Default object is collection.""" - return op.get("object", "collection") + def get_object_name(self, op): + """Default object is collection.""" + return op.get("object", "collection") - def maybe_skip_scenario(self, test): - super().maybe_skip_scenario(test) - desc = test["description"].lower() - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and sys.platform in ("win32", "darwin") - ): - self.skipTest("PYTHON-3706 flaky test on Windows/macOS") - if "type=symbol" in desc: - self.skipTest("PyMongo does not support the symbol type") + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + desc = test["description"].lower() + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and sys.platform in ("win32", "darwin") + ): + self.skipTest("PYTHON-3706 flaky test on Windows/macOS") + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") + if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC: + self.skipTest("PYTHON-4844 flaky test on async") - def setup_scenario(self, scenario_def): - """Override a test's setup.""" - key_vault_data = scenario_def["key_vault_data"] - encrypted_fields = scenario_def["encrypted_fields"] - json_schema = scenario_def["json_schema"] - data = scenario_def["data"] - coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] - coll.delete_many({}) - if key_vault_data: - coll.insert_many(key_vault_data) + def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def["key_vault_data"] + encrypted_fields = scenario_def["encrypted_fields"] + json_schema = scenario_def["json_schema"] + data = scenario_def["data"] + coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] + coll.delete_many({}) + if key_vault_data: + coll.insert_many(key_vault_data) - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - db = client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) - wc = WriteConcern(w="majority") - kwargs: Dict[str, Any] = {} - if json_schema: - kwargs["validator"] = {"$jsonSchema": json_schema} - kwargs["codec_options"] = OPTS - if not data: - kwargs["write_concern"] = wc - if encrypted_fields: - kwargs["encryptedFields"] = encrypted_fields - db.create_collection(coll_name, **kwargs) - coll = db[coll_name] - if data: - # Load data. - coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = client_context.client.get_database(db_name, codec_options=OPTS) + db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + wc = WriteConcern(w="majority") + kwargs: Dict[str, Any] = {} + if json_schema: + kwargs["validator"] = {"$jsonSchema": json_schema} + kwargs["codec_options"] = OPTS + if not data: + kwargs["write_concern"] = wc + if encrypted_fields: + kwargs["encryptedFields"] = encrypted_fields + db.create_collection(coll_name, **kwargs) + coll = db[coll_name] + if data: + # Load data. + coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) - def allowable_errors(self, op): - """Override expected error classes.""" - errors = super().allowable_errors(op) - # An updateOne test expects encryption to error when no $ operator - # appears but pymongo raises a client side ValueError in this case. - if op["name"] == "updateOne": - errors += (ValueError,) - return errors + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super().allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op["name"] == "updateOne": + errors += (ValueError,) + return errors - def create_test(scenario_def, test, name): - @client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) - return run_scenario +def create_test(scenario_def, test, name): + @client_context.require_test_commands + def run_scenario(self): + self.run_scenario(scenario_def, test) - test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) - test_creator.create_tests() + return run_scenario - if _HAVE_PYMONGOCRYPT: - globals().update( - generate_test_classes( - os.path.join(SPEC_PATH, "unified"), - module=__name__, - ) + +test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator.create_tests() + +if _HAVE_PYMONGOCRYPT: + globals().update( + generate_test_classes( + os.path.join(SPEC_PATH, "unified"), + module=__name__, ) + ) # Prose Tests ALL_KMS_PROVIDERS = { diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 7cab42cca..05772fa38 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -21,11 +21,11 @@ from test import IntegrationTest, client_context, unittest from test.utils import ( CMAPListener, OvertCommandListener, - SpecTestCreator, get_pool, wait_until, ) from test.utils_selection_tests import create_topology +from test.utils_spec_runner import SpecTestCreator from pymongo.common import clean_node from pymongo.monitoring import ConnectionReadyEvent diff --git a/test/unified_format.py b/test/unified_format.py index 62211d3d2..6a19082b8 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -18,41 +18,41 @@ https://github.com/mongodb/specifications/blob/master/source/unified-test-format """ from __future__ import annotations +import asyncio import binascii -import collections import copy -import datetime import functools import os import re import sys import time import traceback -import types -from collections import abc, defaultdict +from asyncio import iscoroutinefunction +from collections import defaultdict from test import ( IntegrationTest, client_context, client_knobs, unittest, ) -from test.helpers import ( - AWS_CREDS, - AWS_CREDS_2, - AZURE_CREDS, - CA_PEM, - CLIENT_PEM, - GCP_CREDS, - KMIP_CREDS, - LOCAL_MASTER_KEY, - client_knobs, +from test.unified_format_shared import ( + IS_INTERRUPTED, + KMS_TLS_OPTS, + PLACEHOLDER_MAP, + SKIP_CSOT_TESTS, + EventListenerUtil, + MatchEvaluatorUtil, + coerce_result, + parse_bulk_write_error_result, + parse_bulk_write_result, + parse_client_bulk_write_error_result, + parse_collection_or_database_options, + with_metaclass, ) from test.utils import ( - CMAPListener, camel_to_snake, camel_to_snake_args, get_pool, - parse_collection_options, parse_spec_options, prepare_spec_arguments, snake_to_camel, @@ -60,14 +60,12 @@ from test.utils import ( ) from test.utils_spec_runner import SpecRunnerThread from test.version import Version -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional import pymongo -from bson import SON, Code, DBRef, Decimal128, Int64, MaxKey, MinKey, json_util -from bson.binary import Binary +from bson import SON, json_util from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.objectid import ObjectId -from bson.regex import RE_TYPE, Regex from gridfs import GridFSBucket, GridOut from pymongo import ASCENDING, CursorType, MongoClient, _csot from pymongo.encryption_options import _HAVE_PYMONGOCRYPT @@ -83,55 +81,14 @@ from pymongo.errors import ( PyMongoError, ) from pymongo.monitoring import ( - _SENSITIVE_COMMANDS, - CommandFailedEvent, - CommandListener, CommandStartedEvent, - CommandSucceededEvent, - ConnectionCheckedInEvent, - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutStartedEvent, - ConnectionClosedEvent, - ConnectionCreatedEvent, - ConnectionReadyEvent, - PoolClearedEvent, - PoolClosedEvent, - PoolCreatedEvent, - PoolReadyEvent, - ServerClosedEvent, - ServerDescriptionChangedEvent, - ServerHeartbeatFailedEvent, - ServerHeartbeatListener, - ServerHeartbeatStartedEvent, - ServerHeartbeatSucceededEvent, - ServerListener, - ServerOpeningEvent, - TopologyClosedEvent, - TopologyDescriptionChangedEvent, - TopologyEvent, - TopologyListener, - TopologyOpenedEvent, - _CommandEvent, - _ConnectionEvent, - _PoolEvent, - _ServerEvent, - _ServerHeartbeatEvent, ) from pymongo.operations import ( - DeleteMany, - DeleteOne, - InsertOne, - ReplaceOne, SearchIndexModel, - UpdateMany, - UpdateOne, ) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, ClientBulkWriteResult from pymongo.server_api import ServerApi -from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection, writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.synchronous.change_stream import ChangeStream @@ -140,85 +97,12 @@ from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.database import Database from pymongo.synchronous.encryption import ClientEncryption +from pymongo.synchronous.helpers import next from pymongo.topology_description import TopologyDescription from pymongo.typings import _Address from pymongo.write_concern import WriteConcern -SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS") - -JSON_OPTS = json_util.JSONOptions(tz_aware=False) - -IS_INTERRUPTED = False - -KMS_TLS_OPTS = { - "kmip": { - "tlsCAFile": CA_PEM, - "tlsCertificateKeyFile": CLIENT_PEM, - } -} - - -# Build up a placeholder maps. -PLACEHOLDER_MAP = {} -for provider_name, provider_data in [ - ("local", {"key": LOCAL_MASTER_KEY}), - ("local:name1", {"key": LOCAL_MASTER_KEY}), - ("aws", AWS_CREDS), - ("aws:name1", AWS_CREDS), - ("aws:name2", AWS_CREDS_2), - ("azure", AZURE_CREDS), - ("azure:name1", AZURE_CREDS), - ("gcp", GCP_CREDS), - ("gcp:name1", GCP_CREDS), - ("kmip", KMIP_CREDS), - ("kmip:name1", KMIP_CREDS), -]: - for key, value in provider_data.items(): - placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}" - PLACEHOLDER_MAP[placeholder] = value - -OIDC_ENV = os.environ.get("OIDC_ENV", "test") -if OIDC_ENV == "test": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"} -elif OIDC_ENV == "azure": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { - "ENVIRONMENT": "azure", - "TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"], - } -elif OIDC_ENV == "gcp": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { - "ENVIRONMENT": "gcp", - "TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"], - } - - -def interrupt_loop(): - global IS_INTERRUPTED - IS_INTERRUPTED = True - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass. - - Vendored from six: https://github.com/benjaminp/six/blob/master/six.py - """ - - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - def __new__(cls, name, this_bases, d): - # __orig_bases__ is required by PEP 560. - resolved_bases = types.resolve_bases(bases) - if resolved_bases is not bases: - d["__orig_bases__"] = bases - return meta(name, resolved_bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - - return type.__new__(metaclass, "temporary_class", (), {}) +_IS_SYNC = True def is_run_on_requirement_satisfied(requirement): @@ -283,77 +167,6 @@ def is_run_on_requirement_satisfied(requirement): ) -def parse_collection_or_database_options(options): - return parse_collection_options(options) - - -def parse_bulk_write_result(result): - upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids} - return { - "deletedCount": result.deleted_count, - "insertedCount": result.inserted_count, - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": result.upserted_count, - "upsertedIds": upserted_ids, - } - - -def parse_client_bulk_write_individual(op_type, result): - if op_type == "insert": - return {"insertedId": result.inserted_id} - if op_type == "update": - if result.upserted_id: - return { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedId": result.upserted_id, - } - else: - return { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - } - if op_type == "delete": - return { - "deletedCount": result.deleted_count, - } - - -def parse_client_bulk_write_result(result): - insert_results, update_results, delete_results = {}, {}, {} - if result.has_verbose_results: - for idx, res in result.insert_results.items(): - insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res) - for idx, res in result.update_results.items(): - update_results[str(idx)] = parse_client_bulk_write_individual("update", res) - for idx, res in result.delete_results.items(): - delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res) - - return { - "deletedCount": result.deleted_count, - "insertedCount": result.inserted_count, - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": result.upserted_count, - "insertResults": insert_results, - "updateResults": update_results, - "deleteResults": delete_results, - } - - -def parse_bulk_write_error_result(error): - write_result = BulkWriteResult(error.details, True) - return parse_bulk_write_result(write_result) - - -def parse_client_bulk_write_error_result(error): - write_result = error.partial_result - if not write_result: - return None - return parse_client_bulk_write_result(write_result) - - class NonLazyCursor: """A find cursor proxy that creates the remote cursor when initialized.""" @@ -361,7 +174,16 @@ class NonLazyCursor: self.client = client self.find_cursor = find_cursor # Create the server side cursor. - self.first_result = next(find_cursor, None) + self.first_result = None + + @classmethod + def create(cls, find_cursor, client): + cursor = cls(find_cursor, client) + try: + cursor.first_result = next(cursor.find_cursor) + except StopIteration: + cursor.first_result = None + return cursor @property def alive(self): @@ -382,105 +204,6 @@ class NonLazyCursor: self.client = None -class EventListenerUtil( - CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener -): - def __init__( - self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map - ): - self._event_types = {name.lower() for name in observe_events} - if observe_sensitive_commands: - self._observe_sensitive_commands = True - self._ignore_commands = set(ignore_commands) - else: - self._observe_sensitive_commands = False - self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) - self._ignore_commands.add("configurefailpoint") - self._event_mapping = collections.defaultdict(list) - self.entity_map = entity_map - if store_events: - for i in store_events: - id = i["id"] - events = (i.lower() for i in i["events"]) - for i in events: - self._event_mapping[i].append(id) - self.entity_map[id] = [] - super().__init__() - - def get_events(self, event_type): - assert event_type in ("command", "cmap", "sdam", "all"), event_type - if event_type == "all": - return list(self.events) - if event_type == "command": - return [e for e in self.events if isinstance(e, _CommandEvent)] - if event_type == "cmap": - return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))] - return [ - e - for e in self.events - if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent)) - ] - - def add_event(self, event): - event_name = type(event).__name__.lower() - if event_name in self._event_types: - super().add_event(event) - for id in self._event_mapping[event_name]: - self.entity_map[id].append( - { - "name": type(event).__name__, - "observedAt": time.time(), - "description": repr(event), - } - ) - - def _command_event(self, event): - if event.command_name.lower() not in self._ignore_commands: - self.add_event(event) - - def started(self, event): - if isinstance(event, CommandStartedEvent): - if event.command == {}: - # Command is redacted. Observe only if flag is set. - if self._observe_sensitive_commands: - self._command_event(event) - else: - self._command_event(event) - else: - self.add_event(event) - - def succeeded(self, event): - if isinstance(event, CommandSucceededEvent): - if event.reply == {}: - # Command is redacted. Observe only if flag is set. - if self._observe_sensitive_commands: - self._command_event(event) - else: - self._command_event(event) - else: - self.add_event(event) - - def failed(self, event): - if isinstance(event, CommandFailedEvent): - self._command_event(event) - else: - self.add_event(event) - - def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None: - self.add_event(event) - - def description_changed( - self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent] - ) -> None: - self.add_event(event) - - def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None: - self.add_event(event) - - def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None: - self.add_event(event) - - class EntityMapUtil: """Utility class that implements an entity map as per the unified test format specification. @@ -692,353 +415,12 @@ class EntityMapUtil: def advance_cluster_times(self) -> None: """Manually synchronize entities when desired""" if not self._cluster_time: - self._cluster_time = self.test.client.admin.command("ping").get("$clusterTime") + self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime") for entity in self._entities.values(): if isinstance(entity, ClientSession) and self._cluster_time: entity.advance_cluster_time(self._cluster_time) -binary_types = (Binary, bytes) -long_types = (Int64,) -unicode_type = str - - -BSON_TYPE_ALIAS_MAP = { - # https://mongodb.com/docs/manual/reference/operator/query/type/ - # https://pymongo.readthedocs.io/en/stable/api/bson/index.html - "double": (float,), - "string": (str,), - "object": (abc.Mapping,), - "array": (abc.MutableSequence,), - "binData": binary_types, - "undefined": (type(None),), - "objectId": (ObjectId,), - "bool": (bool,), - "date": (datetime.datetime,), - "null": (type(None),), - "regex": (Regex, RE_TYPE), - "dbPointer": (DBRef,), - "javascript": (unicode_type, Code), - "symbol": (unicode_type,), - "javascriptWithScope": (unicode_type, Code), - "int": (int,), - "long": (Int64,), - "decimal": (Decimal128,), - "maxKey": (MaxKey,), - "minKey": (MinKey,), -} - - -class MatchEvaluatorUtil: - """Utility class that implements methods for evaluating matches as per - the unified test format specification. - """ - - def __init__(self, test_class): - self.test = test_class - - def _operation_exists(self, spec, actual, key_to_compare): - if spec is True: - if key_to_compare is None: - assert actual is not None - else: - self.test.assertIn(key_to_compare, actual) - elif spec is False: - if key_to_compare is None: - assert actual is None - else: - self.test.assertNotIn(key_to_compare, actual) - else: - self.test.fail(f"Expected boolean value for $$exists operator, got {spec}") - - def __type_alias_to_type(self, alias): - if alias not in BSON_TYPE_ALIAS_MAP: - self.test.fail(f"Unrecognized BSON type alias {alias}") - return BSON_TYPE_ALIAS_MAP[alias] - - def _operation_type(self, spec, actual, key_to_compare): - if isinstance(spec, abc.MutableSequence): - permissible_types = tuple( - [t for alias in spec for t in self.__type_alias_to_type(alias)] - ) - else: - permissible_types = self.__type_alias_to_type(spec) - value = actual[key_to_compare] if key_to_compare else actual - self.test.assertIsInstance(value, permissible_types) - - def _operation_matchesEntity(self, spec, actual, key_to_compare): - expected_entity = self.test.entity_map[spec] - self.test.assertEqual(expected_entity, actual[key_to_compare]) - - def _operation_matchesHexBytes(self, spec, actual, key_to_compare): - expected = binascii.unhexlify(spec) - value = actual[key_to_compare] if key_to_compare else actual - self.test.assertEqual(value, expected) - - def _operation_unsetOrMatches(self, spec, actual, key_to_compare): - if key_to_compare is None and not actual: - # top-level document can be None when unset - return - - if key_to_compare not in actual: - # we add a dummy value for the compared key to pass map size check - actual[key_to_compare] = "dummyValue" - return - self.match_result(spec, actual[key_to_compare], in_recursive_call=True) - - def _operation_sessionLsid(self, spec, actual, key_to_compare): - expected_lsid = self.test.entity_map.get_lsid_for_session(spec) - self.test.assertEqual(expected_lsid, actual[key_to_compare]) - - def _operation_lte(self, spec, actual, key_to_compare): - if key_to_compare not in actual: - self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}") - self.test.assertLessEqual(actual[key_to_compare], spec) - - def _operation_matchAsDocument(self, spec, actual, key_to_compare): - self._match_document(spec, json_util.loads(actual[key_to_compare]), False) - - def _operation_matchAsRoot(self, spec, actual, key_to_compare): - self._match_document(spec, actual, True) - - def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): - method_name = "_operation_{}".format(opname.strip("$")) - try: - method = getattr(self, method_name) - except AttributeError: - self.test.fail(f"Unsupported special matching operator {opname}") - else: - method(spec, actual, key_to_compare) - - def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None): - """Returns True if a special operation is evaluated, False - otherwise. If the ``expectation`` map contains a single key, - value pair we check it for a special operation. - If given, ``key_to_compare`` is assumed to be the key in - ``expectation`` whose corresponding value needs to be - evaluated for a possible special operation. ``key_to_compare`` - is ignored when ``expectation`` has only one key. - """ - if not isinstance(expectation, abc.Mapping): - return False - - is_special_op, opname, spec = False, False, False - - if key_to_compare is not None: - if key_to_compare.startswith("$$"): - is_special_op = True - opname = key_to_compare - spec = expectation[key_to_compare] - key_to_compare = None - else: - nested = expectation[key_to_compare] - if isinstance(nested, abc.Mapping) and len(nested) == 1: - opname, spec = next(iter(nested.items())) - if opname.startswith("$$"): - is_special_op = True - elif len(expectation) == 1: - opname, spec = next(iter(expectation.items())) - if opname.startswith("$$"): - is_special_op = True - key_to_compare = None - - if is_special_op: - self._evaluate_special_operation( - opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare - ) - return True - - return False - - def _match_document(self, expectation, actual, is_root, test=False): - if self._evaluate_if_special_operation(expectation, actual): - return - - self.test.assertIsInstance(actual, abc.Mapping) - for key, value in expectation.items(): - if self._evaluate_if_special_operation(expectation, actual, key): - continue - - self.test.assertIn(key, actual) - if not self.match_result(value, actual[key], in_recursive_call=True, test=test): - return False - - if not is_root: - expected_keys = set(expectation.keys()) - for key, value in expectation.items(): - if value == {"$$exists": False}: - expected_keys.remove(key) - if test: - self.test.assertEqual(expected_keys, set(actual.keys())) - else: - return set(expected_keys).issubset(set(actual.keys())) - return True - - def match_result(self, expectation, actual, in_recursive_call=False, test=True): - if isinstance(expectation, abc.Mapping): - return self._match_document( - expectation, actual, is_root=not in_recursive_call, test=test - ) - - if isinstance(expectation, abc.MutableSequence): - self.test.assertIsInstance(actual, abc.MutableSequence) - for e, a in zip(expectation, actual): - if isinstance(e, abc.Mapping): - self._match_document(e, a, is_root=not in_recursive_call, test=test) - else: - self.match_result(e, a, in_recursive_call=True, test=test) - return None - - # account for flexible numerics in element-wise comparison - if isinstance(expectation, int) or isinstance(expectation, float): - if test: - self.test.assertEqual(expectation, actual) - else: - return expectation == actual - return None - else: - if test: - self.test.assertIsInstance(actual, type(expectation)) - self.test.assertEqual(expectation, actual) - else: - return isinstance(actual, type(expectation)) and expectation == actual - return None - - def match_server_description(self, actual: ServerDescription, spec: dict) -> None: - for field, expected in spec.items(): - field = camel_to_snake(field) - if field == "type": - field = "server_type_name" - self.test.assertEqual(getattr(actual, field), expected) - - def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None: - for field, expected in spec.items(): - field = camel_to_snake(field) - if field == "type": - field = "topology_type_name" - self.test.assertEqual(getattr(actual, field), expected) - - def match_event_fields(self, actual: Any, spec: dict) -> None: - for field, expected in spec.items(): - if field == "command" and isinstance(actual, CommandStartedEvent): - command = spec["command"] - if command: - self.match_result(command, actual.command) - continue - if field == "reply" and isinstance(actual, CommandSucceededEvent): - reply = spec["reply"] - if reply: - self.match_result(reply, actual.reply) - continue - if field == "hasServiceId": - if spec["hasServiceId"]: - self.test.assertIsNotNone(actual.service_id) - self.test.assertIsInstance(actual.service_id, ObjectId) - else: - self.test.assertIsNone(actual.service_id) - continue - if field == "hasServerConnectionId": - if spec["hasServerConnectionId"]: - self.test.assertIsNotNone(actual.server_connection_id) - self.test.assertIsInstance(actual.server_connection_id, int) - else: - self.test.assertIsNone(actual.server_connection_id) - continue - if field in ("previousDescription", "newDescription"): - if isinstance(actual, ServerDescriptionChangedEvent): - self.match_server_description( - getattr(actual, camel_to_snake(field)), spec[field] - ) - continue - if isinstance(actual, TopologyDescriptionChangedEvent): - self.match_topology_description( - getattr(actual, camel_to_snake(field)), spec[field] - ) - continue - - if field == "interruptInUseConnections": - field = "interrupt_connections" - else: - field = camel_to_snake(field) - self.test.assertEqual(getattr(actual, field), expected) - - def match_event(self, expectation, actual): - name, spec = next(iter(expectation.items())) - if name == "commandStartedEvent": - self.test.assertIsInstance(actual, CommandStartedEvent) - elif name == "commandSucceededEvent": - self.test.assertIsInstance(actual, CommandSucceededEvent) - elif name == "commandFailedEvent": - self.test.assertIsInstance(actual, CommandFailedEvent) - elif name == "poolCreatedEvent": - self.test.assertIsInstance(actual, PoolCreatedEvent) - elif name == "poolReadyEvent": - self.test.assertIsInstance(actual, PoolReadyEvent) - elif name == "poolClearedEvent": - self.test.assertIsInstance(actual, PoolClearedEvent) - self.test.assertIsInstance(actual.interrupt_connections, bool) - elif name == "poolClosedEvent": - self.test.assertIsInstance(actual, PoolClosedEvent) - elif name == "connectionCreatedEvent": - self.test.assertIsInstance(actual, ConnectionCreatedEvent) - elif name == "connectionReadyEvent": - self.test.assertIsInstance(actual, ConnectionReadyEvent) - elif name == "connectionClosedEvent": - self.test.assertIsInstance(actual, ConnectionClosedEvent) - elif name == "connectionCheckOutStartedEvent": - self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) - elif name == "connectionCheckOutFailedEvent": - self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) - elif name == "connectionCheckedOutEvent": - self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) - elif name == "connectionCheckedInEvent": - self.test.assertIsInstance(actual, ConnectionCheckedInEvent) - elif name == "serverDescriptionChangedEvent": - self.test.assertIsInstance(actual, ServerDescriptionChangedEvent) - elif name == "serverHeartbeatStartedEvent": - self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent) - elif name == "serverHeartbeatSucceededEvent": - self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent) - elif name == "serverHeartbeatFailedEvent": - self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent) - elif name == "topologyDescriptionChangedEvent": - self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent) - elif name == "topologyOpeningEvent": - self.test.assertIsInstance(actual, TopologyOpenedEvent) - elif name == "topologyClosedEvent": - self.test.assertIsInstance(actual, TopologyClosedEvent) - else: - raise Exception(f"Unsupported event type {name}") - - self.match_event_fields(actual, spec) - - -def coerce_result(opname, result): - """Convert a pymongo result into the spec's result format.""" - if hasattr(result, "acknowledged") and not result.acknowledged: - return {"acknowledged": False} - if opname == "bulkWrite": - return parse_bulk_write_result(result) - if opname == "clientBulkWrite": - return parse_client_bulk_write_result(result) - if opname == "insertOne": - return {"insertedId": result.inserted_id} - if opname == "insertMany": - return dict(enumerate(result.inserted_ids)) - if opname in ("deleteOne", "deleteMany"): - return {"deletedCount": result.deleted_count} - if opname in ("updateOne", "updateMany", "replaceOne"): - value = { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": 0 if result.upserted_id is None else 1, - } - if result.upserted_id is not None: - value["upsertedId"] = result.upserted_id - return value - return result - - class UnifiedSpecTestMixinV1(IntegrationTest): """Mixin class to run test cases from test specification files. @@ -1090,9 +472,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest): db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - def setUpClass(cls): + def _setup_class(cls): # super call creates internal client cls.client - super().setUpClass() + super()._setup_class() # process file-level runOnRequirements run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) if not cls.should_run_on(run_on_spec): @@ -1125,11 +507,11 @@ class UnifiedSpecTestMixinV1(IntegrationTest): cls.knobs.enable() @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.knobs.disable() for client in cls.mongos_clients: client.close() - super().tearDownClass() + super()._tearDown_class() def setUp(self): super().setUp() @@ -1391,7 +773,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): def __entityOperation_aggregate(self, target, *args, **kwargs): self.__raise_if_unsupported("aggregate", target, Database, Collection) - return list(target.aggregate(*args, **kwargs)) + return (target.aggregate(*args, **kwargs)).to_list() def _databaseOperation_aggregate(self, target, *args, **kwargs): return self.__entityOperation_aggregate(target, *args, **kwargs) @@ -1402,13 +784,13 @@ class UnifiedSpecTestMixinV1(IntegrationTest): def _collectionOperation_find(self, target, *args, **kwargs): self.__raise_if_unsupported("find", target, Collection) find_cursor = target.find(*args, **kwargs) - return list(find_cursor) + return find_cursor.to_list() def _collectionOperation_createFindCursor(self, target, *args, **kwargs): self.__raise_if_unsupported("find", target, Collection) if "filter" not in kwargs: self.fail('createFindCursor requires a "filter" argument') - cursor = NonLazyCursor(target.find(*args, **kwargs), target.database.client) + cursor = NonLazyCursor.create(target.find(*args, **kwargs), target.database.client) self.addCleanup(cursor.close) return cursor @@ -1418,7 +800,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): def _collectionOperation_listIndexes(self, target, *args, **kwargs): if "batch_size" in kwargs: self.skipTest("PyMongo does not support batch_size for list_indexes") - return list(target.list_indexes(*args, **kwargs)) + return (target.list_indexes(*args, **kwargs)).to_list() def _collectionOperation_listIndexNames(self, target, *args, **kwargs): self.skipTest("PyMongo does not support list_index_names") @@ -1430,7 +812,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): def _collectionOperation_listSearchIndexes(self, target, *args, **kwargs): name = kwargs.get("name") agg_kwargs = kwargs.get("aggregation_options", dict()) - return list(target.list_search_indexes(name, **agg_kwargs)) + return (target.list_search_indexes(name, **agg_kwargs)).to_list() def _sessionOperation_withTransaction(self, target, *args, **kwargs): if client_context.storage_engine == "mmapv1": @@ -1470,7 +852,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): return target.create_data_key(*args, **kwargs) def _clientEncryptionOperation_getKeys(self, target, *args, **kwargs): - return list(target.get_keys(*args, **kwargs)) + return (target.get_keys(*args, **kwargs)).to_list() def _clientEncryptionOperation_deleteKey(self, target, *args, **kwargs): result = target.delete_key(*args, **kwargs) @@ -1516,7 +898,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): def _bucketOperation_find( self, target: GridFSBucket, *args: Any, **kwargs: Any ) -> List[GridOut]: - return list(target.find(*args, **kwargs)) + return target.find(*args, **kwargs).to_list() def run_entity_operation(self, spec): target = self.entity_map[spec["object"]] @@ -1849,7 +1231,10 @@ class UnifiedSpecTestMixinV1(IntegrationTest): except AttributeError: self.fail(f"Unsupported special test operation {opname}") else: - method(spec["arguments"]) + if iscoroutinefunction(method): + method(spec["arguments"]) + else: + method(spec["arguments"]) def run_operations(self, spec): for op in spec: @@ -1985,7 +1370,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): if expected_documents: sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"]) - actual_documents = list(coll.find({}, sort=[("_id", ASCENDING)])) + actual_documents = coll.find({}, sort=[("_id", ASCENDING)]).to_list() self.assertListEqual(sorted_expected_documents, actual_documents) def run_scenario(self, spec, uri=None): @@ -2040,7 +1425,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): # process initialData if "initialData" in self.TEST_SPEC: self.insert_initial_data(self.TEST_SPEC["initialData"]) - self._cluster_time = self.client.admin.command("ping").get("$clusterTime") + self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime") self.entity_map.advance_cluster_times() if "expectLogMessages" in spec: diff --git a/test/unified_format_shared.py b/test/unified_format_shared.py new file mode 100644 index 000000000..d11624476 --- /dev/null +++ b/test/unified_format_shared.py @@ -0,0 +1,679 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utility functions and constants for the unified test format runner. + +https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst +""" +from __future__ import annotations + +import binascii +import collections +import datetime +import os +import time +import types +from collections import abc +from test.helpers import ( + AWS_CREDS, + AWS_CREDS_2, + AZURE_CREDS, + CA_PEM, + CLIENT_PEM, + GCP_CREDS, + KMIP_CREDS, + LOCAL_MASTER_KEY, +) +from test.utils import CMAPListener, camel_to_snake, parse_collection_options +from typing import Any, Union + +from bson import ( + RE_TYPE, + Binary, + Code, + DBRef, + Decimal128, + Int64, + MaxKey, + MinKey, + ObjectId, + Regex, + json_util, +) +from pymongo.monitoring import ( + _SENSITIVE_COMMANDS, + CommandFailedEvent, + CommandListener, + CommandStartedEvent, + CommandSucceededEvent, + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, + ServerClosedEvent, + ServerDescriptionChangedEvent, + ServerHeartbeatFailedEvent, + ServerHeartbeatListener, + ServerHeartbeatStartedEvent, + ServerHeartbeatSucceededEvent, + ServerListener, + ServerOpeningEvent, + TopologyClosedEvent, + TopologyDescriptionChangedEvent, + TopologyEvent, + TopologyListener, + TopologyOpenedEvent, + _CommandEvent, + _ConnectionEvent, + _PoolEvent, + _ServerEvent, + _ServerHeartbeatEvent, +) +from pymongo.results import BulkWriteResult +from pymongo.server_description import ServerDescription +from pymongo.topology_description import TopologyDescription + +SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS") + +JSON_OPTS = json_util.JSONOptions(tz_aware=False) + +IS_INTERRUPTED = False + +KMS_TLS_OPTS = { + "kmip": { + "tlsCAFile": CA_PEM, + "tlsCertificateKeyFile": CLIENT_PEM, + } +} + + +# Build up a placeholder maps. +PLACEHOLDER_MAP = {} +for provider_name, provider_data in [ + ("local", {"key": LOCAL_MASTER_KEY}), + ("local:name1", {"key": LOCAL_MASTER_KEY}), + ("aws", AWS_CREDS), + ("aws:name1", AWS_CREDS), + ("aws:name2", AWS_CREDS_2), + ("azure", AZURE_CREDS), + ("azure:name1", AZURE_CREDS), + ("gcp", GCP_CREDS), + ("gcp:name1", GCP_CREDS), + ("kmip", KMIP_CREDS), + ("kmip:name1", KMIP_CREDS), +]: + for key, value in provider_data.items(): + placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}" + PLACEHOLDER_MAP[placeholder] = value + +OIDC_ENV = os.environ.get("OIDC_ENV", "test") +if OIDC_ENV == "test": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"} +elif OIDC_ENV == "azure": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"], + } +elif OIDC_ENV == "gcp": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { + "ENVIRONMENT": "gcp", + "TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"], + } + + +def interrupt_loop(): + global IS_INTERRUPTED + IS_INTERRUPTED = True + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass. + + Vendored from six: https://github.com/benjaminp/six/blob/master/six.py + """ + + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + def __new__(cls, name, this_bases, d): + # __orig_bases__ is required by PEP 560. + resolved_bases = types.resolve_bases(bases) + if resolved_bases is not bases: + d["__orig_bases__"] = bases + return meta(name, resolved_bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + + return type.__new__(metaclass, "temporary_class", (), {}) + + +def parse_collection_or_database_options(options): + return parse_collection_options(options) + + +def parse_bulk_write_result(result): + upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids} + return { + "deletedCount": result.deleted_count, + "insertedCount": result.inserted_count, + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": result.upserted_count, + "upsertedIds": upserted_ids, + } + + +def parse_client_bulk_write_individual(op_type, result): + if op_type == "insert": + return {"insertedId": result.inserted_id} + if op_type == "update": + if result.upserted_id: + return { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedId": result.upserted_id, + } + else: + return { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + } + if op_type == "delete": + return { + "deletedCount": result.deleted_count, + } + + +def parse_client_bulk_write_result(result): + insert_results, update_results, delete_results = {}, {}, {} + if result.has_verbose_results: + for idx, res in result.insert_results.items(): + insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res) + for idx, res in result.update_results.items(): + update_results[str(idx)] = parse_client_bulk_write_individual("update", res) + for idx, res in result.delete_results.items(): + delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res) + + return { + "deletedCount": result.deleted_count, + "insertedCount": result.inserted_count, + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": result.upserted_count, + "insertResults": insert_results, + "updateResults": update_results, + "deleteResults": delete_results, + } + + +def parse_bulk_write_error_result(error): + write_result = BulkWriteResult(error.details, True) + return parse_bulk_write_result(write_result) + + +def parse_client_bulk_write_error_result(error): + write_result = error.partial_result + if not write_result: + return None + return parse_client_bulk_write_result(write_result) + + +class EventListenerUtil( + CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener +): + def __init__( + self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map + ): + self._event_types = {name.lower() for name in observe_events} + if observe_sensitive_commands: + self._observe_sensitive_commands = True + self._ignore_commands = set(ignore_commands) + else: + self._observe_sensitive_commands = False + self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) + self._ignore_commands.add("configurefailpoint") + self._event_mapping = collections.defaultdict(list) + self.entity_map = entity_map + if store_events: + for i in store_events: + id = i["id"] + events = (i.lower() for i in i["events"]) + for i in events: + self._event_mapping[i].append(id) + self.entity_map[id] = [] + super().__init__() + + def get_events(self, event_type): + assert event_type in ("command", "cmap", "sdam", "all"), event_type + if event_type == "all": + return list(self.events) + if event_type == "command": + return [e for e in self.events if isinstance(e, _CommandEvent)] + if event_type == "cmap": + return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))] + return [ + e + for e in self.events + if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent)) + ] + + def add_event(self, event): + event_name = type(event).__name__.lower() + if event_name in self._event_types: + super().add_event(event) + for id in self._event_mapping[event_name]: + self.entity_map[id].append( + { + "name": type(event).__name__, + "observedAt": time.time(), + "description": repr(event), + } + ) + + def _command_event(self, event): + if event.command_name.lower() not in self._ignore_commands: + self.add_event(event) + + def started(self, event): + if isinstance(event, CommandStartedEvent): + if event.command == {}: + # Command is redacted. Observe only if flag is set. + if self._observe_sensitive_commands: + self._command_event(event) + else: + self._command_event(event) + else: + self.add_event(event) + + def succeeded(self, event): + if isinstance(event, CommandSucceededEvent): + if event.reply == {}: + # Command is redacted. Observe only if flag is set. + if self._observe_sensitive_commands: + self._command_event(event) + else: + self._command_event(event) + else: + self.add_event(event) + + def failed(self, event): + if isinstance(event, CommandFailedEvent): + self._command_event(event) + else: + self.add_event(event) + + def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None: + self.add_event(event) + + def description_changed( + self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent] + ) -> None: + self.add_event(event) + + def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None: + self.add_event(event) + + def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None: + self.add_event(event) + + +binary_types = (Binary, bytes) +long_types = (Int64,) +unicode_type = str + + +BSON_TYPE_ALIAS_MAP = { + # https://mongodb.com/docs/manual/reference/operator/query/type/ + # https://pymongo.readthedocs.io/en/stable/api/bson/index.html + "double": (float,), + "string": (str,), + "object": (abc.Mapping,), + "array": (abc.MutableSequence,), + "binData": binary_types, + "undefined": (type(None),), + "objectId": (ObjectId,), + "bool": (bool,), + "date": (datetime.datetime,), + "null": (type(None),), + "regex": (Regex, RE_TYPE), + "dbPointer": (DBRef,), + "javascript": (unicode_type, Code), + "symbol": (unicode_type,), + "javascriptWithScope": (unicode_type, Code), + "int": (int,), + "long": (Int64,), + "decimal": (Decimal128,), + "maxKey": (MaxKey,), + "minKey": (MinKey,), +} + + +class MatchEvaluatorUtil: + """Utility class that implements methods for evaluating matches as per + the unified test format specification. + """ + + def __init__(self, test_class): + self.test = test_class + + def _operation_exists(self, spec, actual, key_to_compare): + if spec is True: + if key_to_compare is None: + assert actual is not None + else: + self.test.assertIn(key_to_compare, actual) + elif spec is False: + if key_to_compare is None: + assert actual is None + else: + self.test.assertNotIn(key_to_compare, actual) + else: + self.test.fail(f"Expected boolean value for $$exists operator, got {spec}") + + def __type_alias_to_type(self, alias): + if alias not in BSON_TYPE_ALIAS_MAP: + self.test.fail(f"Unrecognized BSON type alias {alias}") + return BSON_TYPE_ALIAS_MAP[alias] + + def _operation_type(self, spec, actual, key_to_compare): + if isinstance(spec, abc.MutableSequence): + permissible_types = tuple( + [t for alias in spec for t in self.__type_alias_to_type(alias)] + ) + else: + permissible_types = self.__type_alias_to_type(spec) + value = actual[key_to_compare] if key_to_compare else actual + self.test.assertIsInstance(value, permissible_types) + + def _operation_matchesEntity(self, spec, actual, key_to_compare): + expected_entity = self.test.entity_map[spec] + self.test.assertEqual(expected_entity, actual[key_to_compare]) + + def _operation_matchesHexBytes(self, spec, actual, key_to_compare): + expected = binascii.unhexlify(spec) + value = actual[key_to_compare] if key_to_compare else actual + self.test.assertEqual(value, expected) + + def _operation_unsetOrMatches(self, spec, actual, key_to_compare): + if key_to_compare is None and not actual: + # top-level document can be None when unset + return + + if key_to_compare not in actual: + # we add a dummy value for the compared key to pass map size check + actual[key_to_compare] = "dummyValue" + return + self.match_result(spec, actual[key_to_compare], in_recursive_call=True) + + def _operation_sessionLsid(self, spec, actual, key_to_compare): + expected_lsid = self.test.entity_map.get_lsid_for_session(spec) + self.test.assertEqual(expected_lsid, actual[key_to_compare]) + + def _operation_lte(self, spec, actual, key_to_compare): + if key_to_compare not in actual: + self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}") + self.test.assertLessEqual(actual[key_to_compare], spec) + + def _operation_matchAsDocument(self, spec, actual, key_to_compare): + self._match_document(spec, json_util.loads(actual[key_to_compare]), False) + + def _operation_matchAsRoot(self, spec, actual, key_to_compare): + self._match_document(spec, actual, True) + + def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): + method_name = "_operation_{}".format(opname.strip("$")) + try: + method = getattr(self, method_name) + except AttributeError: + self.test.fail(f"Unsupported special matching operator {opname}") + else: + method(spec, actual, key_to_compare) + + def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None): + """Returns True if a special operation is evaluated, False + otherwise. If the ``expectation`` map contains a single key, + value pair we check it for a special operation. + If given, ``key_to_compare`` is assumed to be the key in + ``expectation`` whose corresponding value needs to be + evaluated for a possible special operation. ``key_to_compare`` + is ignored when ``expectation`` has only one key. + """ + if not isinstance(expectation, abc.Mapping): + return False + + is_special_op, opname, spec = False, False, False + + if key_to_compare is not None: + if key_to_compare.startswith("$$"): + is_special_op = True + opname = key_to_compare + spec = expectation[key_to_compare] + key_to_compare = None + else: + nested = expectation[key_to_compare] + if isinstance(nested, abc.Mapping) and len(nested) == 1: + opname, spec = next(iter(nested.items())) + if opname.startswith("$$"): + is_special_op = True + elif len(expectation) == 1: + opname, spec = next(iter(expectation.items())) + if opname.startswith("$$"): + is_special_op = True + key_to_compare = None + + if is_special_op: + self._evaluate_special_operation( + opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare + ) + return True + + return False + + def _match_document(self, expectation, actual, is_root, test=False): + if self._evaluate_if_special_operation(expectation, actual): + return + + self.test.assertIsInstance(actual, abc.Mapping) + for key, value in expectation.items(): + if self._evaluate_if_special_operation(expectation, actual, key): + continue + + self.test.assertIn(key, actual) + if not self.match_result(value, actual[key], in_recursive_call=True, test=test): + return False + + if not is_root: + expected_keys = set(expectation.keys()) + for key, value in expectation.items(): + if value == {"$$exists": False}: + expected_keys.remove(key) + if test: + self.test.assertEqual(expected_keys, set(actual.keys())) + else: + return set(expected_keys).issubset(set(actual.keys())) + return True + + def match_result(self, expectation, actual, in_recursive_call=False, test=True): + if isinstance(expectation, abc.Mapping): + return self._match_document( + expectation, actual, is_root=not in_recursive_call, test=test + ) + + if isinstance(expectation, abc.MutableSequence): + self.test.assertIsInstance(actual, abc.MutableSequence) + for e, a in zip(expectation, actual): + if isinstance(e, abc.Mapping): + self._match_document(e, a, is_root=not in_recursive_call, test=test) + else: + self.match_result(e, a, in_recursive_call=True, test=test) + return None + + # account for flexible numerics in element-wise comparison + if isinstance(expectation, int) or isinstance(expectation, float): + if test: + self.test.assertEqual(expectation, actual) + else: + return expectation == actual + return None + else: + if test: + self.test.assertIsInstance(actual, type(expectation)) + self.test.assertEqual(expectation, actual) + else: + return isinstance(actual, type(expectation)) and expectation == actual + return None + + def match_server_description(self, actual: ServerDescription, spec: dict) -> None: + for field, expected in spec.items(): + field = camel_to_snake(field) + if field == "type": + field = "server_type_name" + self.test.assertEqual(getattr(actual, field), expected) + + def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None: + for field, expected in spec.items(): + field = camel_to_snake(field) + if field == "type": + field = "topology_type_name" + self.test.assertEqual(getattr(actual, field), expected) + + def match_event_fields(self, actual: Any, spec: dict) -> None: + for field, expected in spec.items(): + if field == "command" and isinstance(actual, CommandStartedEvent): + command = spec["command"] + if command: + self.match_result(command, actual.command) + continue + if field == "reply" and isinstance(actual, CommandSucceededEvent): + reply = spec["reply"] + if reply: + self.match_result(reply, actual.reply) + continue + if field == "hasServiceId": + if spec["hasServiceId"]: + self.test.assertIsNotNone(actual.service_id) + self.test.assertIsInstance(actual.service_id, ObjectId) + else: + self.test.assertIsNone(actual.service_id) + continue + if field == "hasServerConnectionId": + if spec["hasServerConnectionId"]: + self.test.assertIsNotNone(actual.server_connection_id) + self.test.assertIsInstance(actual.server_connection_id, int) + else: + self.test.assertIsNone(actual.server_connection_id) + continue + if field in ("previousDescription", "newDescription"): + if isinstance(actual, ServerDescriptionChangedEvent): + self.match_server_description( + getattr(actual, camel_to_snake(field)), spec[field] + ) + continue + if isinstance(actual, TopologyDescriptionChangedEvent): + self.match_topology_description( + getattr(actual, camel_to_snake(field)), spec[field] + ) + continue + + if field == "interruptInUseConnections": + field = "interrupt_connections" + else: + field = camel_to_snake(field) + self.test.assertEqual(getattr(actual, field), expected) + + def match_event(self, expectation, actual): + name, spec = next(iter(expectation.items())) + if name == "commandStartedEvent": + self.test.assertIsInstance(actual, CommandStartedEvent) + elif name == "commandSucceededEvent": + self.test.assertIsInstance(actual, CommandSucceededEvent) + elif name == "commandFailedEvent": + self.test.assertIsInstance(actual, CommandFailedEvent) + elif name == "poolCreatedEvent": + self.test.assertIsInstance(actual, PoolCreatedEvent) + elif name == "poolReadyEvent": + self.test.assertIsInstance(actual, PoolReadyEvent) + elif name == "poolClearedEvent": + self.test.assertIsInstance(actual, PoolClearedEvent) + self.test.assertIsInstance(actual.interrupt_connections, bool) + elif name == "poolClosedEvent": + self.test.assertIsInstance(actual, PoolClosedEvent) + elif name == "connectionCreatedEvent": + self.test.assertIsInstance(actual, ConnectionCreatedEvent) + elif name == "connectionReadyEvent": + self.test.assertIsInstance(actual, ConnectionReadyEvent) + elif name == "connectionClosedEvent": + self.test.assertIsInstance(actual, ConnectionClosedEvent) + elif name == "connectionCheckOutStartedEvent": + self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) + elif name == "connectionCheckOutFailedEvent": + self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) + elif name == "connectionCheckedOutEvent": + self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) + elif name == "connectionCheckedInEvent": + self.test.assertIsInstance(actual, ConnectionCheckedInEvent) + elif name == "serverDescriptionChangedEvent": + self.test.assertIsInstance(actual, ServerDescriptionChangedEvent) + elif name == "serverHeartbeatStartedEvent": + self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent) + elif name == "serverHeartbeatSucceededEvent": + self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent) + elif name == "serverHeartbeatFailedEvent": + self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent) + elif name == "topologyDescriptionChangedEvent": + self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent) + elif name == "topologyOpeningEvent": + self.test.assertIsInstance(actual, TopologyOpenedEvent) + elif name == "topologyClosedEvent": + self.test.assertIsInstance(actual, TopologyClosedEvent) + else: + raise Exception(f"Unsupported event type {name}") + + self.match_event_fields(actual, spec) + + +def coerce_result(opname, result): + """Convert a pymongo result into the spec's result format.""" + if hasattr(result, "acknowledged") and not result.acknowledged: + return {"acknowledged": False} + if opname == "bulkWrite": + return parse_bulk_write_result(result) + if opname == "clientBulkWrite": + return parse_client_bulk_write_result(result) + if opname == "insertOne": + return {"insertedId": result.inserted_id} + if opname == "insertMany": + return dict(enumerate(result.inserted_ids)) + if opname in ("deleteOne", "deleteMany"): + return {"deletedCount": result.deleted_count} + if opname in ("updateOne", "updateMany", "replaceOne"): + value = { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": 0 if result.upserted_id is None else 1, + } + if result.upserted_id is not None: + value["upsertedId"] = result.upserted_id + return value + return result diff --git a/test/utils.py b/test/utils.py index 9c78cff3a..4575a9fe1 100644 --- a/test/utils.py +++ b/test/utils.py @@ -418,153 +418,6 @@ class FunctionCallRecorder: return len(self._call_list) -class SpecTestCreator: - """Class to create test cases from specifications.""" - - def __init__(self, create_test, test_class, test_path): - """Create a TestCreator object. - - :Parameters: - - `create_test`: callback that returns a test case. The callback - must accept the following arguments - a dictionary containing the - entire test specification (the `scenario_def`), a dictionary - containing the specification for which the test case will be - generated (the `test_def`). - - `test_class`: the unittest.TestCase class in which to create the - test case. - - `test_path`: path to the directory containing the JSON files with - the test specifications. - """ - self._create_test = create_test - self._test_class = test_class - self.test_path = test_path - - def _ensure_min_max_server_version(self, scenario_def, method): - """Test modifier that enforces a version range for the server on a - test case. - """ - if "minServerVersion" in scenario_def: - min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) - if min_ver is not None: - method = client_context.require_version_min(*min_ver)(method) - - if "maxServerVersion" in scenario_def: - max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) - if max_ver is not None: - method = client_context.require_version_max(*max_ver)(method) - - if "serverless" in scenario_def: - serverless = scenario_def["serverless"] - if serverless == "require": - serverless_satisfied = client_context.serverless - elif serverless == "forbid": - serverless_satisfied = not client_context.serverless - else: # unset or "allow" - serverless_satisfied = True - method = unittest.skipUnless( - serverless_satisfied, "Serverless requirement not satisfied" - )(method) - - return method - - @staticmethod - def valid_topology(run_on_req): - return client_context.is_topology_type( - run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) - ) - - @staticmethod - def min_server_version(run_on_req): - version = run_on_req.get("minServerVersion") - if version: - min_ver = tuple(int(elt) for elt in version.split(".")) - return client_context.version >= min_ver - return True - - @staticmethod - def max_server_version(run_on_req): - version = run_on_req.get("maxServerVersion") - if version: - max_ver = tuple(int(elt) for elt in version.split(".")) - return client_context.version <= max_ver - return True - - @staticmethod - def valid_auth_enabled(run_on_req): - if "authEnabled" in run_on_req: - if run_on_req["authEnabled"]: - return client_context.auth_enabled - return not client_context.auth_enabled - return True - - @staticmethod - def serverless_ok(run_on_req): - serverless = run_on_req["serverless"] - if serverless == "require": - return client_context.serverless - elif serverless == "forbid": - return not client_context.serverless - else: # unset or "allow" - return True - - def should_run_on(self, scenario_def): - run_on = scenario_def.get("runOn", []) - if not run_on: - # Always run these tests. - return True - - for req in run_on: - if ( - self.valid_topology(req) - and self.min_server_version(req) - and self.max_server_version(req) - and self.valid_auth_enabled(req) - and self.serverless_ok(req) - ): - return True - return False - - def ensure_run_on(self, scenario_def, method): - """Test modifier that enforces a 'runOn' on a test case.""" - return client_context._require( - lambda: self.should_run_on(scenario_def), "runOn not satisfied", method - ) - - def tests(self, scenario_def): - """Allow CMAP spec test to override the location of test.""" - return scenario_def["tests"] - - def create_tests(self): - for dirpath, _, filenames in os.walk(self.test_path): - dirname = os.path.split(dirpath)[-1] - - for filename in filenames: - with open(os.path.join(dirpath, filename)) as scenario_stream: - # Use tz_aware=False to match how CodecOptions decodes - # dates. - opts = json_util.JSONOptions(tz_aware=False) - scenario_def = ScenarioDict( - json_util.loads(scenario_stream.read(), json_options=opts) - ) - - test_type = os.path.splitext(filename)[0] - - # Construct test from scenario. - for test_def in self.tests(scenario_def): - test_name = "test_{}_{}_{}".format( - dirname, - test_type.replace("-", "_").replace(".", "_"), - str(test_def["description"].replace(" ", "_").replace(".", "_")), - ) - - new_test = self._create_test(scenario_def, test_def, test_name) - new_test = self._ensure_min_max_server_version(scenario_def, new_test) - new_test = self.ensure_run_on(scenario_def, new_test) - - new_test.__name__ = test_name - setattr(self._test_class, new_test.__name__, new_test) - - def ensure_all_connected(client: MongoClient) -> None: """Ensure that the client's connection pool has socket connections to all members of a replica set. Raises ConfigurationError when called with a diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 06a40351c..8a061de0b 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -15,8 +15,12 @@ """Utilities for testing driver specs.""" from __future__ import annotations +import asyncio import functools +import os import threading +import unittest +from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs from test.utils import ( @@ -24,6 +28,7 @@ from test.utils import ( CompareType, EventListener, OvertCommandListener, + ScenarioDict, ServerAndTopologyEventListener, camel_to_snake, camel_to_snake_args, @@ -32,11 +37,12 @@ from test.utils import ( ) from typing import List -from bson import ObjectId, decode, encode +from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket +from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference @@ -83,6 +89,161 @@ class SpecRunnerThread(threading.Thread): self.stop() +class SpecTestCreator: + """Class to create test cases from specifications.""" + + def __init__(self, create_test, test_class, test_path): + """Create a TestCreator object. + + :Parameters: + - `create_test`: callback that returns a test case. The callback + must accept the following arguments - a dictionary containing the + entire test specification (the `scenario_def`), a dictionary + containing the specification for which the test case will be + generated (the `test_def`). + - `test_class`: the unittest.TestCase class in which to create the + test case. + - `test_path`: path to the directory containing the JSON files with + the test specifications. + """ + self._create_test = create_test + self._test_class = test_class + self.test_path = test_path + + def _ensure_min_max_server_version(self, scenario_def, method): + """Test modifier that enforces a version range for the server on a + test case. + """ + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) + if min_ver is not None: + method = client_context.require_version_min(*min_ver)(method) + + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) + if max_ver is not None: + method = client_context.require_version_max(*max_ver)(method) + + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] + if serverless == "require": + serverless_satisfied = client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + method = unittest.skipUnless( + serverless_satisfied, "Serverless requirement not satisfied" + )(method) + + return method + + @staticmethod + def valid_topology(run_on_req): + return client_context.is_topology_type( + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) + + @staticmethod + def min_server_version(run_on_req): + version = run_on_req.get("minServerVersion") + if version: + min_ver = tuple(int(elt) for elt in version.split(".")) + return client_context.version >= min_ver + return True + + @staticmethod + def max_server_version(run_on_req): + version = run_on_req.get("maxServerVersion") + if version: + max_ver = tuple(int(elt) for elt in version.split(".")) + return client_context.version <= max_ver + return True + + @staticmethod + def valid_auth_enabled(run_on_req): + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: + return client_context.auth_enabled + return not client_context.auth_enabled + return True + + @staticmethod + def serverless_ok(run_on_req): + serverless = run_on_req["serverless"] + if serverless == "require": + return client_context.serverless + elif serverless == "forbid": + return not client_context.serverless + else: # unset or "allow" + return True + + def should_run_on(self, scenario_def): + run_on = scenario_def.get("runOn", []) + if not run_on: + # Always run these tests. + return True + + for req in run_on: + if ( + self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): + return True + return False + + def ensure_run_on(self, scenario_def, method): + """Test modifier that enforces a 'runOn' on a test case.""" + + def predicate(): + return self.should_run_on(scenario_def) + + return client_context._require(predicate, "runOn not satisfied", method) + + def tests(self, scenario_def): + """Allow CMAP spec test to override the location of test.""" + return scenario_def["tests"] + + def _create_tests(self): + for dirpath, _, filenames in os.walk(self.test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = ScenarioDict( + json_util.loads(scenario_stream.read(), json_options=opts) + ) + + test_type = os.path.splitext(filename)[0] + + # Construct test from scenario. + for test_def in self.tests(scenario_def): + test_name = "test_{}_{}_{}".format( + dirname, + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) + + new_test = self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) + + new_test.__name__ = test_name + setattr(self._test_class, new_test.__name__, new_test) + + def create_tests(self): + if _IS_SYNC: + self._create_tests() + else: + asyncio.run(self._create_tests()) + + class SpecRunner(IntegrationTest): mongos_clients: List knobs: client_knobs @@ -312,7 +473,10 @@ class SpecRunner(IntegrationTest): args.update(arguments) arguments = args - result = cmd(**dict(arguments)) + if not _IS_SYNC and iscoroutinefunction(cmd): + result = cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addCleanup(result.close) @@ -583,7 +747,7 @@ class SpecRunner(IntegrationTest): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list() + actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/tools/synchro.py b/tools/synchro.py index 5ce83cfbe..b6812e9be 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -105,6 +105,8 @@ replacements = { "PyMongo|c|async": "PyMongo|c", "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", + "AsyncTestSpec": "TestSpec", + "AsyncSpecTestCreator": "SpecTestCreator", "async_set_fail_point": "set_fail_point", "async_ensure_all_connected": "ensure_all_connected", "async_repl_set_step_down": "repl_set_step_down", @@ -188,8 +190,14 @@ converted_tests = [ "test_client.py", "test_client_bulk_write.py", "test_client_context.py", + "test_collation.py", "test_collection.py", + "test_command_logging.py", + "test_command_monitoring.py", + "test_common.py", + "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", + "test_crud_unified.py", "test_cursor.py", "test_database.py", "test_encryption.py", @@ -201,6 +209,7 @@ converted_tests = [ "test_retryable_writes.py", "test_session.py", "test_transactions.py", + "unified_format.py", ] sync_test_files = [