From 1b3dea3f03185c5cf947e9966cf472aa1a900d5d Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 15 Jul 2024 16:45:59 -0700 Subject: [PATCH] PYTHON-4533 - Convert test/test_transactions.py to async (#1732) --- gridfs/asynchronous/grid_file.py | 2 +- pymongo/asynchronous/client_session.py | 5 +- test/asynchronous/test_transactions.py | 585 +++++++++++++++++++++ test/asynchronous/utils_spec_runner.py | 691 +++++++++++++++++++++++++ test/test_transactions.py | 35 +- test/utils_spec_runner.py | 14 +- tools/synchro.py | 4 + 7 files changed, 1311 insertions(+), 25 deletions(-) create mode 100644 test/asynchronous/test_transactions.py create mode 100644 test/asynchronous/utils_spec_runner.py diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 9546429a3..da9bc15b5 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -154,7 +154,7 @@ class AsyncGridFS: """ async with AsyncGridIn(self._collection, **kwargs) as grid_file: await grid_file.write(data) - return await grid_file._id + return grid_file._id async def get(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> AsyncGridOut: """Get a file from GridFS by ``"_id"``. diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 4f2f2d97d..9271445e5 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -144,6 +144,7 @@ from typing import ( Any, AsyncContextManager, Callable, + Coroutine, Mapping, MutableMapping, NoReturn, @@ -598,7 +599,7 @@ class AsyncClientSession: async def with_transaction( self, - callback: Callable[[AsyncClientSession], _T], + callback: Callable[[AsyncClientSession], Coroutine[Any, Any, _T]], read_concern: Optional[ReadConcern] = None, write_concern: Optional[WriteConcern] = None, read_preference: Optional[_ServerMode] = None, @@ -693,7 +694,7 @@ class AsyncClientSession: read_concern, write_concern, read_preference, max_commit_time_ms ) try: - ret = callback(self) + ret = await callback(self) except Exception as exc: if self.in_transaction: await self.abort_transaction() diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py new file mode 100644 index 000000000..db208835b --- /dev/null +++ b/test/asynchronous/test_transactions.py @@ -0,0 +1,585 @@ +# Copyright 2018-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Execute Transactions Spec tests.""" +from __future__ import annotations + +import os +import sys +from io import BytesIO + +from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket + +sys.path[0:0] = [""] + +from test.asynchronous import async_client_context, unittest +from test.asynchronous.utils_spec_runner import AsyncSpecRunner +from test.utils import ( + OvertCommandListener, + async_rs_client, + async_single_client, + wait_until, +) +from typing import List + +from bson import encode +from bson.raw_bson import RawBSONDocument +from pymongo import WriteConcern +from pymongo.asynchronous import client_session +from pymongo.asynchronous.client_session import TransactionOptions +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.cursor import AsyncCursor +from pymongo.asynchronous.helpers import anext +from pymongo.errors import ( + CollectionInvalid, + ConfigurationError, + ConnectionFailure, + InvalidOperation, + OperationFailure, +) +from pymongo.operations import IndexModel, InsertOne +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False + +_TXN_TESTS_DEBUG = os.environ.get("TRANSACTION_TESTS_DEBUG") + +# Max number of operations to perform after a transaction to prove unpinning +# occurs. Chosen so that there's a low false positive rate. With 2 mongoses, +# 50 attempts yields a one in a quadrillion chance of a false positive +# (1/(0.5^50)). +UNPIN_TEST_MAX_ATTEMPTS = 50 + + +class AsyncTransactionsBase(AsyncSpecRunner): + @classmethod + async def _setup_class(cls): + await super()._setup_class() + if async_client_context.supports_transactions(): + for address in async_client_context.mongoses: + cls.mongos_clients.append(await async_single_client("{}:{}".format(*address))) + + @classmethod + async def _tearDown_class(cls): + for client in cls.mongos_clients: + await client.aclose() + await super()._tearDown_class() + + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + if ( + "secondary" in self.id() + and not async_client_context.is_mongos + and not async_client_context.has_secondaries + ): + raise unittest.SkipTest("No secondaries") + + +class TestTransactions(AsyncTransactionsBase): + RUN_ON_SERVERLESS = True + + @async_client_context.require_transactions + def test_transaction_options_validation(self): + default_options = TransactionOptions() + self.assertIsNone(default_options.read_concern) + self.assertIsNone(default_options.write_concern) + self.assertIsNone(default_options.read_preference) + self.assertIsNone(default_options.max_commit_time_ms) + # No error when valid options are provided. + TransactionOptions( + read_concern=ReadConcern(), + write_concern=WriteConcern(), + read_preference=ReadPreference.PRIMARY, + max_commit_time_ms=10000, + ) + with self.assertRaisesRegex(TypeError, "read_concern must be "): + TransactionOptions(read_concern={}) # type: ignore + with self.assertRaisesRegex(TypeError, "write_concern must be "): + TransactionOptions(write_concern={}) # type: ignore + with self.assertRaisesRegex( + ConfigurationError, "transactions do not support unacknowledged write concern" + ): + TransactionOptions(write_concern=WriteConcern(w=0)) + with self.assertRaisesRegex(TypeError, "is not valid for read_preference"): + TransactionOptions(read_preference={}) # type: ignore + with self.assertRaisesRegex(TypeError, "max_commit_time_ms must be an integer or None"): + TransactionOptions(max_commit_time_ms="10000") # type: ignore + + @async_client_context.require_transactions + async def test_transaction_write_concern_override(self): + """Test txn overrides Client/Database/Collection write_concern.""" + client = await async_rs_client(w=0) + self.addAsyncCleanup(client.aclose) + db = client.test + coll = db.test + await coll.insert_one({}) + async with client.start_session() as s: + async with await s.start_transaction(write_concern=WriteConcern(w=1)): + self.assertTrue((await coll.insert_one({}, session=s)).acknowledged) + self.assertTrue((await coll.insert_many([{}, {}], session=s)).acknowledged) + self.assertTrue((await coll.bulk_write([InsertOne({})], session=s)).acknowledged) + self.assertTrue((await coll.replace_one({}, {}, session=s)).acknowledged) + self.assertTrue( + (await coll.update_one({}, {"$set": {"a": 1}}, session=s)).acknowledged + ) + self.assertTrue( + (await coll.update_many({}, {"$set": {"a": 1}}, session=s)).acknowledged + ) + self.assertTrue((await coll.delete_one({}, session=s)).acknowledged) + self.assertTrue((await coll.delete_many({}, session=s)).acknowledged) + await coll.find_one_and_delete({}, session=s) + await coll.find_one_and_replace({}, {}, session=s) + await coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s) + + unsupported_txn_writes: list = [ + (client.drop_database, [db.name], {}), + (db.drop_collection, ["collection"], {}), + (coll.drop, [], {}), + (coll.rename, ["collection2"], {}), + # Drop collection2 between tests of "rename", above. + (coll.database.drop_collection, ["collection2"], {}), + (coll.create_indexes, [[IndexModel("a")]], {}), + (coll.create_index, ["a"], {}), + (coll.drop_index, ["a_1"], {}), + (coll.drop_indexes, [], {}), + (coll.aggregate, [[{"$out": "aggout"}]], {}), + ] + # Creating a collection in a transaction requires MongoDB 4.4+. + if async_client_context.version < (4, 3, 4): + unsupported_txn_writes.extend( + [ + (db.create_collection, ["collection"], {}), + ] + ) + + for op in unsupported_txn_writes: + op, args, kwargs = op + async with client.start_session() as s: + kwargs["session"] = s + await s.start_transaction(write_concern=WriteConcern(w=1)) + with self.assertRaises(OperationFailure): + await op(*args, **kwargs) + await s.abort_transaction() + + @async_client_context.require_transactions + @async_client_context.require_multiple_mongoses + async def test_unpin_for_next_transaction(self): + # Increase localThresholdMS and wait until both nodes are discovered + # to avoid false positives. + client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000) + wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") + coll = client.test.test + # Create the collection. + await coll.insert_one({}) + self.addAsyncCleanup(client.aclose) + async with client.start_session() as s: + # Session is pinned to Mongos. + async with await s.start_transaction(): + await coll.insert_one({}, session=s) + + addresses = set() + for _ in range(UNPIN_TEST_MAX_ATTEMPTS): + async with await s.start_transaction(): + cursor = await coll.find({}, session=s) + self.assertTrue(await anext(cursor)) + addresses.add(cursor.address) + # Break early if we can. + if len(addresses) > 1: + break + + self.assertGreater(len(addresses), 1) + + @async_client_context.require_transactions + @async_client_context.require_multiple_mongoses + async def test_unpin_for_non_transaction_operation(self): + # Increase localThresholdMS and wait until both nodes are discovered + # to avoid false positives. + client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000) + wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") + coll = client.test.test + # Create the collection. + await coll.insert_one({}) + self.addAsyncCleanup(client.aclose) + async with client.start_session() as s: + # Session is pinned to Mongos. + async with await s.start_transaction(): + await coll.insert_one({}, session=s) + + addresses = set() + for _ in range(UNPIN_TEST_MAX_ATTEMPTS): + cursor = await coll.find({}, session=s) + self.assertTrue(await anext(cursor)) + addresses.add(cursor.address) + # Break early if we can. + if len(addresses) > 1: + break + + self.assertGreater(len(addresses), 1) + + @async_client_context.require_transactions + @async_client_context.require_version_min(4, 3, 4) + async def test_create_collection(self): + client = async_client_context.client + db = client.pymongo_test + coll = db.test_create_collection + self.addAsyncCleanup(coll.drop) + + # Use with_transaction to avoid StaleConfig errors on sharded clusters. + async def create_and_insert(session): + coll2 = await db.create_collection(coll.name, session=session) + self.assertEqual(coll, coll2) + await coll.insert_one({}, session=session) + + async with client.start_session() as s: + await s.with_transaction(create_and_insert) + + # Outside a transaction we raise CollectionInvalid on existing colls. + with self.assertRaises(CollectionInvalid): + await db.create_collection(coll.name) + + # Inside a transaction we raise the OperationFailure from create. + async with client.start_session() as s: + await s.start_transaction() + with self.assertRaises(OperationFailure) as ctx: + await db.create_collection(coll.name, session=s) + self.assertEqual(ctx.exception.code, 48) # NamespaceExists + + @async_client_context.require_transactions + async def test_gridfs_does_not_support_transactions(self): + client = async_client_context.client + db = client.pymongo_test + gfs = AsyncGridFS(db) + bucket = AsyncGridFSBucket(db) + + async def gridfs_find(*args, **kwargs): + return await gfs.find(*args, **kwargs).next() + + async def gridfs_open_upload_stream(*args, **kwargs): + await (await bucket.open_upload_stream(*args, **kwargs)).write(b"1") + + gridfs_ops = [ + (gfs.put, (b"123",)), + (gfs.get, (1,)), + (gfs.get_version, ("name",)), + (gfs.get_last_version, ("name",)), + (gfs.delete, (1,)), + (gfs.list, ()), + (gfs.find_one, ()), + (gridfs_find, ()), + (gfs.exists, ()), + (gridfs_open_upload_stream, ("name",)), + ( + bucket.upload_from_stream, + ( + "name", + b"data", + ), + ), + ( + bucket.download_to_stream, + ( + 1, + BytesIO(), + ), + ), + ( + bucket.download_to_stream_by_name, + ( + "name", + BytesIO(), + ), + ), + (bucket.delete, (1,)), + (bucket.find, ()), + (bucket.open_download_stream, (1,)), + (bucket.open_download_stream_by_name, ("name",)), + ( + bucket.rename, + ( + 1, + "new-name", + ), + ), + ] + + async with client.start_session() as s, await s.start_transaction(): + for op, args in gridfs_ops: + with self.assertRaisesRegex( + InvalidOperation, + "GridFS does not support multi-document transactions", + ): + await op(*args, session=s) # type: ignore + + # Require 4.2+ for large (16MB+) transactions. + @async_client_context.require_version_min(4, 2) + @async_client_context.require_transactions + @unittest.skipIf(sys.platform == "win32", "Our Windows machines are too slow to pass this test") + async def test_transaction_starts_with_batched_write(self): + if "PyPy" in sys.version and async_client_context.tls: + self.skipTest( + "PYTHON-2937 PyPy is so slow sending large " + "messages over TLS that this test fails" + ) + # Start a transaction with a batch of operations that needs to be + # split. + listener = OvertCommandListener() + client = await async_rs_client(event_listeners=[listener]) + coll = client[self.db.name].test + await coll.delete_many({}) + listener.reset() + self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(coll.drop) + large_str = "\0" * (1 * 1024 * 1024) + ops: List[InsertOne[RawBSONDocument]] = [ + InsertOne(RawBSONDocument(encode({"a": large_str}))) for _ in range(48) + ] + async with client.start_session() as session: + async with await session.start_transaction(): + await coll.bulk_write(ops, session=session) # type: ignore[arg-type] + # Assert commands were constructed properly. + self.assertEqual( + ["insert", "insert", "commitTransaction"], listener.started_command_names() + ) + first_cmd = listener.started_events[0].command + self.assertTrue(first_cmd["startTransaction"]) + lsid = first_cmd["lsid"] + txn_number = first_cmd["txnNumber"] + for event in listener.started_events[1:]: + self.assertNotIn("startTransaction", event.command) + self.assertEqual(lsid, event.command["lsid"]) + self.assertEqual(txn_number, event.command["txnNumber"]) + self.assertEqual(48, await coll.count_documents({})) + + @async_client_context.require_transactions + async def test_transaction_direct_connection(self): + client = await async_single_client() + self.addAsyncCleanup(client.aclose) + coll = client.pymongo_test.test + + # Make sure the collection exists. + await coll.insert_one({}) + self.assertEqual(client.topology_description.topology_type_name, "Single") + ops = [ + (coll.bulk_write, [[InsertOne[dict]({})]]), + (coll.insert_one, [{}]), + (coll.insert_many, [[{}, {}]]), + (coll.replace_one, [{}, {}]), + (coll.update_one, [{}, {"$set": {"a": 1}}]), + (coll.update_many, [{}, {"$set": {"a": 1}}]), + (coll.delete_one, [{}]), + (coll.delete_many, [{}]), + (coll.find_one_and_replace, [{}, {}]), + (coll.find_one_and_update, [{}, {"$set": {"a": 1}}]), + (coll.find_one_and_delete, [{}, {}]), + (coll.find_one, [{}]), + (coll.count_documents, [{}]), + (coll.distinct, ["foo"]), + (coll.aggregate, [[]]), + (coll.find, [{}]), + (coll.aggregate_raw_batches, [[]]), + (coll.find_raw_batches, [{}]), + (coll.database.command, ["find", coll.name]), + ] + for f, args in ops: + async with client.start_session() as s, await s.start_transaction(): + res = await f(*args, session=s) # type:ignore[operator] + if isinstance(res, (AsyncCommandCursor, AsyncCursor)): + await res.to_list() + + +class PatchSessionTimeout: + """Patches the client_session's with_transaction timeout for testing.""" + + def __init__(self, mock_timeout): + self.real_timeout = client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT + self.mock_timeout = mock_timeout + + def __enter__(self): + client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.mock_timeout + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout + + +class TestTransactionsConvenientAPI(AsyncTransactionsBase): + @async_client_context.require_transactions + async def test_callback_raises_custom_error(self): + class _MyException(Exception): + pass + + async def raise_error(_): + raise _MyException + + async with self.client.start_session() as s: + with self.assertRaises(_MyException): + await s.with_transaction(raise_error) + + @async_client_context.require_transactions + async def test_callback_returns_value(self): + async def callback(_): + return "Foo" + + async with self.client.start_session() as s: + self.assertEqual(await s.with_transaction(callback), "Foo") + + await self.db.test.insert_one({}) + + async def callback2(session): + await self.db.test.insert_one({}, session=session) + return "Foo" + + async with self.client.start_session() as s: + self.assertEqual(await s.with_transaction(callback2), "Foo") + + @async_client_context.require_transactions + async def test_callback_not_retried_after_timeout(self): + listener = OvertCommandListener() + client = await async_rs_client(event_listeners=[listener]) + self.addAsyncCleanup(client.aclose) + coll = client[self.db.name].test + + async def callback(session): + await coll.insert_one({}, session=session) + err: dict = { + "ok": 0, + "errmsg": "Transaction 7819 has been aborted.", + "code": 251, + "codeName": "NoSuchTransaction", + "errorLabels": ["TransientTransactionError"], + } + raise OperationFailure(err["errmsg"], err["code"], err) + + # Create the collection. + await coll.insert_one({}) + listener.reset() + async with client.start_session() as s: + with PatchSessionTimeout(0): + with self.assertRaises(OperationFailure): + await s.with_transaction(callback) + + self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) + + @async_client_context.require_test_commands + @async_client_context.require_transactions + async def test_callback_not_retried_after_commit_timeout(self): + listener = OvertCommandListener() + client = await async_rs_client(event_listeners=[listener]) + self.addAsyncCleanup(client.aclose) + coll = client[self.db.name].test + + async def callback(session): + await coll.insert_one({}, session=session) + + # Create the collection. + await coll.insert_one({}) + await self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, # NoSuchTransaction + }, + } + ) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} + ) + listener.reset() + + async with client.start_session() as s: + with PatchSessionTimeout(0): + with self.assertRaises(OperationFailure): + await s.with_transaction(callback) + + self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) + + @async_client_context.require_test_commands + @async_client_context.require_transactions + async def test_commit_not_retried_after_timeout(self): + listener = OvertCommandListener() + client = await async_rs_client(event_listeners=[listener]) + self.addAsyncCleanup(client.aclose) + coll = client[self.db.name].test + + async def callback(session): + await coll.insert_one({}, session=session) + + # Create the collection. + await coll.insert_one({}) + await self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": {"failCommands": ["commitTransaction"], "closeConnection": True}, + } + ) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} + ) + listener.reset() + + async with client.start_session() as s: + with PatchSessionTimeout(0): + with self.assertRaises(ConnectionFailure): + await s.with_transaction(callback) + + # One insert for the callback and two commits (includes the automatic + # retry). + self.assertEqual( + listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] + ) + + # Tested here because this supports Motor's convenient transactions API. + @async_client_context.require_transactions + async def test_in_transaction_property(self): + client = async_client_context.client + coll = client.test.testcollection + await coll.insert_one({}) + self.addAsyncCleanup(coll.drop) + + async with client.start_session() as s: + self.assertFalse(s.in_transaction) + await s.start_transaction() + self.assertTrue(s.in_transaction) + await coll.insert_one({}, session=s) + self.assertTrue(s.in_transaction) + await s.commit_transaction() + self.assertFalse(s.in_transaction) + + async with client.start_session() as s: + await s.start_transaction() + # commit empty transaction + await s.commit_transaction() + self.assertFalse(s.in_transaction) + + async with client.start_session() as s: + await s.start_transaction() + await s.abort_transaction() + self.assertFalse(s.in_transaction) + + # Using a callback + async def callback(session): + self.assertTrue(session.in_transaction) + + async with client.start_session() as s: + self.assertFalse(s.in_transaction) + await s.with_transaction(callback) + self.assertFalse(s.in_transaction) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py new file mode 100644 index 000000000..2e256ec17 --- /dev/null +++ b/test/asynchronous/utils_spec_runner.py @@ -0,0 +1,691 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing driver specs.""" +from __future__ import annotations + +import functools +import threading +from collections import abc +from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs +from test.utils import ( + CMAPListener, + CompareType, + EventListener, + OvertCommandListener, + ServerAndTopologyEventListener, + async_rs_client, + camel_to_snake, + camel_to_snake_args, + parse_spec_options, + prepare_spec_arguments, +) +from typing import List + +from bson import ObjectId, decode, encode +from bson.binary import Binary +from bson.int64 import Int64 +from bson.son import SON +from gridfs import GridFSBucket +from pymongo.asynchronous import client_session +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.cursor import AsyncCursor +from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.results import BulkWriteResult, _WriteResult +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class SpecRunnerThread(threading.Thread): + def __init__(self, name): + super().__init__() + self.name = name + self.exc = None + self.daemon = True + self.cond = threading.Condition() + self.ops = [] + self.stopped = False + + def schedule(self, work): + self.ops.append(work) + with self.cond: + self.cond.notify() + + def stop(self): + self.stopped = True + with self.cond: + self.cond.notify() + + def run(self): + while not self.stopped or self.ops: + if not self.ops: + with self.cond: + self.cond.wait(10) + if self.ops: + try: + work = self.ops.pop(0) + work() + except Exception as exc: + self.exc = exc + self.stop() + + +class AsyncSpecRunner(AsyncIntegrationTest): + mongos_clients: List + knobs: client_knobs + listener: EventListener + + @classmethod + async def _setup_class(cls): + await super()._setup_class() + cls.mongos_clients = [] + + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + cls.knobs.enable() + + @classmethod + async def _tearDown_class(cls): + cls.knobs.disable() + await super()._tearDown_class() + + def setUp(self): + super().setUp() + self.targets = {} + self.listener = None # type: ignore + self.pool_listener = None + self.server_listener = None + self.maxDiff = None + + async def _set_fail_point(self, client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await client.admin.command(cmd) + + async def set_fail_point(self, command_args): + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + await self._set_fail_point(client, command_args) + + async def targeted_fail_point(self, session, fail_point): + """Run the targetedFailPoint test operation. + + Enable the fail point on the session's pinned mongos. + """ + clients = {c.address: c for c in self.mongos_clients} + client = clients[session._pinned_address] + await self._set_fail_point(client, fail_point) + self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) + + def assert_session_pinned(self, session): + """Run the assertSessionPinned test operation. + + Assert that the given session is pinned. + """ + self.assertIsNotNone(session._transaction.pinned_address) + + def assert_session_unpinned(self, session): + """Run the assertSessionUnpinned test operation. + + Assert that the given session is not pinned. + """ + self.assertIsNone(session._pinned_address) + self.assertIsNone(session._transaction.pinned_address) + + async def assert_collection_exists(self, database, collection): + """Run the assertCollectionExists test operation.""" + db = self.client[database] + self.assertIn(collection, await db.list_collection_names()) + + async def assert_collection_not_exists(self, database, collection): + """Run the assertCollectionNotExists test operation.""" + db = self.client[database] + self.assertNotIn(collection, await db.list_collection_names()) + + async def assert_index_exists(self, database, collection, index): + """Run the assertIndexExists test operation.""" + coll = self.client[database][collection] + self.assertIn(index, [doc["name"] async for doc in await coll.list_indexes()]) + + async def assert_index_not_exists(self, database, collection, index): + """Run the assertIndexNotExists test operation.""" + coll = self.client[database][collection] + self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()]) + + def assertErrorLabelsContain(self, exc, expected_labels): + labels = [l for l in expected_labels if exc.has_error_label(l)] + self.assertEqual(labels, expected_labels) + + def assertErrorLabelsOmit(self, exc, omit_labels): + for label in omit_labels: + self.assertFalse( + exc.has_error_label(label), msg=f"error labels should not contain {label}" + ) + + async def kill_all_sessions(self): + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + try: + await client.admin.command("killAllSessions", []) + except OperationFailure: + # "operation was interrupted" by killing the command's + # own session. + pass + + def check_command_result(self, expected_result, result): + # Only compare the keys in the expected result. + filtered_result = {} + for key in expected_result: + try: + filtered_result[key] = result[key] + except KeyError: + pass + self.assertEqual(filtered_result, expected_result) + + # TODO: factor the following function with test_crud.py. + def check_result(self, expected_result, result): + if isinstance(result, _WriteResult): + for res in expected_result: + prop = camel_to_snake(res) + # SPEC-869: Only BulkWriteResult has upserted_count. + if prop == "upserted_count" and not isinstance(result, BulkWriteResult): + if result.upserted_id is not None: + upserted_count = 1 + else: + upserted_count = 0 + self.assertEqual(upserted_count, expected_result[res], prop) + elif prop == "inserted_ids": + # BulkWriteResult does not have inserted_ids. + if isinstance(result, BulkWriteResult): + self.assertEqual(len(expected_result[res]), result.inserted_count) + else: + # InsertManyResult may be compared to [id1] from the + # crud spec or {"0": id1} from the retryable write spec. + ids = expected_result[res] + if isinstance(ids, dict): + ids = [ids[str(i)] for i in range(len(ids))] + + self.assertEqual(ids, result.inserted_ids, prop) + elif prop == "upserted_ids": + # Convert indexes from strings to integers. + ids = expected_result[res] + expected_ids = {} + for str_index in ids: + expected_ids[int(str_index)] = ids[str_index] + self.assertEqual(expected_ids, result.upserted_ids, prop) + else: + self.assertEqual(getattr(result, prop), expected_result[res], prop) + + return True + else: + + def _helper(expected_result, result): + if isinstance(expected_result, abc.Mapping): + for i in expected_result.keys(): + self.assertEqual(expected_result[i], result[i]) + + elif isinstance(expected_result, list): + for i, k in zip(expected_result, result): + _helper(i, k) + else: + self.assertEqual(expected_result, result) + + _helper(expected_result, result) + return None + + def get_object_name(self, op): + """Allow subclasses to override handling of 'object' + + Transaction spec says 'object' is required. + """ + return op["object"] + + @staticmethod + def parse_options(opts): + return parse_spec_options(opts) + + async def run_operation(self, sessions, collection, operation): + original_collection = collection + name = camel_to_snake(operation["name"]) + if name == "run_command": + name = "command" + elif name == "download_by_name": + name = "open_download_stream_by_name" + elif name == "download": + name = "open_download_stream" + elif name == "map_reduce": + self.skipTest("PyMongo does not support mapReduce") + elif name == "count": + self.skipTest("PyMongo does not support count") + + database = collection.database + collection = database.get_collection(collection.name) + if "collectionOptions" in operation: + collection = collection.with_options( + **self.parse_options(operation["collectionOptions"]) + ) + + object_name = self.get_object_name(operation) + if object_name == "gridfsbucket": + # Only create the GridFSBucket when we need it (for the gridfs + # retryable reads tests). + obj = GridFSBucket(database, bucket_name=collection.name) + else: + objects = { + "client": database.client, + "database": database, + "collection": collection, + "testRunner": self, + } + objects.update(sessions) + obj = objects[object_name] + + # Combine arguments with options and handle special cases. + arguments = operation.get("arguments", {}) + arguments.update(arguments.pop("options", {})) + self.parse_options(arguments) + + cmd = getattr(obj, name) + + with_txn_callback = functools.partial( + self.run_operations, sessions, original_collection, in_with_transaction=True + ) + prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) + + if name == "run_on_thread": + args = {"sessions": sessions, "collection": collection} + args.update(arguments) + arguments = args + + result = cmd(**dict(arguments)) + # Cleanup open change stream cursors. + if name == "watch": + self.addAsyncCleanup(result.close) + + if name == "aggregate": + if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: + # Read from the primary to ensure causal consistency. + out = collection.database.get_collection( + arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY + ) + return out.find() + if "download" in name: + result = Binary(result.read()) + + if isinstance(result, AsyncCursor) or isinstance(result, AsyncCommandCursor): + return await result.to_list() + + return result + + def allowable_errors(self, op): + """Allow encryption spec to override expected error classes.""" + return (PyMongoError,) + + async def _run_op(self, sessions, collection, op, in_with_transaction): + expected_result = op.get("result") + if expect_error(op): + with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: + await self.run_operation(sessions, collection, op.copy()) + exc = context.exception + if expect_error_message(expected_result): + if isinstance(exc, BulkWriteError): + errmsg = str(exc.details).lower() + else: + errmsg = str(exc).lower() + self.assertIn(expected_result["errorContains"].lower(), errmsg) + if expect_error_code(expected_result): + self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) + if expect_error_labels_contain(expected_result): + self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) + if expect_error_labels_omit(expected_result): + self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) + if expect_timeout_error(expected_result): + self.assertIsInstance(exc, PyMongoError) + if not exc.timeout: + # Re-raise the exception for better diagnostics. + raise exc + + # Reraise the exception if we're in the with_transaction + # callback. + if in_with_transaction: + raise context.exception + else: + result = await self.run_operation(sessions, collection, op.copy()) + if "result" in op: + if op["name"] == "runCommand": + self.check_command_result(expected_result, result) + else: + self.check_result(expected_result, result) + + async def run_operations(self, sessions, collection, ops, in_with_transaction=False): + for op in ops: + await self._run_op(sessions, collection, op, in_with_transaction) + + # TODO: factor with test_command_monitoring.py + def check_events(self, test, listener, session_ids): + events = listener.started_events + if not len(test["expectations"]): + return + + # Give a nicer message when there are missing or extra events + cmds = decode_raw([event.command for event in events]) + self.assertEqual(len(events), len(test["expectations"]), cmds) + for i, expectation in enumerate(test["expectations"]): + event_type = next(iter(expectation)) + event = events[i] + + # The tests substitute 42 for any number other than 0. + if event.command_name == "getMore" and event.command["getMore"]: + event.command["getMore"] = Int64(42) + elif event.command_name == "killCursors": + event.command["cursors"] = [Int64(42)] + elif event.command_name == "update": + # TODO: remove this once PYTHON-1744 is done. + # Add upsert and multi fields back into expectations. + updates = expectation[event_type]["command"]["updates"] + for update in updates: + update.setdefault("upsert", False) + update.setdefault("multi", False) + + # Replace afterClusterTime: 42 with actual afterClusterTime. + expected_cmd = expectation[event_type]["command"] + expected_read_concern = expected_cmd.get("readConcern") + if expected_read_concern is not None: + time = expected_read_concern.get("afterClusterTime") + if time == 42: + actual_time = event.command.get("readConcern", {}).get("afterClusterTime") + if actual_time is not None: + expected_read_concern["afterClusterTime"] = actual_time + + recovery_token = expected_cmd.get("recoveryToken") + if recovery_token == 42: + expected_cmd["recoveryToken"] = CompareType(dict) + + # Replace lsid with a name like "session0" to match test. + if "lsid" in event.command: + for name, lsid in session_ids.items(): + if event.command["lsid"] == lsid: + event.command["lsid"] = name + break + + for attr, expected in expectation[event_type].items(): + actual = getattr(event, attr) + expected = wrap_types(expected) + if isinstance(expected, dict): + for key, val in expected.items(): + if val is None: + if key in actual: + self.fail(f"Unexpected key [{key}] in {actual!r}") + elif key not in actual: + self.fail(f"Expected key [{key}] in {actual!r}") + else: + self.assertEqual( + val, decode_raw(actual[key]), f"Key [{key}] in {actual}" + ) + else: + self.assertEqual(actual, expected) + + def maybe_skip_scenario(self, test): + if test.get("skipReason"): + self.skipTest(test.get("skipReason")) + + def get_scenario_db_name(self, scenario_def): + """Allow subclasses to override a test's database name.""" + return scenario_def["database_name"] + + def get_scenario_coll_name(self, scenario_def): + """Allow subclasses to override a test's collection name.""" + return scenario_def["collection_name"] + + def get_outcome_coll_name(self, outcome, collection): + """Allow subclasses to override outcome collection.""" + return collection.name + + async def run_test_ops(self, sessions, collection, test): + """Added to allow retryable writes spec to override a test's + operation. + """ + await self.run_operations(sessions, collection, test["operations"]) + + def parse_client_options(self, opts): + """Allow encryption spec to override a clientOptions parsing.""" + # Convert test['clientOptions'] to dict to avoid a Jython bug using + # "**" with ScenarioDict. + return dict(opts) + + async def setup_scenario(self, scenario_def): + """Allow specs to override a test's setup.""" + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + documents = scenario_def["data"] + + # Setup the collection with as few majority writes as possible. + db = async_client_context.client.get_database(db_name) + coll_exists = bool(await db.list_collection_names(filter={"name": coll_name})) + if coll_exists: + await db[coll_name].delete_many({}) + # Only use majority wc only on the final write. + wc = WriteConcern(w="majority") + if documents: + db.get_collection(coll_name, write_concern=wc).insert_many(documents) + elif not coll_exists: + # Ensure collection exists. + await db.create_collection(coll_name, write_concern=wc) + + async def run_scenario(self, scenario_def, test): + self.maybe_skip_scenario(test) + + # Kill all sessions before and after each test to prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + await self.kill_all_sessions() + self.addAsyncCleanup(self.kill_all_sessions) + await self.setup_scenario(scenario_def) + database_name = self.get_scenario_db_name(scenario_def) + collection_name = self.get_scenario_coll_name(scenario_def) + # SPEC-1245 workaround StaleDbVersion on distinct + for c in self.mongos_clients: + await c[database_name][collection_name].distinct("x") + + # Configure the fail point before creating the client. + if "failPoint" in test: + fp = test["failPoint"] + await self.set_fail_point(fp) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + listener = OvertCommandListener() + pool_listener = CMAPListener() + server_listener = ServerAndTopologyEventListener() + # Create a new client, to avoid interference from pooled sessions. + client_options = self.parse_client_options(test["clientOptions"]) + # MMAPv1 does not support retryable writes. + if ( + client_options.get("retryWrites") is True + and async_client_context.storage_engine == "mmapv1" + ): + self.skipTest("MMAPv1 does not support retryWrites=True") + use_multi_mongos = test["useMultipleMongoses"] + host = None + if use_multi_mongos: + if async_client_context.load_balancer or async_client_context.serverless: + host = async_client_context.MULTI_MONGOS_LB_URI + elif async_client_context.is_mongos: + host = async_client_context.mongos_seeds() + client = await async_rs_client( + h=host, event_listeners=[listener, pool_listener, server_listener], **client_options + ) + self.scenario_client = client + self.listener = listener + self.pool_listener = pool_listener + self.server_listener = server_listener + # Close the client explicitly to avoid having too many threads open. + self.addAsyncCleanup(client.aclose) + + # Create session0 and session1. + sessions = {} + session_ids = {} + for i in range(2): + # Don't attempt to create sessions if they are not supported by + # the running server version. + if not async_client_context.sessions_enabled: + break + session_name = "session%d" % i + opts = camel_to_snake_args(test["sessionOptions"][session_name]) + if "default_transaction_options" in opts: + txn_opts = self.parse_options(opts["default_transaction_options"]) + txn_opts = client_session.TransactionOptions(**txn_opts) + opts["default_transaction_options"] = txn_opts + + s = client.start_session(**dict(opts)) + + sessions[session_name] = s + # Store lsid so we can access it after end_session, in check_events. + session_ids[session_name] = s.session_id + + self.addAsyncCleanup(end_sessions, sessions) + + collection = client[database_name][collection_name] + await self.run_test_ops(sessions, collection, test) + + await end_sessions(sessions) + + self.check_events(test, listener, session_ids) + + # Disable fail points. + if "failPoint" in test: + fp = test["failPoint"] + await self.set_fail_point( + {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + # Assert final state is expected. + outcome = test["outcome"] + expected_c = outcome.get("collection") + if expected_c is not None: + outcome_coll_name = self.get_outcome_coll_name(outcome, collection) + + # Read from the primary with local read concern to ensure causal + # consistency. + outcome_coll = async_client_context.client[collection.database.name].get_collection( + outcome_coll_name, + read_preference=ReadPreference.PRIMARY, + read_concern=ReadConcern("local"), + ) + actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list() + + # The expected data needs to be the left hand side here otherwise + # CompareType(Binary) doesn't work. + self.assertEqual(wrap_types(expected_c["data"]), actual_data) + + +def expect_any_error(op): + if isinstance(op, dict): + return op.get("error") + + return False + + +def expect_error_message(expected_result): + if isinstance(expected_result, dict): + return isinstance(expected_result["errorContains"], str) + + return False + + +def expect_error_code(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorCodeName"] + + return False + + +def expect_error_labels_contain(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorLabelsContain"] + + return False + + +def expect_error_labels_omit(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorLabelsOmit"] + + return False + + +def expect_timeout_error(expected_result): + if isinstance(expected_result, dict): + return expected_result["isTimeoutError"] + + return False + + +def expect_error(op): + expected_result = op.get("result") + return ( + expect_any_error(op) + or expect_error_message(expected_result) + or expect_error_code(expected_result) + or expect_error_labels_contain(expected_result) + or expect_error_labels_omit(expected_result) + or expect_timeout_error(expected_result) + ) + + +async def end_sessions(sessions): + for s in sessions.values(): + # Aborts the transaction if it's open. + await s.end_session() + + +def decode_raw(val): + """Decode RawBSONDocuments in the given container.""" + if isinstance(val, (list, abc.Mapping)): + return decode(encode({"v": val}))["v"] + return val + + +TYPES = { + "binData": Binary, + "long": Int64, + "int": int, + "string": str, + "objectId": ObjectId, + "object": dict, + "array": list, +} + + +def wrap_types(val): + """Support $$type assertion in command results.""" + if isinstance(val, list): + return [wrap_types(v) for v in val] + if isinstance(val, abc.Mapping): + typ = val.get("$$type") + if typ: + if isinstance(typ, str): + types = TYPES[typ] + else: + types = tuple(TYPES[t] for t in typ) + return CompareType(types) + d = {} + for key in val: + d[key] = wrap_types(val[key]) + return d + return val diff --git a/test/test_transactions.py b/test/test_transactions.py index 62525742d..ade7ee4de 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -19,12 +19,13 @@ import os import sys from io import BytesIO +from gridfs.synchronous.grid_file import GridFS, GridFSBucket + sys.path[0:0] = [""] from test import client_context, unittest from test.utils import ( OvertCommandListener, - SpecTestCreator, rs_client, single_client, wait_until, @@ -34,7 +35,6 @@ from typing import List from bson import encode from bson.raw_bson import RawBSONDocument -from gridfs import GridFS, GridFSBucket from pymongo import WriteConcern from pymongo.errors import ( CollectionInvalid, @@ -50,6 +50,9 @@ from pymongo.synchronous import client_session from pymongo.synchronous.client_session import TransactionOptions from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor +from pymongo.synchronous.helpers import next + +_IS_SYNC = True _TXN_TESTS_DEBUG = os.environ.get("TRANSACTION_TESTS_DEBUG") @@ -62,17 +65,17 @@ UNPIN_TEST_MAX_ATTEMPTS = 50 class TransactionsBase(SpecRunner): @classmethod - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() if client_context.supports_transactions(): for address in client_context.mongoses: cls.mongos_clients.append(single_client("{}:{}".format(*address))) @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): for client in cls.mongos_clients: client.close() - super().tearDownClass() + super()._tearDown_class() def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) @@ -124,14 +127,14 @@ class TestTransactions(TransactionsBase): coll.insert_one({}) with client.start_session() as s: with s.start_transaction(write_concern=WriteConcern(w=1)): - self.assertTrue(coll.insert_one({}, session=s).acknowledged) - self.assertTrue(coll.insert_many([{}, {}], session=s).acknowledged) - self.assertTrue(coll.bulk_write([InsertOne({})], session=s).acknowledged) - self.assertTrue(coll.replace_one({}, {}, session=s).acknowledged) - self.assertTrue(coll.update_one({}, {"$set": {"a": 1}}, session=s).acknowledged) - self.assertTrue(coll.update_many({}, {"$set": {"a": 1}}, session=s).acknowledged) - self.assertTrue(coll.delete_one({}, session=s).acknowledged) - self.assertTrue(coll.delete_many({}, session=s).acknowledged) + self.assertTrue((coll.insert_one({}, session=s)).acknowledged) + self.assertTrue((coll.insert_many([{}, {}], session=s)).acknowledged) + self.assertTrue((coll.bulk_write([InsertOne({})], session=s)).acknowledged) + self.assertTrue((coll.replace_one({}, {}, session=s)).acknowledged) + self.assertTrue((coll.update_one({}, {"$set": {"a": 1}}, session=s)).acknowledged) + self.assertTrue((coll.update_many({}, {"$set": {"a": 1}}, session=s)).acknowledged) + self.assertTrue((coll.delete_one({}, session=s)).acknowledged) + self.assertTrue((coll.delete_many({}, session=s)).acknowledged) coll.find_one_and_delete({}, session=s) coll.find_one_and_replace({}, {}, session=s) coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s) @@ -260,7 +263,7 @@ class TestTransactions(TransactionsBase): return gfs.find(*args, **kwargs).next() def gridfs_open_upload_stream(*args, **kwargs): - bucket.open_upload_stream(*args, **kwargs).write(b"1") + (bucket.open_upload_stream(*args, **kwargs)).write(b"1") gridfs_ops = [ (gfs.put, (b"123",)), @@ -389,7 +392,7 @@ class TestTransactions(TransactionsBase): with client.start_session() as s, s.start_transaction(): res = f(*args, session=s) # type:ignore[operator] if isinstance(res, (CommandCursor, Cursor)): - list(res) + res.to_list() class PatchSessionTimeout: diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 091533582..0b882a8bc 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -47,6 +47,8 @@ from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class SpecRunnerThread(threading.Thread): def __init__(self, name): @@ -88,8 +90,8 @@ class SpecRunner(IntegrationTest): listener: EventListener @classmethod - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. @@ -97,9 +99,9 @@ class SpecRunner(IntegrationTest): cls.knobs.enable() @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.knobs.disable() - super().tearDownClass() + super()._tearDown_class() def setUp(self): super().setUp() @@ -325,7 +327,7 @@ class SpecRunner(IntegrationTest): result = Binary(result.read()) if isinstance(result, Cursor) or isinstance(result, CommandCursor): - return list(result) + return result.to_list() return result @@ -580,7 +582,7 @@ class SpecRunner(IntegrationTest): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = list(outcome_coll.find(sort=[("_id", 1)])) + actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/tools/synchro.py b/tools/synchro.py index 0c608e4e6..761ddb2a7 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -74,6 +74,8 @@ replacements = { "IsolatedAsyncioTestCase": "TestCase", "AsyncUnitTest": "UnitTest", "AsyncMockClient": "MockClient", + "AsyncSpecRunner": "SpecRunner", + "AsyncTransactionsBase": "TransactionsBase", "async_get_pool": "get_pool", "async_is_mongos": "is_mongos", "async_rs_or_single_client": "rs_or_single_client", @@ -141,9 +143,11 @@ converted_tests = [ "__init__.py", "conftest.py", "pymongo_mocks.py", + "utils_spec_runner.py", "test_client.py", "test_collection.py", "test_database.py", + "test_transactions.py", ] sync_test_files = [