# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for testing driver specs.""" from __future__ import annotations import asyncio import functools import os import time import unittest from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs from test.helpers import ConcurrentRunner from test.utils_shared import ( CMAPListener, CompareType, EventListener, OvertCommandListener, ScenarioDict, ServerAndTopologyEventListener, camel_to_snake, camel_to_snake_args, parse_spec_options, prepare_spec_arguments, ) from typing import List from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError from pymongo.lock import _cond_wait, _create_condition, _create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult from pymongo.synchronous import client_session from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor from pymongo.write_concern import WriteConcern _IS_SYNC = True class SpecRunnerThread(ConcurrentRunner): def __init__(self, name): super().__init__(name=name) self.exc = None self.daemon = True self.cond = _create_condition(_create_lock()) self.ops = [] def schedule(self, work): self.ops.append(work) with self.cond: self.cond.notify() def stop(self): self.stopped = True with self.cond: self.cond.notify() def run(self): while not self.stopped or self.ops: if not self.ops: with self.cond: _cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) work() except Exception as exc: self.exc = exc self.stop() class SpecTestCreator: """Class to create test cases from specifications.""" def __init__(self, create_test, test_class, test_path): """Create a TestCreator object. :Parameters: - `create_test`: callback that returns a test case. The callback must accept the following arguments - a dictionary containing the entire test specification (the `scenario_def`), a dictionary containing the specification for which the test case will be generated (the `test_def`). - `test_class`: the unittest.TestCase class in which to create the test case. - `test_path`: path to the directory containing the JSON files with the test specifications. """ self._create_test = create_test self._test_class = test_class self.test_path = test_path def _ensure_min_max_server_version(self, scenario_def, method): """Test modifier that enforces a version range for the server on a test case. """ if "minServerVersion" in scenario_def: min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) if min_ver is not None: method = client_context.require_version_min(*min_ver)(method) if "maxServerVersion" in scenario_def: max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) if max_ver is not None: method = client_context.require_version_max(*max_ver)(method) return method @staticmethod def valid_topology(run_on_req): return client_context.is_topology_type( run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) ) @staticmethod def min_server_version(run_on_req): version = run_on_req.get("minServerVersion") if version: min_ver = tuple(int(elt) for elt in version.split(".")) return client_context.version >= min_ver return True @staticmethod def max_server_version(run_on_req): version = run_on_req.get("maxServerVersion") if version: max_ver = tuple(int(elt) for elt in version.split(".")) return client_context.version <= max_ver return True @staticmethod def valid_auth_enabled(run_on_req): if "authEnabled" in run_on_req: if run_on_req["authEnabled"]: return client_context.auth_enabled return not client_context.auth_enabled return True def should_run_on(self, scenario_def): run_on = scenario_def.get("runOn", []) if not run_on: # Always run these tests. return True for req in run_on: if ( self.valid_topology(req) and self.min_server_version(req) and self.max_server_version(req) and self.valid_auth_enabled(req) ): return True return False def ensure_run_on(self, scenario_def, method): """Test modifier that enforces a 'runOn' on a test case.""" def predicate(): return self.should_run_on(scenario_def) return client_context._require(predicate, "runOn not satisfied", method) def tests(self, scenario_def): """Allow CMAP spec test to override the location of test.""" return scenario_def["tests"] def _create_tests(self): for dirpath, _, filenames in os.walk(self.test_path): dirname = os.path.split(dirpath)[-1] for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 # Use tz_aware=False to match how CodecOptions decodes # dates. opts = json_util.JSONOptions(tz_aware=False) scenario_def = ScenarioDict( json_util.loads(scenario_stream.read(), json_options=opts) ) test_type = os.path.splitext(filename)[0] # Construct test from scenario. for test_def in self.tests(scenario_def): test_name = "test_{}_{}_{}".format( dirname, test_type.replace("-", "_").replace(".", "_"), str(test_def["description"].replace(" ", "_").replace(".", "_")), ) new_test = self._create_test(scenario_def, test_def, test_name) new_test = self._ensure_min_max_server_version(scenario_def, new_test) new_test = self.ensure_run_on(scenario_def, new_test) new_test.__name__ = test_name setattr(self._test_class, new_test.__name__, new_test) def create_tests(self): if _IS_SYNC: self._create_tests() else: asyncio.run(self._create_tests()) class SpecRunner(IntegrationTest): mongos_clients: List knobs: client_knobs listener: EventListener def setUp(self) -> None: super().setUp() self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None def tearDown(self) -> None: self.knobs.disable() def set_fail_point(self, command_args): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: self.configure_fail_point(client, command_args) def targeted_fail_point(self, session, fail_point): """Run the targetedFailPoint test operation. Enable the fail point on the session's pinned mongos. """ clients = {c.address: c for c in self.mongos_clients} client = clients[session._pinned_address] self.configure_fail_point(client, fail_point) self.addCleanup(self.set_fail_point, {"mode": "off"}) def assert_session_pinned(self, session): """Run the assertSessionPinned test operation. Assert that the given session is pinned. """ self.assertIsNotNone(session._transaction.pinned_address) def assert_session_unpinned(self, session): """Run the assertSessionUnpinned test operation. Assert that the given session is not pinned. """ self.assertIsNone(session._pinned_address) self.assertIsNone(session._transaction.pinned_address) def assert_collection_exists(self, database, collection): """Run the assertCollectionExists test operation.""" db = self.client[database] self.assertIn(collection, db.list_collection_names()) def assert_collection_not_exists(self, database, collection): """Run the assertCollectionNotExists test operation.""" db = self.client[database] self.assertNotIn(collection, db.list_collection_names()) def assert_index_exists(self, database, collection, index): """Run the assertIndexExists test operation.""" coll = self.client[database][collection] self.assertIn(index, [doc["name"] for doc in coll.list_indexes()]) def assert_index_not_exists(self, database, collection, index): """Run the assertIndexNotExists test operation.""" coll = self.client[database][collection] self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()]) def wait(self, ms): """Run the "wait" test operation.""" time.sleep(ms / 1000.0) def assertErrorLabelsContain(self, exc, expected_labels): labels = [l for l in expected_labels if exc.has_error_label(l)] self.assertEqual(labels, expected_labels) def assertErrorLabelsOmit(self, exc, omit_labels): for label in omit_labels: self.assertFalse( exc.has_error_label(label), msg=f"error labels should not contain {label}" ) def kill_all_sessions(self): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: try: client.admin.command("killAllSessions", []) except (OperationFailure, AutoReconnect): # "operation was interrupted" by killing the command's # own session. # On 8.0+ killAllSessions sometimes returns a network error. pass def check_command_result(self, expected_result, result): # Only compare the keys in the expected result. filtered_result = {} for key in expected_result: try: filtered_result[key] = result[key] except KeyError: pass self.assertEqual(filtered_result, expected_result) # TODO: factor the following function with test_crud.py. def check_result(self, expected_result, result): if isinstance(result, _WriteResult): for res in expected_result: prop = camel_to_snake(res) # SPEC-869: Only BulkWriteResult has upserted_count. if prop == "upserted_count" and not isinstance(result, BulkWriteResult): if result.upserted_id is not None: upserted_count = 1 else: upserted_count = 0 self.assertEqual(upserted_count, expected_result[res], prop) elif prop == "inserted_ids": # BulkWriteResult does not have inserted_ids. if isinstance(result, BulkWriteResult): self.assertEqual(len(expected_result[res]), result.inserted_count) else: # InsertManyResult may be compared to [id1] from the # crud spec or {"0": id1} from the retryable write spec. ids = expected_result[res] if isinstance(ids, dict): ids = [ids[str(i)] for i in range(len(ids))] self.assertEqual(ids, result.inserted_ids, prop) elif prop == "upserted_ids": # Convert indexes from strings to integers. ids = expected_result[res] expected_ids = {} for str_index in ids: expected_ids[int(str_index)] = ids[str_index] self.assertEqual(expected_ids, result.upserted_ids, prop) else: self.assertEqual(getattr(result, prop), expected_result[res], prop) return True else: def _helper(expected_result, result): if isinstance(expected_result, abc.Mapping): for i in expected_result.keys(): self.assertEqual(expected_result[i], result[i]) elif isinstance(expected_result, list): for i, k in zip(expected_result, result): _helper(i, k) else: self.assertEqual(expected_result, result) _helper(expected_result, result) return None def get_object_name(self, op): """Allow subclasses to override handling of 'object' Transaction spec says 'object' is required. """ return op["object"] @staticmethod def parse_options(opts): return parse_spec_options(opts) def run_operation(self, sessions, collection, operation): original_collection = collection name = camel_to_snake(operation["name"]) if name == "run_command": name = "command" elif name == "download_by_name": name = "open_download_stream_by_name" elif name == "download": name = "open_download_stream" elif name == "map_reduce": self.skipTest("PyMongo does not support mapReduce") elif name == "count": self.skipTest("PyMongo does not support count") database = collection.database collection = database.get_collection(collection.name) if "collectionOptions" in operation: collection = collection.with_options( **self.parse_options(operation["collectionOptions"]) ) object_name = self.get_object_name(operation) if object_name == "gridfsbucket": # Only create the GridFSBucket when we need it (for the gridfs # retryable reads tests). obj = GridFSBucket(database, bucket_name=collection.name) else: objects = { "client": database.client, "database": database, "collection": collection, "testRunner": self, } objects.update(sessions) obj = objects[object_name] # Combine arguments with options and handle special cases. arguments = operation.get("arguments", {}) arguments.update(arguments.pop("options", {})) self.parse_options(arguments) cmd = getattr(obj, name) with_txn_callback = functools.partial( self.run_operations, sessions, original_collection, in_with_transaction=True ) prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) if name == "run_on_thread": args = {"sessions": sessions, "collection": collection} args.update(arguments) arguments = args if not _IS_SYNC and iscoroutinefunction(cmd): result = cmd(**dict(arguments)) else: result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addCleanup(result.close) if name == "aggregate": if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: # Read from the primary to ensure causal consistency. out = collection.database.get_collection( arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY ) return out.find() if "download" in name: result = Binary(result.read()) if isinstance(result, Cursor) or isinstance(result, CommandCursor): return result.to_list() return result def allowable_errors(self, op): """Allow encryption spec to override expected error classes.""" return (PyMongoError,) def _run_op(self, sessions, collection, op, in_with_transaction): expected_result = op.get("result") if expect_error(op): with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: self.run_operation(sessions, collection, op.copy()) exc = context.exception if expect_error_message(expected_result): if isinstance(exc, BulkWriteError): errmsg = str(exc.details).lower() else: errmsg = str(exc).lower() self.assertIn(expected_result["errorContains"].lower(), errmsg) if expect_error_code(expected_result): self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) if expect_error_labels_contain(expected_result): self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) if expect_error_labels_omit(expected_result): self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) if expect_timeout_error(expected_result): self.assertIsInstance(exc, PyMongoError) if not exc.timeout: # Re-raise the exception for better diagnostics. raise exc # Reraise the exception if we're in the with_transaction # callback. if in_with_transaction: raise context.exception else: result = self.run_operation(sessions, collection, op.copy()) if "result" in op: if op["name"] == "runCommand": self.check_command_result(expected_result, result) else: self.check_result(expected_result, result) def run_operations(self, sessions, collection, ops, in_with_transaction=False): for op in ops: self._run_op(sessions, collection, op, in_with_transaction) # TODO: factor with test_command_monitoring.py def check_events(self, test, listener, session_ids): events = listener.started_events if not len(test["expectations"]): return # Give a nicer message when there are missing or extra events cmds = decode_raw([event.command for event in events]) self.assertEqual(len(events), len(test["expectations"]), cmds) for i, expectation in enumerate(test["expectations"]): event_type = next(iter(expectation)) event = events[i] # The tests substitute 42 for any number other than 0. if event.command_name == "getMore" and event.command["getMore"]: event.command["getMore"] = Int64(42) elif event.command_name == "killCursors": event.command["cursors"] = [Int64(42)] elif event.command_name == "update": # TODO: remove this once PYTHON-1744 is done. # Add upsert and multi fields back into expectations. updates = expectation[event_type]["command"]["updates"] for update in updates: update.setdefault("upsert", False) update.setdefault("multi", False) # Replace afterClusterTime: 42 with actual afterClusterTime. expected_cmd = expectation[event_type]["command"] expected_read_concern = expected_cmd.get("readConcern") if expected_read_concern is not None: time = expected_read_concern.get("afterClusterTime") if time == 42: actual_time = event.command.get("readConcern", {}).get("afterClusterTime") if actual_time is not None: expected_read_concern["afterClusterTime"] = actual_time recovery_token = expected_cmd.get("recoveryToken") if recovery_token == 42: expected_cmd["recoveryToken"] = CompareType(dict) # Replace lsid with a name like "session0" to match test. if "lsid" in event.command: for name, lsid in session_ids.items(): if event.command["lsid"] == lsid: event.command["lsid"] = name break for attr, expected in expectation[event_type].items(): actual = getattr(event, attr) expected = wrap_types(expected) if isinstance(expected, dict): for key, val in expected.items(): if val is None: if key in actual: self.fail(f"Unexpected key [{key}] in {actual!r}") elif key not in actual: self.fail(f"Expected key [{key}] in {actual!r}") else: self.assertEqual( val, decode_raw(actual[key]), f"Key [{key}] in {actual}" ) else: self.assertEqual(actual, expected) def maybe_skip_scenario(self, test): if test.get("skipReason"): self.skipTest(test.get("skipReason")) def get_scenario_db_name(self, scenario_def): """Allow subclasses to override a test's database name.""" return scenario_def["database_name"] def get_scenario_coll_name(self, scenario_def): """Allow subclasses to override a test's collection name.""" return scenario_def["collection_name"] def get_outcome_coll_name(self, outcome, collection): """Allow subclasses to override outcome collection.""" return collection.name def run_test_ops(self, sessions, collection, test): """Added to allow retryable writes spec to override a test's operation. """ self.run_operations(sessions, collection, test["operations"]) def parse_client_options(self, opts): """Allow encryption spec to override a clientOptions parsing.""" # Convert test['clientOptions'] to dict to avoid a Jython bug using # "**" with ScenarioDict. return dict(opts) def setup_scenario(self, scenario_def): """Allow specs to override a test's setup.""" db_name = self.get_scenario_db_name(scenario_def) coll_name = self.get_scenario_coll_name(scenario_def) documents = scenario_def["data"] # Setup the collection with as few majority writes as possible. db = client_context.client.get_database(db_name) coll_exists = bool(db.list_collection_names(filter={"name": coll_name})) if coll_exists: db[coll_name].delete_many({}) # Only use majority wc only on the final write. wc = WriteConcern(w="majority") if documents: db.get_collection(coll_name, write_concern=wc).insert_many(documents) elif not coll_exists: # Ensure collection exists. db.create_collection(coll_name, write_concern=wc) def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) # Kill all sessions before and after each test to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. self.kill_all_sessions() self.addCleanup(self.kill_all_sessions) self.setup_scenario(scenario_def) database_name = self.get_scenario_db_name(scenario_def) collection_name = self.get_scenario_coll_name(scenario_def) # SPEC-1245 workaround StaleDbVersion on distinct for c in self.mongos_clients: c[database_name][collection_name].distinct("x") # Configure the fail point before creating the client. if "failPoint" in test: fp = test["failPoint"] self.set_fail_point(fp) self.addCleanup( self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} ) listener = OvertCommandListener() pool_listener = CMAPListener() server_listener = ServerAndTopologyEventListener() # Create a new client, to avoid interference from pooled sessions. client_options = self.parse_client_options(test["clientOptions"]) # MMAPv1 does not support retryable writes. if client_options.get("retryWrites") is True and client_context.storage_engine == "mmapv1": self.skipTest("MMAPv1 does not support retryWrites=True") use_multi_mongos = test["useMultipleMongoses"] host = None if use_multi_mongos: if client_context.load_balancer: host = client_context.MULTI_MONGOS_LB_URI elif client_context.is_mongos: host = client_context.mongos_seeds() client = self.rs_client( h=host, event_listeners=[listener, pool_listener, server_listener], **client_options ) self.scenario_client = client self.listener = listener self.pool_listener = pool_listener self.server_listener = server_listener # Create session0 and session1. sessions = {} session_ids = {} for i in range(2): # Don't attempt to create sessions if they are not supported by # the running server version. if not client_context.sessions_enabled: break session_name = "session%d" % i opts = camel_to_snake_args(test["sessionOptions"][session_name]) if "default_transaction_options" in opts: txn_opts = self.parse_options(opts["default_transaction_options"]) txn_opts = client_session.TransactionOptions(**txn_opts) opts["default_transaction_options"] = txn_opts s = client.start_session(**dict(opts)) sessions[session_name] = s # Store lsid so we can access it after end_session, in check_events. session_ids[session_name] = s.session_id self.addCleanup(end_sessions, sessions) collection = client[database_name][collection_name] self.run_test_ops(sessions, collection, test) end_sessions(sessions) self.check_events(test, listener, session_ids) # Disable fail points. if "failPoint" in test: fp = test["failPoint"] self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"}) # Assert final state is expected. outcome = test["outcome"] expected_c = outcome.get("collection") if expected_c is not None: outcome_coll_name = self.get_outcome_coll_name(outcome, collection) # Read from the primary with local read concern to ensure causal # consistency. outcome_coll = client_context.client[collection.database.name].get_collection( outcome_coll_name, read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. self.assertEqual(wrap_types(expected_c["data"]), actual_data) def expect_any_error(op): if isinstance(op, dict): return op.get("error") return False def expect_error_message(expected_result): if isinstance(expected_result, dict): return isinstance(expected_result["errorContains"], str) return False def expect_error_code(expected_result): if isinstance(expected_result, dict): return expected_result["errorCodeName"] return False def expect_error_labels_contain(expected_result): if isinstance(expected_result, dict): return expected_result["errorLabelsContain"] return False def expect_error_labels_omit(expected_result): if isinstance(expected_result, dict): return expected_result["errorLabelsOmit"] return False def expect_timeout_error(expected_result): if isinstance(expected_result, dict): return expected_result["isTimeoutError"] return False def expect_error(op): expected_result = op.get("result") return ( expect_any_error(op) or expect_error_message(expected_result) or expect_error_code(expected_result) or expect_error_labels_contain(expected_result) or expect_error_labels_omit(expected_result) or expect_timeout_error(expected_result) ) def end_sessions(sessions): for s in sessions.values(): # Aborts the transaction if it's open. s.end_session() def decode_raw(val): """Decode RawBSONDocuments in the given container.""" if isinstance(val, (list, abc.Mapping)): return decode(encode({"v": val}))["v"] return val TYPES = { "binData": Binary, "long": Int64, "int": int, "string": str, "objectId": ObjectId, "object": dict, "array": list, } def wrap_types(val): """Support $$type assertion in command results.""" if isinstance(val, list): return [wrap_types(v) for v in val] if isinstance(val, abc.Mapping): typ = val.get("$$type") if typ: if isinstance(typ, str): types = TYPES[typ] else: types = tuple(TYPES[t] for t in typ) return CompareType(types) d = {} for key in val: d[key] = wrap_types(val[key]) return d return val