# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Execute Transactions Spec tests.""" from __future__ import annotations import asyncio import random import sys import time from io import BytesIO from unittest.mock import patch import pymongo from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket from pymongo.asynchronous.pool import PoolState from pymongo.server_selectors import writable_server_selector sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.utils_shared import ( OvertCommandListener, async_wait_until, ) from typing import List from bson import encode from bson.raw_bson import RawBSONDocument from pymongo import WriteConcern, _csot from pymongo.asynchronous import client_session from pymongo.asynchronous.client_session import TransactionOptions from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.helpers import anext from pymongo.errors import ( AutoReconnect, CollectionInvalid, ConfigurationError, ConnectionFailure, ExecutionTimeout, InvalidOperation, NetworkTimeout, OperationFailure, ) from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference _IS_SYNC = False # Max number of operations to perform after a transaction to prove unpinning # occurs. Chosen so that there's a low false positive rate. With 2 mongoses, # 50 attempts yields a one in a quadrillion chance of a false positive # (1/(0.5^50)). UNPIN_TEST_MAX_ATTEMPTS = 50 class AsyncTransactionsBase(AsyncIntegrationTest): pass class TestTransactions(AsyncTransactionsBase): @async_client_context.require_transactions def test_transaction_options_validation(self): default_options = TransactionOptions() self.assertIsNone(default_options.read_concern) self.assertIsNone(default_options.write_concern) self.assertIsNone(default_options.read_preference) self.assertIsNone(default_options.max_commit_time_ms) # No error when valid options are provided. TransactionOptions( read_concern=ReadConcern(), write_concern=WriteConcern(), read_preference=ReadPreference.PRIMARY, max_commit_time_ms=10000, ) with self.assertRaisesRegex(TypeError, "read_concern must be "): TransactionOptions(read_concern={}) # type: ignore with self.assertRaisesRegex(TypeError, "write_concern must be "): TransactionOptions(write_concern={}) # type: ignore with self.assertRaisesRegex( ConfigurationError, "transactions do not support unacknowledged write concern" ): TransactionOptions(write_concern=WriteConcern(w=0)) with self.assertRaisesRegex(TypeError, "is not valid for read_preference"): TransactionOptions(read_preference={}) # type: ignore with self.assertRaisesRegex(TypeError, "max_commit_time_ms must be an integer or None"): TransactionOptions(max_commit_time_ms="10000") # type: ignore @async_client_context.require_transactions async def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" client = await self.async_rs_client(w=0) db = client.test coll = db.test await coll.insert_one({}) async with client.start_session() as s: async with await s.start_transaction(write_concern=WriteConcern(w=1)): self.assertTrue((await coll.insert_one({}, session=s)).acknowledged) self.assertTrue((await coll.insert_many([{}, {}], session=s)).acknowledged) self.assertTrue((await coll.bulk_write([InsertOne({})], session=s)).acknowledged) self.assertTrue((await coll.replace_one({}, {}, session=s)).acknowledged) self.assertTrue( (await coll.update_one({}, {"$set": {"a": 1}}, session=s)).acknowledged ) self.assertTrue( (await coll.update_many({}, {"$set": {"a": 1}}, session=s)).acknowledged ) self.assertTrue((await coll.delete_one({}, session=s)).acknowledged) self.assertTrue((await coll.delete_many({}, session=s)).acknowledged) await coll.find_one_and_delete({}, session=s) await coll.find_one_and_replace({}, {}, session=s) await coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s) unsupported_txn_writes: list = [ (client.drop_database, [db.name], {}), (db.drop_collection, ["collection"], {}), (coll.drop, [], {}), (coll.rename, ["collection2"], {}), # Drop collection2 between tests of "rename", above. (coll.database.drop_collection, ["collection2"], {}), (coll.create_indexes, [[IndexModel("a")]], {}), (coll.create_index, ["a"], {}), (coll.drop_index, ["a_1"], {}), (coll.drop_indexes, [], {}), (coll.aggregate, [[{"$out": "aggout"}]], {}), ] # Creating a collection in a transaction requires MongoDB 4.4+. if async_client_context.version < (4, 3, 4): unsupported_txn_writes.extend( [ (db.create_collection, ["collection"], {}), ] ) for op in unsupported_txn_writes: op, args, kwargs = op async with client.start_session() as s: kwargs["session"] = s await s.start_transaction(write_concern=WriteConcern(w=1)) with self.assertRaises(OperationFailure): await op(*args, **kwargs) await s.abort_transaction() @async_client_context.require_transactions @async_client_context.require_multiple_mongoses async def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. client = await self.async_rs_client( async_client_context.mongos_seeds(), localThresholdMS=1000 ) await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): await coll.insert_one({}, session=s) addresses = set() for _ in range(UNPIN_TEST_MAX_ATTEMPTS): async with await s.start_transaction(): cursor = coll.find({}, session=s) self.assertTrue(await anext(cursor)) addresses.add(cursor.address) # Break early if we can. if len(addresses) > 1: break self.assertGreater(len(addresses), 1) @async_client_context.require_transactions @async_client_context.require_multiple_mongoses async def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. client = await self.async_rs_client( async_client_context.mongos_seeds(), localThresholdMS=1000 ) await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): await coll.insert_one({}, session=s) addresses = set() for _ in range(UNPIN_TEST_MAX_ATTEMPTS): cursor = coll.find({}, session=s) self.assertTrue(await anext(cursor)) addresses.add(cursor.address) # Break early if we can. if len(addresses) > 1: break self.assertGreater(len(addresses), 1) @async_client_context.require_transactions @async_client_context.require_version_min(4, 3, 4) async def test_create_collection(self): client = async_client_context.client db = client.pymongo_test coll = db.test_create_collection self.addAsyncCleanup(coll.drop) # Use with_transaction to avoid StaleConfig errors on sharded clusters. async def create_and_insert(session): coll2 = await db.create_collection(coll.name, session=session) self.assertEqual(coll, coll2) await coll.insert_one({}, session=session) async with client.start_session() as s: await s.with_transaction(create_and_insert) # Outside a transaction we raise CollectionInvalid on existing colls. with self.assertRaises(CollectionInvalid): await db.create_collection(coll.name) # Inside a transaction we raise the OperationFailure from create. async with client.start_session() as s: await s.start_transaction() with self.assertRaises(OperationFailure) as ctx: await db.create_collection(coll.name, session=s) self.assertEqual(ctx.exception.code, 48) # NamespaceExists @async_client_context.require_transactions async def test_gridfs_does_not_support_transactions(self): client = async_client_context.client db = client.pymongo_test gfs = AsyncGridFS(db) bucket = AsyncGridFSBucket(db) async def gridfs_find(*args, **kwargs): return await gfs.find(*args, **kwargs).next() async def gridfs_open_upload_stream(*args, **kwargs): await (await bucket.open_upload_stream(*args, **kwargs)).write(b"1") gridfs_ops = [ (gfs.put, (b"123",)), (gfs.get, (1,)), (gfs.get_version, ("name",)), (gfs.get_last_version, ("name",)), (gfs.delete, (1,)), (gfs.list, ()), (gfs.find_one, ()), (gridfs_find, ()), (gfs.exists, ()), (gridfs_open_upload_stream, ("name",)), ( bucket.upload_from_stream, ( "name", b"data", ), ), ( bucket.download_to_stream, ( 1, BytesIO(), ), ), ( bucket.download_to_stream_by_name, ( "name", BytesIO(), ), ), (bucket.delete, (1,)), (bucket.find, ()), (bucket.open_download_stream, (1,)), (bucket.open_download_stream_by_name, ("name",)), ( bucket.rename, ( 1, "new-name", ), ), ( bucket.rename_by_name, ( "new-name", "new-name2", ), ), (bucket.delete_by_name, ("new-name2",)), ] async with client.start_session() as s, await s.start_transaction(): for op, args in gridfs_ops: with self.assertRaisesRegex( InvalidOperation, "GridFS does not support multi-document transactions", ): await op(*args, session=s) # type: ignore # Require 4.2+ for large (16MB+) transactions. @async_client_context.require_version_min(4, 2) @async_client_context.require_transactions @unittest.skipIf(sys.platform == "win32", "Our Windows machines are too slow to pass this test") async def test_transaction_starts_with_batched_write(self): if "PyPy" in sys.version and async_client_context.tls: self.skipTest( "PYTHON-2937 PyPy is so slow sending large " "messages over TLS that this test fails" ) # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test await coll.delete_many({}) listener.reset() self.addAsyncCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ InsertOne(RawBSONDocument(encode({"a": large_str}))) for _ in range(48) ] async with client.start_session() as session: async with await session.start_transaction(): await coll.bulk_write(ops, session=session) # type: ignore[arg-type] # Assert commands were constructed properly. self.assertEqual( ["insert", "insert", "commitTransaction"], listener.started_command_names() ) first_cmd = listener.started_events[0].command self.assertTrue(first_cmd["startTransaction"]) lsid = first_cmd["lsid"] txn_number = first_cmd["txnNumber"] for event in listener.started_events[1:]: self.assertNotIn("startTransaction", event.command) self.assertEqual(lsid, event.command["lsid"]) self.assertEqual(txn_number, event.command["txnNumber"]) self.assertEqual(48, await coll.count_documents({})) @async_client_context.require_transactions async def test_transaction_direct_connection(self): client = await self.async_single_client() coll = client.pymongo_test.test # Make sure the collection exists. await coll.insert_one({}) self.assertEqual(client.topology_description.topology_type_name, "Single") async def find(*args, **kwargs): return coll.find(*args, **kwargs) async def find_raw_batches(*args, **kwargs): return coll.find_raw_batches(*args, **kwargs) ops = [ (coll.bulk_write, [[InsertOne[dict]({})]]), (coll.insert_one, [{}]), (coll.insert_many, [[{}, {}]]), (coll.replace_one, [{}, {}]), (coll.update_one, [{}, {"$set": {"a": 1}}]), (coll.update_many, [{}, {"$set": {"a": 1}}]), (coll.delete_one, [{}]), (coll.delete_many, [{}]), (coll.find_one_and_replace, [{}, {}]), (coll.find_one_and_update, [{}, {"$set": {"a": 1}}]), (coll.find_one_and_delete, [{}, {}]), (coll.find_one, [{}]), (coll.count_documents, [{}]), (coll.distinct, ["foo"]), (coll.aggregate, [[]]), (find, [{}]), (coll.aggregate_raw_batches, [[]]), (find_raw_batches, [{}]), (coll.database.command, ["find", coll.name]), ] for f, args in ops: async with client.start_session() as s, await s.start_transaction(): res = await f(*args, session=s) # type:ignore[operator] if isinstance(res, (AsyncCommandCursor, AsyncCursor)): await res.to_list() @async_client_context.require_transactions async def test_transaction_pool_cleared_error_labelled_transient(self): c = await self.async_single_client() with self.assertRaises(AutoReconnect) as context: async with c.start_session() as session: async with await session.start_transaction(): server = await c._select_server(writable_server_selector, session, "test") # Pause the server's pool, causing it to fail connection checkout. server.pool.state = PoolState.PAUSED async with c._checkout(server, session): pass # Verify that the TransientTransactionError label is present in the error. self.assertTrue(context.exception.has_error_label("TransientTransactionError")) class PatchSessionTimeout: """Patches the client_session's with_transaction timeout for testing.""" def __init__(self, mock_timeout): self.real_timeout = client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT self.mock_timeout = mock_timeout def __enter__(self): client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.mock_timeout return self def __exit__(self, exc_type, exc_val, exc_tb): client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout class TestTransactionsConvenientAPI(AsyncTransactionsBase): async def asyncSetUp(self) -> None: await super().asyncSetUp() self.mongos_clients = [] if async_client_context.supports_transactions(): for address in async_client_context.mongoses: self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) async def set_fail_point(self, command_args): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: await self.configure_fail_point(client, command_args) @async_client_context.require_transactions async def test_1_callback_raises_custom_error(self): class _MyException(Exception): pass async def raise_error(_): raise _MyException async with self.client.start_session() as s: with self.assertRaises(_MyException): await s.with_transaction(raise_error) @async_client_context.require_transactions async def test_2_callback_returns_value(self): async def callback(_): return "Foo" async with self.client.start_session() as s: self.assertEqual(await s.with_transaction(callback), "Foo") await self.db.test.insert_one({}) async def callback2(session): await self.db.test.insert_one({}, session=session) return "Foo" async with self.client.start_session() as s: self.assertEqual(await s.with_transaction(callback2), "Foo") @async_client_context.require_transactions @async_client_context.require_async async def test_callback_awaitable_no_coroutine(self): def callback(_): future = asyncio.Future() future.set_result("Foo") return future async with self.client.start_session() as s: self.assertEqual(await s.with_transaction(callback), "Foo") @async_client_context.require_transactions async def test_3_1_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): await coll.insert_one({}, session=session) err: dict = { "ok": 0, "errmsg": "Transaction 7819 has been aborted.", "code": 251, "codeName": "NoSuchTransaction", "errorLabels": ["TransientTransactionError"], } raise OperationFailure(err["errmsg"], err["code"], err) # Create the collection. await coll.insert_one({}) listener.reset() async with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) # Assert that the timeout error has the same labels as the error it wraps. self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @async_client_context.require_test_commands @async_client_context.require_transactions async def test_3_2_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): await coll.insert_one({}, session=session) # Create the collection. await coll.insert_one({}) await self.set_fail_point( { "configureFailPoint": "failCommand", "mode": {"times": 1}, "data": { "failCommands": ["commitTransaction"], "errorCode": 251, # NoSuchTransaction }, } ) self.addAsyncCleanup( self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} ) listener.reset() async with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) # Assert that the timeout error has the same labels as the error it wraps. self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @async_client_context.require_test_commands @async_client_context.require_transactions async def test_3_3_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): await coll.insert_one({}, session=session) # Create the collection. await coll.insert_one({}) await self.set_fail_point( { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": {"failCommands": ["commitTransaction"], "closeConnection": True}, } ) self.addAsyncCleanup( self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} ) listener.reset() async with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic # retry). self.assertEqual( listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] ) # Assert that the timeout error has the same labels as the error it wraps. self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult")) @async_client_context.require_transactions async def test_callback_not_retried_after_csot_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): await coll.insert_one({}, session=session) err: dict = { "ok": 0, "errmsg": "Transaction 7819 has been aborted.", "code": 251, "codeName": "NoSuchTransaction", "errorLabels": ["TransientTransactionError"], } raise OperationFailure(err["errmsg"], err["code"], err) # Create the collection. await coll.insert_one({}) listener.reset() async with client.start_session() as s: with pymongo.timeout(1.0): with self.assertRaises(ExecutionTimeout): await s.with_transaction(callback) # At least two attempts: the original and one or more retries. inserts = len([x for x in listener.started_command_names() if x == "insert"]) aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"]) self.assertGreaterEqual(inserts, 2) self.assertGreaterEqual(aborts, 2) # Tested here because this supports Motor's convenient transactions API. @async_client_context.require_transactions async def test_in_transaction_property(self): client = async_client_context.client coll = client.test.testcollection await coll.insert_one({}) self.addAsyncCleanup(coll.drop) async with client.start_session() as s: self.assertFalse(s.in_transaction) await s.start_transaction() self.assertTrue(s.in_transaction) await coll.insert_one({}, session=s) self.assertTrue(s.in_transaction) await s.commit_transaction() self.assertFalse(s.in_transaction) async with client.start_session() as s: await s.start_transaction() # commit empty transaction await s.commit_transaction() self.assertFalse(s.in_transaction) async with client.start_session() as s: await s.start_transaction() await s.abort_transaction() self.assertFalse(s.in_transaction) # Using a callback async def callback(session): self.assertTrue(session.in_transaction) async with client.start_session() as s: self.assertFalse(s.in_transaction) await s.with_transaction(callback) self.assertFalse(s.in_transaction) @async_client_context.require_test_commands @async_client_context.require_transactions async def test_4_retry_backoff_is_enforced(self): client = async_client_context.client coll = client[self.db.name].test end = start = no_backoff_time = 0 # Make random.random always return 0 (no backoff) with patch.object(random, "random", return_value=0): # set fail point to trigger transaction failure and trigger backoff await self.set_fail_point( { "configureFailPoint": "failCommand", "mode": {"times": 13}, "data": { "failCommands": ["commitTransaction"], "errorCode": 251, }, } ) self.addAsyncCleanup( self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} ) async def callback(session): await coll.insert_one({}, session=session) start = time.monotonic() async with self.client.start_session() as s: await s.with_transaction(callback) end = time.monotonic() no_backoff_time = end - start # Make random.random always return 1 (max backoff) with patch.object(random, "random", return_value=1): # set fail point to trigger transaction failure and trigger backoff await self.set_fail_point( { "configureFailPoint": "failCommand", "mode": { "times": 13 }, # sufficiently high enough such that the time effect of backoff is noticeable "data": { "failCommands": ["commitTransaction"], "errorCode": 251, }, } ) self.addAsyncCleanup( self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} ) start = time.monotonic() async with self.client.start_session() as s: await s.with_transaction(callback) end = time.monotonic() self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2 class TestOptionsInsideTransactionProse(AsyncTransactionsBase): @async_client_context.require_transactions @async_client_context.require_no_standalone async def test_case_1(self): # Write concern not inherited from collection object inside transaction # Create a MongoClient running against a configured sharded/replica set/load balanced cluster. client = async_client_context.client coll = client[self.db.name].test await coll.delete_many({}) # Start a new session on the client. async with client.start_session() as s: # Start a transaction on the session. await s.start_transaction() # Instantiate a collection object in the driver with a default write concern of { w: 0 }. inner_coll = coll.with_options(write_concern=WriteConcern(w=0)) # Insert the document { n: 1 } on the instantiated collection. result = await inner_coll.insert_one({"n": 1}, session=s) # Commit the transaction. await s.commit_transaction() # End the session. # Ensure the document was inserted and no error was thrown from the transaction. assert result.inserted_id is not None if __name__ == "__main__": unittest.main()