685 lines
26 KiB
Python
685 lines
26 KiB
Python
# Copyright 2019-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utilities for testing driver specs."""
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import threading
|
|
from collections import abc
|
|
from test import IntegrationTest, client_context, client_knobs
|
|
from test.utils import (
|
|
CMAPListener,
|
|
CompareType,
|
|
EventListener,
|
|
OvertCommandListener,
|
|
ServerAndTopologyEventListener,
|
|
camel_to_snake,
|
|
camel_to_snake_args,
|
|
parse_spec_options,
|
|
prepare_spec_arguments,
|
|
rs_client,
|
|
)
|
|
from typing import List
|
|
|
|
from bson import ObjectId, decode, encode
|
|
from bson.binary import Binary
|
|
from bson.int64 import Int64
|
|
from bson.son import SON
|
|
from gridfs import GridFSBucket
|
|
from pymongo import client_session
|
|
from pymongo.command_cursor import CommandCursor
|
|
from pymongo.cursor import Cursor
|
|
from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError
|
|
from pymongo.read_concern import ReadConcern
|
|
from pymongo.read_preferences import ReadPreference
|
|
from pymongo.results import BulkWriteResult, _WriteResult
|
|
from pymongo.write_concern import WriteConcern
|
|
|
|
|
|
class SpecRunnerThread(threading.Thread):
|
|
def __init__(self, name):
|
|
super().__init__()
|
|
self.name = name
|
|
self.exc = None
|
|
self.daemon = True
|
|
self.cond = threading.Condition()
|
|
self.ops = []
|
|
self.stopped = False
|
|
|
|
def schedule(self, work):
|
|
self.ops.append(work)
|
|
with self.cond:
|
|
self.cond.notify()
|
|
|
|
def stop(self):
|
|
self.stopped = True
|
|
with self.cond:
|
|
self.cond.notify()
|
|
|
|
def run(self):
|
|
while not self.stopped or self.ops:
|
|
if not self.ops:
|
|
with self.cond:
|
|
self.cond.wait(10)
|
|
if self.ops:
|
|
try:
|
|
work = self.ops.pop(0)
|
|
work()
|
|
except Exception as exc:
|
|
self.exc = exc
|
|
self.stop()
|
|
|
|
|
|
class SpecRunner(IntegrationTest):
|
|
mongos_clients: List
|
|
knobs: client_knobs
|
|
listener: EventListener
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls.mongos_clients = []
|
|
|
|
# Speed up the tests by decreasing the heartbeat frequency.
|
|
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
|
cls.knobs.enable()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.knobs.disable()
|
|
super().tearDownClass()
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.targets = {}
|
|
self.listener = None # type: ignore
|
|
self.pool_listener = None
|
|
self.server_listener = None
|
|
self.maxDiff = None
|
|
|
|
def _set_fail_point(self, client, command_args):
|
|
cmd = SON([("configureFailPoint", "failCommand")])
|
|
cmd.update(command_args)
|
|
client.admin.command(cmd)
|
|
|
|
def set_fail_point(self, command_args):
|
|
clients = self.mongos_clients if self.mongos_clients else [self.client]
|
|
for client in clients:
|
|
self._set_fail_point(client, command_args)
|
|
|
|
def targeted_fail_point(self, session, fail_point):
|
|
"""Run the targetedFailPoint test operation.
|
|
|
|
Enable the fail point on the session's pinned mongos.
|
|
"""
|
|
clients = {c.address: c for c in self.mongos_clients}
|
|
client = clients[session._pinned_address]
|
|
self._set_fail_point(client, fail_point)
|
|
self.addCleanup(self.set_fail_point, {"mode": "off"})
|
|
|
|
def assert_session_pinned(self, session):
|
|
"""Run the assertSessionPinned test operation.
|
|
|
|
Assert that the given session is pinned.
|
|
"""
|
|
self.assertIsNotNone(session._transaction.pinned_address)
|
|
|
|
def assert_session_unpinned(self, session):
|
|
"""Run the assertSessionUnpinned test operation.
|
|
|
|
Assert that the given session is not pinned.
|
|
"""
|
|
self.assertIsNone(session._pinned_address)
|
|
self.assertIsNone(session._transaction.pinned_address)
|
|
|
|
def assert_collection_exists(self, database, collection):
|
|
"""Run the assertCollectionExists test operation."""
|
|
db = self.client[database]
|
|
self.assertIn(collection, db.list_collection_names())
|
|
|
|
def assert_collection_not_exists(self, database, collection):
|
|
"""Run the assertCollectionNotExists test operation."""
|
|
db = self.client[database]
|
|
self.assertNotIn(collection, db.list_collection_names())
|
|
|
|
def assert_index_exists(self, database, collection, index):
|
|
"""Run the assertIndexExists test operation."""
|
|
coll = self.client[database][collection]
|
|
self.assertIn(index, [doc["name"] for doc in coll.list_indexes()])
|
|
|
|
def assert_index_not_exists(self, database, collection, index):
|
|
"""Run the assertIndexNotExists test operation."""
|
|
coll = self.client[database][collection]
|
|
self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()])
|
|
|
|
def assertErrorLabelsContain(self, exc, expected_labels):
|
|
labels = [l for l in expected_labels if exc.has_error_label(l)]
|
|
self.assertEqual(labels, expected_labels)
|
|
|
|
def assertErrorLabelsOmit(self, exc, omit_labels):
|
|
for label in omit_labels:
|
|
self.assertFalse(
|
|
exc.has_error_label(label), msg=f"error labels should not contain {label}"
|
|
)
|
|
|
|
def kill_all_sessions(self):
|
|
clients = self.mongos_clients if self.mongos_clients else [self.client]
|
|
for client in clients:
|
|
try:
|
|
client.admin.command("killAllSessions", [])
|
|
except OperationFailure:
|
|
# "operation was interrupted" by killing the command's
|
|
# own session.
|
|
pass
|
|
|
|
def check_command_result(self, expected_result, result):
|
|
# Only compare the keys in the expected result.
|
|
filtered_result = {}
|
|
for key in expected_result:
|
|
try:
|
|
filtered_result[key] = result[key]
|
|
except KeyError:
|
|
pass
|
|
self.assertEqual(filtered_result, expected_result)
|
|
|
|
# TODO: factor the following function with test_crud.py.
|
|
def check_result(self, expected_result, result):
|
|
if isinstance(result, _WriteResult):
|
|
for res in expected_result:
|
|
prop = camel_to_snake(res)
|
|
# SPEC-869: Only BulkWriteResult has upserted_count.
|
|
if prop == "upserted_count" and not isinstance(result, BulkWriteResult):
|
|
if result.upserted_id is not None:
|
|
upserted_count = 1
|
|
else:
|
|
upserted_count = 0
|
|
self.assertEqual(upserted_count, expected_result[res], prop)
|
|
elif prop == "inserted_ids":
|
|
# BulkWriteResult does not have inserted_ids.
|
|
if isinstance(result, BulkWriteResult):
|
|
self.assertEqual(len(expected_result[res]), result.inserted_count)
|
|
else:
|
|
# InsertManyResult may be compared to [id1] from the
|
|
# crud spec or {"0": id1} from the retryable write spec.
|
|
ids = expected_result[res]
|
|
if isinstance(ids, dict):
|
|
ids = [ids[str(i)] for i in range(len(ids))]
|
|
|
|
self.assertEqual(ids, result.inserted_ids, prop)
|
|
elif prop == "upserted_ids":
|
|
# Convert indexes from strings to integers.
|
|
ids = expected_result[res]
|
|
expected_ids = {}
|
|
for str_index in ids:
|
|
expected_ids[int(str_index)] = ids[str_index]
|
|
self.assertEqual(expected_ids, result.upserted_ids, prop)
|
|
else:
|
|
self.assertEqual(getattr(result, prop), expected_result[res], prop)
|
|
|
|
return True
|
|
else:
|
|
|
|
def _helper(expected_result, result):
|
|
if isinstance(expected_result, abc.Mapping):
|
|
for i in expected_result.keys():
|
|
self.assertEqual(expected_result[i], result[i])
|
|
|
|
elif isinstance(expected_result, list):
|
|
for i, k in zip(expected_result, result):
|
|
_helper(i, k)
|
|
else:
|
|
self.assertEqual(expected_result, result)
|
|
|
|
_helper(expected_result, result)
|
|
return None
|
|
|
|
def get_object_name(self, op):
|
|
"""Allow subclasses to override handling of 'object'
|
|
|
|
Transaction spec says 'object' is required.
|
|
"""
|
|
return op["object"]
|
|
|
|
@staticmethod
|
|
def parse_options(opts):
|
|
return parse_spec_options(opts)
|
|
|
|
def run_operation(self, sessions, collection, operation):
|
|
original_collection = collection
|
|
name = camel_to_snake(operation["name"])
|
|
if name == "run_command":
|
|
name = "command"
|
|
elif name == "download_by_name":
|
|
name = "open_download_stream_by_name"
|
|
elif name == "download":
|
|
name = "open_download_stream"
|
|
elif name == "map_reduce":
|
|
self.skipTest("PyMongo does not support mapReduce")
|
|
elif name == "count":
|
|
self.skipTest("PyMongo does not support count")
|
|
|
|
database = collection.database
|
|
collection = database.get_collection(collection.name)
|
|
if "collectionOptions" in operation:
|
|
collection = collection.with_options(
|
|
**self.parse_options(operation["collectionOptions"])
|
|
)
|
|
|
|
object_name = self.get_object_name(operation)
|
|
if object_name == "gridfsbucket":
|
|
# Only create the GridFSBucket when we need it (for the gridfs
|
|
# retryable reads tests).
|
|
obj = GridFSBucket(database, bucket_name=collection.name)
|
|
else:
|
|
objects = {
|
|
"client": database.client,
|
|
"database": database,
|
|
"collection": collection,
|
|
"testRunner": self,
|
|
}
|
|
objects.update(sessions)
|
|
obj = objects[object_name]
|
|
|
|
# Combine arguments with options and handle special cases.
|
|
arguments = operation.get("arguments", {})
|
|
arguments.update(arguments.pop("options", {}))
|
|
self.parse_options(arguments)
|
|
|
|
cmd = getattr(obj, name)
|
|
|
|
with_txn_callback = functools.partial(
|
|
self.run_operations, sessions, original_collection, in_with_transaction=True
|
|
)
|
|
prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback)
|
|
|
|
if name == "run_on_thread":
|
|
args = {"sessions": sessions, "collection": collection}
|
|
args.update(arguments)
|
|
arguments = args
|
|
|
|
result = cmd(**dict(arguments))
|
|
# Cleanup open change stream cursors.
|
|
if name == "watch":
|
|
self.addCleanup(result.close)
|
|
|
|
if name == "aggregate":
|
|
if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
|
|
# Read from the primary to ensure causal consistency.
|
|
out = collection.database.get_collection(
|
|
arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY
|
|
)
|
|
return out.find()
|
|
if "download" in name:
|
|
result = Binary(result.read())
|
|
|
|
if isinstance(result, Cursor) or isinstance(result, CommandCursor):
|
|
return list(result)
|
|
|
|
return result
|
|
|
|
def allowable_errors(self, op):
|
|
"""Allow encryption spec to override expected error classes."""
|
|
return (PyMongoError,)
|
|
|
|
def _run_op(self, sessions, collection, op, in_with_transaction):
|
|
expected_result = op.get("result")
|
|
if expect_error(op):
|
|
with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context:
|
|
self.run_operation(sessions, collection, op.copy())
|
|
exc = context.exception
|
|
if expect_error_message(expected_result):
|
|
if isinstance(exc, BulkWriteError):
|
|
errmsg = str(exc.details).lower()
|
|
else:
|
|
errmsg = str(exc).lower()
|
|
self.assertIn(expected_result["errorContains"].lower(), errmsg)
|
|
if expect_error_code(expected_result):
|
|
self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName"))
|
|
if expect_error_labels_contain(expected_result):
|
|
self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"])
|
|
if expect_error_labels_omit(expected_result):
|
|
self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"])
|
|
if expect_timeout_error(expected_result):
|
|
self.assertIsInstance(exc, PyMongoError)
|
|
if not exc.timeout:
|
|
# Re-raise the exception for better diagnostics.
|
|
raise exc
|
|
|
|
# Reraise the exception if we're in the with_transaction
|
|
# callback.
|
|
if in_with_transaction:
|
|
raise context.exception
|
|
else:
|
|
result = self.run_operation(sessions, collection, op.copy())
|
|
if "result" in op:
|
|
if op["name"] == "runCommand":
|
|
self.check_command_result(expected_result, result)
|
|
else:
|
|
self.check_result(expected_result, result)
|
|
|
|
def run_operations(self, sessions, collection, ops, in_with_transaction=False):
|
|
for op in ops:
|
|
self._run_op(sessions, collection, op, in_with_transaction)
|
|
|
|
# TODO: factor with test_command_monitoring.py
|
|
def check_events(self, test, listener, session_ids):
|
|
events = listener.started_events
|
|
if not len(test["expectations"]):
|
|
return
|
|
|
|
# Give a nicer message when there are missing or extra events
|
|
cmds = decode_raw([event.command for event in events])
|
|
self.assertEqual(len(events), len(test["expectations"]), cmds)
|
|
for i, expectation in enumerate(test["expectations"]):
|
|
event_type = next(iter(expectation))
|
|
event = events[i]
|
|
|
|
# The tests substitute 42 for any number other than 0.
|
|
if event.command_name == "getMore" and event.command["getMore"]:
|
|
event.command["getMore"] = Int64(42)
|
|
elif event.command_name == "killCursors":
|
|
event.command["cursors"] = [Int64(42)]
|
|
elif event.command_name == "update":
|
|
# TODO: remove this once PYTHON-1744 is done.
|
|
# Add upsert and multi fields back into expectations.
|
|
updates = expectation[event_type]["command"]["updates"]
|
|
for update in updates:
|
|
update.setdefault("upsert", False)
|
|
update.setdefault("multi", False)
|
|
|
|
# Replace afterClusterTime: 42 with actual afterClusterTime.
|
|
expected_cmd = expectation[event_type]["command"]
|
|
expected_read_concern = expected_cmd.get("readConcern")
|
|
if expected_read_concern is not None:
|
|
time = expected_read_concern.get("afterClusterTime")
|
|
if time == 42:
|
|
actual_time = event.command.get("readConcern", {}).get("afterClusterTime")
|
|
if actual_time is not None:
|
|
expected_read_concern["afterClusterTime"] = actual_time
|
|
|
|
recovery_token = expected_cmd.get("recoveryToken")
|
|
if recovery_token == 42:
|
|
expected_cmd["recoveryToken"] = CompareType(dict)
|
|
|
|
# Replace lsid with a name like "session0" to match test.
|
|
if "lsid" in event.command:
|
|
for name, lsid in session_ids.items():
|
|
if event.command["lsid"] == lsid:
|
|
event.command["lsid"] = name
|
|
break
|
|
|
|
for attr, expected in expectation[event_type].items():
|
|
actual = getattr(event, attr)
|
|
expected = wrap_types(expected)
|
|
if isinstance(expected, dict):
|
|
for key, val in expected.items():
|
|
if val is None:
|
|
if key in actual:
|
|
self.fail(f"Unexpected key [{key}] in {actual!r}")
|
|
elif key not in actual:
|
|
self.fail(f"Expected key [{key}] in {actual!r}")
|
|
else:
|
|
self.assertEqual(
|
|
val, decode_raw(actual[key]), f"Key [{key}] in {actual}"
|
|
)
|
|
else:
|
|
self.assertEqual(actual, expected)
|
|
|
|
def maybe_skip_scenario(self, test):
|
|
if test.get("skipReason"):
|
|
self.skipTest(test.get("skipReason"))
|
|
|
|
def get_scenario_db_name(self, scenario_def):
|
|
"""Allow subclasses to override a test's database name."""
|
|
return scenario_def["database_name"]
|
|
|
|
def get_scenario_coll_name(self, scenario_def):
|
|
"""Allow subclasses to override a test's collection name."""
|
|
return scenario_def["collection_name"]
|
|
|
|
def get_outcome_coll_name(self, outcome, collection):
|
|
"""Allow subclasses to override outcome collection."""
|
|
return collection.name
|
|
|
|
def run_test_ops(self, sessions, collection, test):
|
|
"""Added to allow retryable writes spec to override a test's
|
|
operation.
|
|
"""
|
|
self.run_operations(sessions, collection, test["operations"])
|
|
|
|
def parse_client_options(self, opts):
|
|
"""Allow encryption spec to override a clientOptions parsing."""
|
|
# Convert test['clientOptions'] to dict to avoid a Jython bug using
|
|
# "**" with ScenarioDict.
|
|
return dict(opts)
|
|
|
|
def setup_scenario(self, scenario_def):
|
|
"""Allow specs to override a test's setup."""
|
|
db_name = self.get_scenario_db_name(scenario_def)
|
|
coll_name = self.get_scenario_coll_name(scenario_def)
|
|
documents = scenario_def["data"]
|
|
|
|
# Setup the collection with as few majority writes as possible.
|
|
db = client_context.client.get_database(db_name)
|
|
coll_exists = bool(db.list_collection_names(filter={"name": coll_name}))
|
|
if coll_exists:
|
|
db[coll_name].delete_many({})
|
|
# Only use majority wc only on the final write.
|
|
wc = WriteConcern(w="majority")
|
|
if documents:
|
|
db.get_collection(coll_name, write_concern=wc).insert_many(documents)
|
|
elif not coll_exists:
|
|
# Ensure collection exists.
|
|
db.create_collection(coll_name, write_concern=wc)
|
|
|
|
def run_scenario(self, scenario_def, test):
|
|
self.maybe_skip_scenario(test)
|
|
|
|
# Kill all sessions before and after each test to prevent an open
|
|
# transaction (from a test failure) from blocking collection/database
|
|
# operations during test set up and tear down.
|
|
self.kill_all_sessions()
|
|
self.addCleanup(self.kill_all_sessions)
|
|
self.setup_scenario(scenario_def)
|
|
database_name = self.get_scenario_db_name(scenario_def)
|
|
collection_name = self.get_scenario_coll_name(scenario_def)
|
|
# SPEC-1245 workaround StaleDbVersion on distinct
|
|
for c in self.mongos_clients:
|
|
c[database_name][collection_name].distinct("x")
|
|
|
|
# Configure the fail point before creating the client.
|
|
if "failPoint" in test:
|
|
fp = test["failPoint"]
|
|
self.set_fail_point(fp)
|
|
self.addCleanup(
|
|
self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
|
|
)
|
|
|
|
listener = OvertCommandListener()
|
|
pool_listener = CMAPListener()
|
|
server_listener = ServerAndTopologyEventListener()
|
|
# Create a new client, to avoid interference from pooled sessions.
|
|
client_options = self.parse_client_options(test["clientOptions"])
|
|
# MMAPv1 does not support retryable writes.
|
|
if client_options.get("retryWrites") is True and client_context.storage_engine == "mmapv1":
|
|
self.skipTest("MMAPv1 does not support retryWrites=True")
|
|
use_multi_mongos = test["useMultipleMongoses"]
|
|
host = None
|
|
if use_multi_mongos:
|
|
if client_context.load_balancer or client_context.serverless:
|
|
host = client_context.MULTI_MONGOS_LB_URI
|
|
elif client_context.is_mongos:
|
|
host = client_context.mongos_seeds()
|
|
client = rs_client(
|
|
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
|
|
)
|
|
self.scenario_client = client
|
|
self.listener = listener
|
|
self.pool_listener = pool_listener
|
|
self.server_listener = server_listener
|
|
# Close the client explicitly to avoid having too many threads open.
|
|
self.addCleanup(client.close)
|
|
|
|
# Create session0 and session1.
|
|
sessions = {}
|
|
session_ids = {}
|
|
for i in range(2):
|
|
# Don't attempt to create sessions if they are not supported by
|
|
# the running server version.
|
|
if not client_context.sessions_enabled:
|
|
break
|
|
session_name = "session%d" % i
|
|
opts = camel_to_snake_args(test["sessionOptions"][session_name])
|
|
if "default_transaction_options" in opts:
|
|
txn_opts = self.parse_options(opts["default_transaction_options"])
|
|
txn_opts = client_session.TransactionOptions(**txn_opts)
|
|
opts["default_transaction_options"] = txn_opts
|
|
|
|
s = client.start_session(**dict(opts))
|
|
|
|
sessions[session_name] = s
|
|
# Store lsid so we can access it after end_session, in check_events.
|
|
session_ids[session_name] = s.session_id
|
|
|
|
self.addCleanup(end_sessions, sessions)
|
|
|
|
collection = client[database_name][collection_name]
|
|
self.run_test_ops(sessions, collection, test)
|
|
|
|
end_sessions(sessions)
|
|
|
|
self.check_events(test, listener, session_ids)
|
|
|
|
# Disable fail points.
|
|
if "failPoint" in test:
|
|
fp = test["failPoint"]
|
|
self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"})
|
|
|
|
# Assert final state is expected.
|
|
outcome = test["outcome"]
|
|
expected_c = outcome.get("collection")
|
|
if expected_c is not None:
|
|
outcome_coll_name = self.get_outcome_coll_name(outcome, collection)
|
|
|
|
# Read from the primary with local read concern to ensure causal
|
|
# consistency.
|
|
outcome_coll = client_context.client[collection.database.name].get_collection(
|
|
outcome_coll_name,
|
|
read_preference=ReadPreference.PRIMARY,
|
|
read_concern=ReadConcern("local"),
|
|
)
|
|
actual_data = list(outcome_coll.find(sort=[("_id", 1)]))
|
|
|
|
# The expected data needs to be the left hand side here otherwise
|
|
# CompareType(Binary) doesn't work.
|
|
self.assertEqual(wrap_types(expected_c["data"]), actual_data)
|
|
|
|
|
|
def expect_any_error(op):
|
|
if isinstance(op, dict):
|
|
return op.get("error")
|
|
|
|
return False
|
|
|
|
|
|
def expect_error_message(expected_result):
|
|
if isinstance(expected_result, dict):
|
|
return isinstance(expected_result["errorContains"], str)
|
|
|
|
return False
|
|
|
|
|
|
def expect_error_code(expected_result):
|
|
if isinstance(expected_result, dict):
|
|
return expected_result["errorCodeName"]
|
|
|
|
return False
|
|
|
|
|
|
def expect_error_labels_contain(expected_result):
|
|
if isinstance(expected_result, dict):
|
|
return expected_result["errorLabelsContain"]
|
|
|
|
return False
|
|
|
|
|
|
def expect_error_labels_omit(expected_result):
|
|
if isinstance(expected_result, dict):
|
|
return expected_result["errorLabelsOmit"]
|
|
|
|
return False
|
|
|
|
|
|
def expect_timeout_error(expected_result):
|
|
if isinstance(expected_result, dict):
|
|
return expected_result["isTimeoutError"]
|
|
|
|
return False
|
|
|
|
|
|
def expect_error(op):
|
|
expected_result = op.get("result")
|
|
return (
|
|
expect_any_error(op)
|
|
or expect_error_message(expected_result)
|
|
or expect_error_code(expected_result)
|
|
or expect_error_labels_contain(expected_result)
|
|
or expect_error_labels_omit(expected_result)
|
|
or expect_timeout_error(expected_result)
|
|
)
|
|
|
|
|
|
def end_sessions(sessions):
|
|
for s in sessions.values():
|
|
# Aborts the transaction if it's open.
|
|
s.end_session()
|
|
|
|
|
|
def decode_raw(val):
|
|
"""Decode RawBSONDocuments in the given container."""
|
|
if isinstance(val, (list, abc.Mapping)):
|
|
return decode(encode({"v": val}))["v"]
|
|
return val
|
|
|
|
|
|
TYPES = {
|
|
"binData": Binary,
|
|
"long": Int64,
|
|
"int": int,
|
|
"string": str,
|
|
"objectId": ObjectId,
|
|
"object": dict,
|
|
"array": list,
|
|
}
|
|
|
|
|
|
def wrap_types(val):
|
|
"""Support $$type assertion in command results."""
|
|
if isinstance(val, list):
|
|
return [wrap_types(v) for v in val]
|
|
if isinstance(val, abc.Mapping):
|
|
typ = val.get("$$type")
|
|
if typ:
|
|
if isinstance(typ, str):
|
|
types = TYPES[typ]
|
|
else:
|
|
types = tuple(TYPES[t] for t in typ)
|
|
return CompareType(types)
|
|
d = {}
|
|
for key in val:
|
|
d[key] = wrap_types(val[key])
|
|
return d
|
|
return val
|