# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Execute Transactions Spec tests.""" import os import sys from io import BytesIO sys.path[0:0] = [""] from test import client_context, unittest from test.utils import ( OvertCommandListener, TestCreator, rs_client, single_client, wait_until, ) from test.utils_spec_runner import SpecRunner from bson import encode from bson.raw_bson import RawBSONDocument from gridfs import GridFS, GridFSBucket from pymongo import WriteConcern, client_session from pymongo.client_session import TransactionOptions from pymongo.command_cursor import CommandCursor from pymongo.cursor import Cursor from pymongo.errors import ( CollectionInvalid, ConfigurationError, ConnectionFailure, InvalidOperation, OperationFailure, ) from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference # Location of JSON test specifications. TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "transactions", "legacy") _TXN_TESTS_DEBUG = os.environ.get("TRANSACTION_TESTS_DEBUG") # Max number of operations to perform after a transaction to prove unpinning # occurs. Chosen so that there's a low false positive rate. With 2 mongoses, # 50 attempts yields a one in a quadrillion chance of a false positive # (1/(0.5^50)). UNPIN_TEST_MAX_ATTEMPTS = 50 class TransactionsBase(SpecRunner): @classmethod def setUpClass(cls): super(TransactionsBase, cls).setUpClass() if client_context.supports_transactions(): for address in client_context.mongoses: cls.mongos_clients.append(single_client("%s:%s" % address)) @classmethod def tearDownClass(cls): for client in cls.mongos_clients: client.close() super(TransactionsBase, cls).tearDownClass() def maybe_skip_scenario(self, test): super(TransactionsBase, self).maybe_skip_scenario(test) if ( "secondary" in self.id() and not client_context.is_mongos and not client_context.has_secondaries ): raise unittest.SkipTest("No secondaries") class TestTransactions(TransactionsBase): RUN_ON_SERVERLESS = True @client_context.require_transactions def test_transaction_options_validation(self): default_options = TransactionOptions() self.assertIsNone(default_options.read_concern) self.assertIsNone(default_options.write_concern) self.assertIsNone(default_options.read_preference) self.assertIsNone(default_options.max_commit_time_ms) # No error when valid options are provided. TransactionOptions( read_concern=ReadConcern(), write_concern=WriteConcern(), read_preference=ReadPreference.PRIMARY, max_commit_time_ms=10000, ) with self.assertRaisesRegex(TypeError, "read_concern must be "): TransactionOptions(read_concern={}) # type: ignore with self.assertRaisesRegex(TypeError, "write_concern must be "): TransactionOptions(write_concern={}) # type: ignore with self.assertRaisesRegex( ConfigurationError, "transactions do not support unacknowledged write concern" ): TransactionOptions(write_concern=WriteConcern(w=0)) with self.assertRaisesRegex(TypeError, "is not valid for read_preference"): TransactionOptions(read_preference={}) # type: ignore with self.assertRaisesRegex(TypeError, "max_commit_time_ms must be an integer or None"): TransactionOptions(max_commit_time_ms="10000") # type: ignore @client_context.require_transactions def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" client = rs_client(w=0) self.addCleanup(client.close) db = client.test coll = db.test coll.insert_one({}) with client.start_session() as s: with s.start_transaction(write_concern=WriteConcern(w=1)): self.assertTrue(coll.insert_one({}, session=s).acknowledged) self.assertTrue(coll.insert_many([{}, {}], session=s).acknowledged) self.assertTrue(coll.bulk_write([InsertOne({})], session=s).acknowledged) self.assertTrue(coll.replace_one({}, {}, session=s).acknowledged) self.assertTrue(coll.update_one({}, {"$set": {"a": 1}}, session=s).acknowledged) self.assertTrue(coll.update_many({}, {"$set": {"a": 1}}, session=s).acknowledged) self.assertTrue(coll.delete_one({}, session=s).acknowledged) self.assertTrue(coll.delete_many({}, session=s).acknowledged) coll.find_one_and_delete({}, session=s) coll.find_one_and_replace({}, {}, session=s) coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s) unsupported_txn_writes: list = [ (client.drop_database, [db.name], {}), (db.drop_collection, ["collection"], {}), (coll.drop, [], {}), (coll.rename, ["collection2"], {}), # Drop collection2 between tests of "rename", above. (coll.database.drop_collection, ["collection2"], {}), (coll.create_indexes, [[IndexModel("a")]], {}), (coll.create_index, ["a"], {}), (coll.drop_index, ["a_1"], {}), (coll.drop_indexes, [], {}), (coll.aggregate, [[{"$out": "aggout"}]], {}), ] # Creating a collection in a transaction requires MongoDB 4.4+. if client_context.version < (4, 3, 4): unsupported_txn_writes.extend( [ (db.create_collection, ["collection"], {}), ] ) for op in unsupported_txn_writes: op, args, kwargs = op with client.start_session() as s: kwargs["session"] = s s.start_transaction(write_concern=WriteConcern(w=1)) with self.assertRaises(OperationFailure): op(*args, **kwargs) s.abort_transaction() @client_context.require_transactions @client_context.require_multiple_mongoses def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. coll.insert_one({}) self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): coll.insert_one({}, session=s) addresses = set() for _ in range(UNPIN_TEST_MAX_ATTEMPTS): with s.start_transaction(): cursor = coll.find({}, session=s) self.assertTrue(next(cursor)) addresses.add(cursor.address) # Break early if we can. if len(addresses) > 1: break self.assertGreater(len(addresses), 1) @client_context.require_transactions @client_context.require_multiple_mongoses def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. coll.insert_one({}) self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): coll.insert_one({}, session=s) addresses = set() for _ in range(UNPIN_TEST_MAX_ATTEMPTS): cursor = coll.find({}, session=s) self.assertTrue(next(cursor)) addresses.add(cursor.address) # Break early if we can. if len(addresses) > 1: break self.assertGreater(len(addresses), 1) @client_context.require_transactions @client_context.require_version_min(4, 3, 4) def test_create_collection(self): client = client_context.client db = client.pymongo_test coll = db.test_create_collection self.addCleanup(coll.drop) # Use with_transaction to avoid StaleConfig errors on sharded clusters. def create_and_insert(session): coll2 = db.create_collection(coll.name, session=session) self.assertEqual(coll, coll2) coll.insert_one({}, session=session) with client.start_session() as s: s.with_transaction(create_and_insert) # Outside a transaction we raise CollectionInvalid on existing colls. with self.assertRaises(CollectionInvalid): db.create_collection(coll.name) # Inside a transaction we raise the OperationFailure from create. with client.start_session() as s: s.start_transaction() with self.assertRaises(OperationFailure) as ctx: db.create_collection(coll.name, session=s) self.assertEqual(ctx.exception.code, 48) # NamespaceExists @client_context.require_transactions def test_gridfs_does_not_support_transactions(self): client = client_context.client db = client.pymongo_test gfs = GridFS(db) bucket = GridFSBucket(db) def gridfs_find(*args, **kwargs): return gfs.find(*args, **kwargs).next() def gridfs_open_upload_stream(*args, **kwargs): bucket.open_upload_stream(*args, **kwargs).write(b"1") gridfs_ops = [ (gfs.put, (b"123",)), (gfs.get, (1,)), (gfs.get_version, ("name",)), (gfs.get_last_version, ("name",)), (gfs.delete, (1,)), (gfs.list, ()), (gfs.find_one, ()), (gridfs_find, ()), (gfs.exists, ()), (gridfs_open_upload_stream, ("name",)), ( bucket.upload_from_stream, ( "name", b"data", ), ), ( bucket.download_to_stream, ( 1, BytesIO(), ), ), ( bucket.download_to_stream_by_name, ( "name", BytesIO(), ), ), (bucket.delete, (1,)), (bucket.find, ()), (bucket.open_download_stream, (1,)), (bucket.open_download_stream_by_name, ("name",)), ( bucket.rename, ( 1, "new-name", ), ), ] with client.start_session() as s, s.start_transaction(): for op, args in gridfs_ops: with self.assertRaisesRegex( InvalidOperation, "GridFS does not support multi-document transactions", ): op(*args, session=s) # type: ignore # Require 4.2+ for large (16MB+) transactions. @client_context.require_version_min(4, 2) @client_context.require_transactions @unittest.skipIf(sys.platform == "win32", "Our Windows machines are too slow to pass this test") def test_transaction_starts_with_batched_write(self): if "PyPy" in sys.version and client_context.tls: self.skipTest( "PYTHON-2937 PyPy is so slow sending large " "messages over TLS that this test fails" ) # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) coll = client[self.db.name].test coll.delete_many({}) listener.reset() self.addCleanup(client.close) self.addCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops = [InsertOne(RawBSONDocument(encode({"a": large_str}))) for _ in range(48)] with client.start_session() as session: with session.start_transaction(): coll.bulk_write(ops, session=session) # Assert commands were constructed properly. self.assertEqual( ["insert", "insert", "commitTransaction"], listener.started_command_names() ) first_cmd = listener.results["started"][0].command self.assertTrue(first_cmd["startTransaction"]) lsid = first_cmd["lsid"] txn_number = first_cmd["txnNumber"] for event in listener.results["started"][1:]: self.assertNotIn("startTransaction", event.command) self.assertEqual(lsid, event.command["lsid"]) self.assertEqual(txn_number, event.command["txnNumber"]) self.assertEqual(48, coll.count_documents({})) @client_context.require_transactions def test_transaction_direct_connection(self): client = single_client() self.addCleanup(client.close) coll = client.pymongo_test.test # Make sure the collection exists. coll.insert_one({}) self.assertEqual(client.topology_description.topology_type_name, "Single") ops = [ (coll.bulk_write, [[InsertOne({})]]), (coll.insert_one, [{}]), (coll.insert_many, [[{}, {}]]), (coll.replace_one, [{}, {}]), (coll.update_one, [{}, {"$set": {"a": 1}}]), (coll.update_many, [{}, {"$set": {"a": 1}}]), (coll.delete_one, [{}]), (coll.delete_many, [{}]), (coll.find_one_and_replace, [{}, {}]), (coll.find_one_and_update, [{}, {"$set": {"a": 1}}]), (coll.find_one_and_delete, [{}, {}]), (coll.find_one, [{}]), (coll.count_documents, [{}]), (coll.distinct, ["foo"]), (coll.aggregate, [[]]), (coll.find, [{}]), (coll.aggregate_raw_batches, [[]]), (coll.find_raw_batches, [{}]), (coll.database.command, ["find", coll.name]), ] for f, args in ops: with client.start_session() as s, s.start_transaction(): res = f(*args, session=s) if isinstance(res, (CommandCursor, Cursor)): list(res) class PatchSessionTimeout(object): """Patches the client_session's with_transaction timeout for testing.""" def __init__(self, mock_timeout): self.real_timeout = client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT self.mock_timeout = mock_timeout def __enter__(self): client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.mock_timeout return self def __exit__(self, exc_type, exc_val, exc_tb): client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout class TestTransactionsConvenientAPI(TransactionsBase): TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), "transactions-convenient-api" ) @client_context.require_transactions def test_callback_raises_custom_error(self): class _MyException(Exception): pass def raise_error(_): raise _MyException() with self.client.start_session() as s: with self.assertRaises(_MyException): s.with_transaction(raise_error) @client_context.require_transactions def test_callback_returns_value(self): def callback(_): return "Foo" with self.client.start_session() as s: self.assertEqual(s.with_transaction(callback), "Foo") self.db.test.insert_one({}) def callback2(session): self.db.test.insert_one({}, session=session) return "Foo" with self.client.start_session() as s: self.assertEqual(s.with_transaction(callback2), "Foo") @client_context.require_transactions def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test def callback(session): coll.insert_one({}, session=session) err: dict = { "ok": 0, "errmsg": "Transaction 7819 has been aborted.", "code": 251, "codeName": "NoSuchTransaction", "errorLabels": ["TransientTransactionError"], } raise OperationFailure(err["errmsg"], err["code"], err) # Create the collection. coll.insert_one({}) listener.results.clear() with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(OperationFailure): s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) @client_context.require_test_commands @client_context.require_transactions def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test def callback(session): coll.insert_one({}, session=session) # Create the collection. coll.insert_one({}) self.set_fail_point( { "configureFailPoint": "failCommand", "mode": {"times": 1}, "data": { "failCommands": ["commitTransaction"], "errorCode": 251, # NoSuchTransaction }, } ) self.addCleanup(self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}) listener.results.clear() with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(OperationFailure): s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) @client_context.require_test_commands @client_context.require_transactions def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test def callback(session): coll.insert_one({}, session=session) # Create the collection. coll.insert_one({}) self.set_fail_point( { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": {"failCommands": ["commitTransaction"], "closeConnection": True}, } ) self.addCleanup(self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}) listener.results.clear() with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(ConnectionFailure): s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic # retry). self.assertEqual( listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] ) # Tested here because this supports Motor's convenient transactions API. @client_context.require_transactions def test_in_transaction_property(self): client = client_context.client coll = client.test.testcollection coll.insert_one({}) self.addCleanup(coll.drop) with client.start_session() as s: self.assertFalse(s.in_transaction) s.start_transaction() self.assertTrue(s.in_transaction) coll.insert_one({}, session=s) self.assertTrue(s.in_transaction) s.commit_transaction() self.assertFalse(s.in_transaction) with client.start_session() as s: s.start_transaction() # commit empty transaction s.commit_transaction() self.assertFalse(s.in_transaction) with client.start_session() as s: s.start_transaction() s.abort_transaction() self.assertFalse(s.in_transaction) # Using a callback def callback(session): self.assertTrue(session.in_transaction) with client.start_session() as s: self.assertFalse(s.in_transaction) s.with_transaction(callback) self.assertFalse(s.in_transaction) def create_test(scenario_def, test, name): @client_context.require_test_commands @client_context.require_transactions def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator(create_test, TestTransactions, TEST_PATH) test_creator.create_tests() TestCreator( create_test, TestTransactionsConvenientAPI, TestTransactionsConvenientAPI.TEST_PATH ).create_tests() if __name__ == "__main__": unittest.main()