680 lines
25 KiB
Python
680 lines
25 KiB
Python
# Copyright 2024-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Shared utility functions and constants for the unified test format runner.
|
|
|
|
https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.md
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import binascii
|
|
import collections
|
|
import datetime
|
|
import os
|
|
import time
|
|
import types
|
|
from collections import abc
|
|
from test.helpers import (
|
|
AWS_CREDS,
|
|
AWS_CREDS_2,
|
|
AZURE_CREDS,
|
|
CA_PEM,
|
|
CLIENT_PEM,
|
|
GCP_CREDS,
|
|
KMIP_CREDS,
|
|
LOCAL_MASTER_KEY,
|
|
)
|
|
from test.utils_shared import CMAPListener, camel_to_snake, parse_collection_options
|
|
from typing import Any, Union
|
|
|
|
from bson import (
|
|
RE_TYPE,
|
|
Binary,
|
|
Code,
|
|
DBRef,
|
|
Decimal128,
|
|
Int64,
|
|
MaxKey,
|
|
MinKey,
|
|
ObjectId,
|
|
Regex,
|
|
json_util,
|
|
)
|
|
from pymongo.monitoring import (
|
|
_SENSITIVE_COMMANDS,
|
|
CommandFailedEvent,
|
|
CommandListener,
|
|
CommandStartedEvent,
|
|
CommandSucceededEvent,
|
|
ConnectionCheckedInEvent,
|
|
ConnectionCheckedOutEvent,
|
|
ConnectionCheckOutFailedEvent,
|
|
ConnectionCheckOutStartedEvent,
|
|
ConnectionClosedEvent,
|
|
ConnectionCreatedEvent,
|
|
ConnectionReadyEvent,
|
|
PoolClearedEvent,
|
|
PoolClosedEvent,
|
|
PoolCreatedEvent,
|
|
PoolReadyEvent,
|
|
ServerClosedEvent,
|
|
ServerDescriptionChangedEvent,
|
|
ServerHeartbeatFailedEvent,
|
|
ServerHeartbeatListener,
|
|
ServerHeartbeatStartedEvent,
|
|
ServerHeartbeatSucceededEvent,
|
|
ServerListener,
|
|
ServerOpeningEvent,
|
|
TopologyClosedEvent,
|
|
TopologyDescriptionChangedEvent,
|
|
TopologyEvent,
|
|
TopologyListener,
|
|
TopologyOpenedEvent,
|
|
_CommandEvent,
|
|
_ConnectionEvent,
|
|
_PoolEvent,
|
|
_ServerEvent,
|
|
_ServerHeartbeatEvent,
|
|
)
|
|
from pymongo.results import BulkWriteResult
|
|
from pymongo.server_description import ServerDescription
|
|
from pymongo.topology_description import TopologyDescription
|
|
|
|
SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS")
|
|
|
|
JSON_OPTS = json_util.JSONOptions(tz_aware=False)
|
|
|
|
IS_INTERRUPTED = False
|
|
|
|
KMS_TLS_OPTS = {
|
|
"kmip": {
|
|
"tlsCAFile": CA_PEM,
|
|
"tlsCertificateKeyFile": CLIENT_PEM,
|
|
}
|
|
}
|
|
|
|
|
|
# Build up a placeholder maps.
|
|
PLACEHOLDER_MAP = {}
|
|
for provider_name, provider_data in [
|
|
("local", {"key": LOCAL_MASTER_KEY}),
|
|
("local:name1", {"key": LOCAL_MASTER_KEY}),
|
|
("aws", AWS_CREDS),
|
|
("aws:name1", AWS_CREDS),
|
|
("aws:name2", AWS_CREDS_2),
|
|
("azure", AZURE_CREDS),
|
|
("azure:name1", AZURE_CREDS),
|
|
("gcp", GCP_CREDS),
|
|
("gcp:name1", GCP_CREDS),
|
|
("kmip", KMIP_CREDS),
|
|
("kmip:name1", KMIP_CREDS),
|
|
]:
|
|
for key, value in provider_data.items():
|
|
placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}"
|
|
PLACEHOLDER_MAP[placeholder] = value
|
|
|
|
OIDC_ENV = os.environ.get("OIDC_ENV", "test")
|
|
if OIDC_ENV == "test":
|
|
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"}
|
|
elif OIDC_ENV == "azure":
|
|
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
|
|
"ENVIRONMENT": "azure",
|
|
"TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"],
|
|
}
|
|
elif OIDC_ENV == "gcp":
|
|
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
|
|
"ENVIRONMENT": "gcp",
|
|
"TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"],
|
|
}
|
|
elif OIDC_ENV == "k8s":
|
|
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "k8s"}
|
|
|
|
|
|
def with_metaclass(meta, *bases):
|
|
"""Create a base class with a metaclass.
|
|
|
|
Vendored from six: https://github.com/benjaminp/six/blob/master/six.py
|
|
"""
|
|
|
|
# This requires a bit of explanation: the basic idea is to make a dummy
|
|
# metaclass for one level of class instantiation that replaces itself with
|
|
# the actual metaclass.
|
|
class metaclass(type):
|
|
def __new__(cls, name, this_bases, d):
|
|
# __orig_bases__ is required by PEP 560.
|
|
resolved_bases = types.resolve_bases(bases)
|
|
if resolved_bases is not bases:
|
|
d["__orig_bases__"] = bases
|
|
return meta(name, resolved_bases, d)
|
|
|
|
@classmethod
|
|
def __prepare__(cls, name, this_bases):
|
|
return meta.__prepare__(name, bases)
|
|
|
|
return type.__new__(metaclass, "temporary_class", (), {})
|
|
|
|
|
|
def parse_collection_or_database_options(options):
|
|
return parse_collection_options(options)
|
|
|
|
|
|
def parse_bulk_write_result(result):
|
|
upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids}
|
|
return {
|
|
"deletedCount": result.deleted_count,
|
|
"insertedCount": result.inserted_count,
|
|
"matchedCount": result.matched_count,
|
|
"modifiedCount": result.modified_count,
|
|
"upsertedCount": result.upserted_count,
|
|
"upsertedIds": upserted_ids,
|
|
}
|
|
|
|
|
|
def parse_client_bulk_write_individual(op_type, result):
|
|
if op_type == "insert":
|
|
return {"insertedId": result.inserted_id}
|
|
if op_type == "update":
|
|
if result.upserted_id:
|
|
return {
|
|
"matchedCount": result.matched_count,
|
|
"modifiedCount": result.modified_count,
|
|
"upsertedId": result.upserted_id,
|
|
}
|
|
else:
|
|
return {
|
|
"matchedCount": result.matched_count,
|
|
"modifiedCount": result.modified_count,
|
|
}
|
|
if op_type == "delete":
|
|
return {
|
|
"deletedCount": result.deleted_count,
|
|
}
|
|
|
|
|
|
def parse_client_bulk_write_result(result):
|
|
insert_results, update_results, delete_results = {}, {}, {}
|
|
if result.has_verbose_results:
|
|
for idx, res in result.insert_results.items():
|
|
insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res)
|
|
for idx, res in result.update_results.items():
|
|
update_results[str(idx)] = parse_client_bulk_write_individual("update", res)
|
|
for idx, res in result.delete_results.items():
|
|
delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res)
|
|
|
|
return {
|
|
"deletedCount": result.deleted_count,
|
|
"insertedCount": result.inserted_count,
|
|
"matchedCount": result.matched_count,
|
|
"modifiedCount": result.modified_count,
|
|
"upsertedCount": result.upserted_count,
|
|
"insertResults": insert_results,
|
|
"updateResults": update_results,
|
|
"deleteResults": delete_results,
|
|
}
|
|
|
|
|
|
def parse_bulk_write_error_result(error):
|
|
write_result = BulkWriteResult(error.details, True)
|
|
return parse_bulk_write_result(write_result)
|
|
|
|
|
|
def parse_client_bulk_write_error_result(error):
|
|
write_result = error.partial_result
|
|
if not write_result:
|
|
return None
|
|
return parse_client_bulk_write_result(write_result)
|
|
|
|
|
|
class EventListenerUtil(
|
|
CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener
|
|
):
|
|
def __init__(
|
|
self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map
|
|
):
|
|
self._event_types = {name.lower() for name in observe_events}
|
|
if observe_sensitive_commands:
|
|
self._observe_sensitive_commands = True
|
|
self._ignore_commands = set(ignore_commands)
|
|
else:
|
|
self._observe_sensitive_commands = False
|
|
self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands)
|
|
self._ignore_commands.add("configurefailpoint")
|
|
self._event_mapping = collections.defaultdict(list)
|
|
self.entity_map = entity_map
|
|
if store_events:
|
|
for i in store_events:
|
|
id = i["id"]
|
|
events = (i.lower() for i in i["events"])
|
|
for i in events:
|
|
self._event_mapping[i].append(id)
|
|
self.entity_map[id] = []
|
|
super().__init__()
|
|
|
|
def get_events(self, event_type):
|
|
assert event_type in ("command", "cmap", "sdam", "all"), event_type
|
|
if event_type == "all":
|
|
return list(self.events)
|
|
if event_type == "command":
|
|
return [e for e in self.events if isinstance(e, _CommandEvent)]
|
|
if event_type == "cmap":
|
|
return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))]
|
|
return [
|
|
e
|
|
for e in self.events
|
|
if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent))
|
|
]
|
|
|
|
def add_event(self, event):
|
|
event_name = type(event).__name__.lower()
|
|
if event_name in self._event_types:
|
|
super().add_event(event)
|
|
for id in self._event_mapping[event_name]:
|
|
self.entity_map[id].append(
|
|
{
|
|
"name": type(event).__name__,
|
|
"observedAt": time.time(),
|
|
"description": repr(event),
|
|
}
|
|
)
|
|
|
|
def _command_event(self, event):
|
|
if event.command_name.lower() not in self._ignore_commands:
|
|
self.add_event(event)
|
|
|
|
def started(self, event):
|
|
if isinstance(event, CommandStartedEvent):
|
|
if event.command == {}:
|
|
# Command is redacted. Observe only if flag is set.
|
|
if self._observe_sensitive_commands:
|
|
self._command_event(event)
|
|
else:
|
|
self._command_event(event)
|
|
else:
|
|
self.add_event(event)
|
|
|
|
def succeeded(self, event):
|
|
if isinstance(event, CommandSucceededEvent):
|
|
if event.reply == {}:
|
|
# Command is redacted. Observe only if flag is set.
|
|
if self._observe_sensitive_commands:
|
|
self._command_event(event)
|
|
else:
|
|
self._command_event(event)
|
|
else:
|
|
self.add_event(event)
|
|
|
|
def failed(self, event):
|
|
if isinstance(event, CommandFailedEvent):
|
|
self._command_event(event)
|
|
else:
|
|
self.add_event(event)
|
|
|
|
def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None:
|
|
self.add_event(event)
|
|
|
|
def description_changed(
|
|
self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent]
|
|
) -> None:
|
|
self.add_event(event)
|
|
|
|
def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None:
|
|
self.add_event(event)
|
|
|
|
def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None:
|
|
self.add_event(event)
|
|
|
|
|
|
binary_types = (Binary, bytes)
|
|
long_types = (Int64,)
|
|
unicode_type = str
|
|
|
|
|
|
BSON_TYPE_ALIAS_MAP = {
|
|
# https://mongodb.com/docs/manual/reference/operator/query/type/
|
|
# https://pymongo.readthedocs.io/en/stable/api/bson/index.html
|
|
"double": (float,),
|
|
"string": (str,),
|
|
"object": (abc.Mapping,),
|
|
"array": (abc.MutableSequence,),
|
|
"binData": binary_types,
|
|
"undefined": (type(None),),
|
|
"objectId": (ObjectId,),
|
|
"bool": (bool,),
|
|
"date": (datetime.datetime,),
|
|
"null": (type(None),),
|
|
"regex": (Regex, RE_TYPE),
|
|
"dbPointer": (DBRef,),
|
|
"javascript": (unicode_type, Code),
|
|
"symbol": (unicode_type,),
|
|
"javascriptWithScope": (unicode_type, Code),
|
|
"int": (int,),
|
|
"long": (Int64,),
|
|
"decimal": (Decimal128,),
|
|
"maxKey": (MaxKey,),
|
|
"minKey": (MinKey,),
|
|
}
|
|
|
|
|
|
class MatchEvaluatorUtil:
|
|
"""Utility class that implements methods for evaluating matches as per
|
|
the unified test format specification.
|
|
"""
|
|
|
|
def __init__(self, test_class):
|
|
self.test = test_class
|
|
|
|
def _operation_exists(self, spec, actual, key_to_compare):
|
|
if spec is True:
|
|
if key_to_compare is None:
|
|
assert actual is not None
|
|
else:
|
|
self.test.assertIn(key_to_compare, actual)
|
|
elif spec is False:
|
|
if key_to_compare is None:
|
|
assert actual is None
|
|
else:
|
|
self.test.assertNotIn(key_to_compare, actual)
|
|
else:
|
|
self.test.fail(f"Expected boolean value for $$exists operator, got {spec}")
|
|
|
|
def __type_alias_to_type(self, alias):
|
|
if alias not in BSON_TYPE_ALIAS_MAP:
|
|
self.test.fail(f"Unrecognized BSON type alias {alias}")
|
|
return BSON_TYPE_ALIAS_MAP[alias]
|
|
|
|
def _operation_type(self, spec, actual, key_to_compare):
|
|
if isinstance(spec, abc.MutableSequence):
|
|
permissible_types = tuple(
|
|
[t for alias in spec for t in self.__type_alias_to_type(alias)]
|
|
)
|
|
else:
|
|
permissible_types = self.__type_alias_to_type(spec)
|
|
value = actual[key_to_compare] if key_to_compare else actual
|
|
self.test.assertIsInstance(value, permissible_types)
|
|
|
|
def _operation_matchesEntity(self, spec, actual, key_to_compare):
|
|
expected_entity = self.test.entity_map[spec]
|
|
self.test.assertEqual(expected_entity, actual[key_to_compare])
|
|
|
|
def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
|
|
expected = binascii.unhexlify(spec)
|
|
value = actual[key_to_compare] if key_to_compare else actual
|
|
self.test.assertEqual(value, expected)
|
|
|
|
def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
|
|
if key_to_compare is None and not actual:
|
|
# top-level document can be None when unset
|
|
return
|
|
|
|
if key_to_compare not in actual:
|
|
# we add a dummy value for the compared key to pass map size check
|
|
actual[key_to_compare] = "dummyValue"
|
|
return
|
|
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
|
|
|
|
def _operation_sessionLsid(self, spec, actual, key_to_compare):
|
|
expected_lsid = self.test.entity_map.get_lsid_for_session(spec)
|
|
self.test.assertEqual(expected_lsid, actual[key_to_compare])
|
|
|
|
def _operation_lte(self, spec, actual, key_to_compare):
|
|
if key_to_compare not in actual:
|
|
self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}")
|
|
self.test.assertLessEqual(actual[key_to_compare], spec)
|
|
|
|
def _operation_matchAsDocument(self, spec, actual, key_to_compare):
|
|
self._match_document(spec, json_util.loads(actual[key_to_compare]), False, test=True)
|
|
|
|
def _operation_matchAsRoot(self, spec, actual, key_to_compare):
|
|
if key_to_compare:
|
|
actual = actual[key_to_compare]
|
|
self._match_document(spec, actual, True, test=True)
|
|
|
|
def _evaluate_special_operation(self, opname, spec, actual, key_to_compare):
|
|
method_name = "_operation_{}".format(opname.strip("$"))
|
|
try:
|
|
method = getattr(self, method_name)
|
|
except AttributeError:
|
|
self.test.fail(f"Unsupported special matching operator {opname}")
|
|
else:
|
|
method(spec, actual, key_to_compare)
|
|
|
|
def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None):
|
|
"""Returns True if a special operation is evaluated, False
|
|
otherwise. If the ``expectation`` map contains a single key,
|
|
value pair we check it for a special operation.
|
|
If given, ``key_to_compare`` is assumed to be the key in
|
|
``expectation`` whose corresponding value needs to be
|
|
evaluated for a possible special operation. ``key_to_compare``
|
|
is ignored when ``expectation`` has only one key.
|
|
"""
|
|
if not isinstance(expectation, abc.Mapping):
|
|
return False
|
|
|
|
is_special_op, opname, spec = False, False, False
|
|
|
|
if key_to_compare is not None:
|
|
if key_to_compare.startswith("$$"):
|
|
is_special_op = True
|
|
opname = key_to_compare
|
|
spec = expectation[key_to_compare]
|
|
key_to_compare = None
|
|
else:
|
|
nested = expectation[key_to_compare]
|
|
if isinstance(nested, abc.Mapping) and len(nested) == 1:
|
|
opname, spec = next(iter(nested.items()))
|
|
if opname.startswith("$$"):
|
|
is_special_op = True
|
|
elif len(expectation) == 1:
|
|
opname, spec = next(iter(expectation.items()))
|
|
if opname.startswith("$$"):
|
|
is_special_op = True
|
|
key_to_compare = None
|
|
|
|
if is_special_op:
|
|
self._evaluate_special_operation(
|
|
opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare
|
|
)
|
|
return True
|
|
|
|
return False
|
|
|
|
def _match_document(self, expectation, actual, is_root, test=False):
|
|
if self._evaluate_if_special_operation(expectation, actual):
|
|
return True
|
|
|
|
self.test.assertIsInstance(actual, abc.Mapping)
|
|
for key, value in expectation.items():
|
|
if self._evaluate_if_special_operation(expectation, actual, key):
|
|
continue
|
|
|
|
self.test.assertIn(key, actual)
|
|
if not self.match_result(value, actual[key], in_recursive_call=True, test=test):
|
|
return False
|
|
|
|
if not is_root:
|
|
expected_keys = set(expectation.keys())
|
|
for key, value in expectation.items():
|
|
if value == {"$$exists": False}:
|
|
expected_keys.remove(key)
|
|
if test:
|
|
self.test.assertEqual(expected_keys, set(actual.keys()))
|
|
else:
|
|
return set(expected_keys).issubset(set(actual.keys()))
|
|
return True
|
|
|
|
def match_result(self, expectation, actual, in_recursive_call=False, test=True):
|
|
if isinstance(expectation, abc.Mapping):
|
|
return self._match_document(
|
|
expectation, actual, is_root=not in_recursive_call, test=test
|
|
)
|
|
|
|
if isinstance(expectation, abc.MutableSequence):
|
|
self.test.assertIsInstance(actual, abc.MutableSequence)
|
|
for e, a in zip(expectation, actual):
|
|
if isinstance(e, abc.Mapping):
|
|
res = self._match_document(e, a, is_root=not in_recursive_call, test=test)
|
|
else:
|
|
res = self.match_result(e, a, in_recursive_call=True, test=test)
|
|
if not res:
|
|
return False
|
|
return True
|
|
|
|
# account for flexible numerics in element-wise comparison
|
|
if isinstance(expectation, (int, float)):
|
|
if test:
|
|
self.test.assertEqual(expectation, actual)
|
|
else:
|
|
return expectation == actual
|
|
else:
|
|
if test:
|
|
self.test.assertIsInstance(actual, type(expectation))
|
|
self.test.assertEqual(expectation, actual)
|
|
else:
|
|
return isinstance(actual, type(expectation)) and expectation == actual
|
|
return True
|
|
|
|
def match_server_description(self, actual: ServerDescription, spec: dict) -> None:
|
|
for field, expected in spec.items():
|
|
field = camel_to_snake(field)
|
|
if field == "type":
|
|
field = "server_type_name"
|
|
self.test.assertEqual(getattr(actual, field), expected)
|
|
|
|
def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None:
|
|
for field, expected in spec.items():
|
|
field = camel_to_snake(field)
|
|
if field == "type":
|
|
field = "topology_type_name"
|
|
self.test.assertEqual(getattr(actual, field), expected)
|
|
|
|
def match_event_fields(self, actual: Any, spec: dict) -> None:
|
|
for field, expected in spec.items():
|
|
if field == "command" and isinstance(actual, CommandStartedEvent):
|
|
command = spec["command"]
|
|
if command:
|
|
self.match_result(command, actual.command)
|
|
continue
|
|
if field == "reply" and isinstance(actual, CommandSucceededEvent):
|
|
reply = spec["reply"]
|
|
if reply:
|
|
self.match_result(reply, actual.reply)
|
|
continue
|
|
if field == "hasServiceId":
|
|
if spec["hasServiceId"]:
|
|
self.test.assertIsNotNone(actual.service_id)
|
|
self.test.assertIsInstance(actual.service_id, ObjectId)
|
|
else:
|
|
self.test.assertIsNone(actual.service_id)
|
|
continue
|
|
if field == "hasServerConnectionId":
|
|
if spec["hasServerConnectionId"]:
|
|
self.test.assertIsNotNone(actual.server_connection_id)
|
|
self.test.assertIsInstance(actual.server_connection_id, int)
|
|
else:
|
|
self.test.assertIsNone(actual.server_connection_id)
|
|
continue
|
|
if field in ("previousDescription", "newDescription"):
|
|
if isinstance(actual, ServerDescriptionChangedEvent):
|
|
self.match_server_description(
|
|
getattr(actual, camel_to_snake(field)), spec[field]
|
|
)
|
|
continue
|
|
if isinstance(actual, TopologyDescriptionChangedEvent):
|
|
self.match_topology_description(
|
|
getattr(actual, camel_to_snake(field)), spec[field]
|
|
)
|
|
continue
|
|
|
|
if field == "interruptInUseConnections":
|
|
field = "interrupt_connections"
|
|
else:
|
|
field = camel_to_snake(field)
|
|
self.test.assertEqual(getattr(actual, field), expected)
|
|
|
|
def match_event(self, expectation, actual):
|
|
name, spec = next(iter(expectation.items()))
|
|
if name == "commandStartedEvent":
|
|
self.test.assertIsInstance(actual, CommandStartedEvent)
|
|
elif name == "commandSucceededEvent":
|
|
self.test.assertIsInstance(actual, CommandSucceededEvent)
|
|
elif name == "commandFailedEvent":
|
|
self.test.assertIsInstance(actual, CommandFailedEvent)
|
|
elif name == "poolCreatedEvent":
|
|
self.test.assertIsInstance(actual, PoolCreatedEvent)
|
|
elif name == "poolReadyEvent":
|
|
self.test.assertIsInstance(actual, PoolReadyEvent)
|
|
elif name == "poolClearedEvent":
|
|
self.test.assertIsInstance(actual, PoolClearedEvent)
|
|
self.test.assertIsInstance(actual.interrupt_connections, bool)
|
|
elif name == "poolClosedEvent":
|
|
self.test.assertIsInstance(actual, PoolClosedEvent)
|
|
elif name == "connectionCreatedEvent":
|
|
self.test.assertIsInstance(actual, ConnectionCreatedEvent)
|
|
elif name == "connectionReadyEvent":
|
|
self.test.assertIsInstance(actual, ConnectionReadyEvent)
|
|
elif name == "connectionClosedEvent":
|
|
self.test.assertIsInstance(actual, ConnectionClosedEvent)
|
|
elif name == "connectionCheckOutStartedEvent":
|
|
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
|
|
elif name == "connectionCheckOutFailedEvent":
|
|
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
|
|
elif name == "connectionCheckedOutEvent":
|
|
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
|
|
elif name == "connectionCheckedInEvent":
|
|
self.test.assertIsInstance(actual, ConnectionCheckedInEvent)
|
|
elif name == "serverDescriptionChangedEvent":
|
|
self.test.assertIsInstance(actual, ServerDescriptionChangedEvent)
|
|
elif name == "serverHeartbeatStartedEvent":
|
|
self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent)
|
|
elif name == "serverHeartbeatSucceededEvent":
|
|
self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent)
|
|
elif name == "serverHeartbeatFailedEvent":
|
|
self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent)
|
|
elif name == "topologyDescriptionChangedEvent":
|
|
self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent)
|
|
elif name == "topologyOpeningEvent":
|
|
self.test.assertIsInstance(actual, TopologyOpenedEvent)
|
|
elif name == "topologyClosedEvent":
|
|
self.test.assertIsInstance(actual, TopologyClosedEvent)
|
|
else:
|
|
raise Exception(f"Unsupported event type {name}")
|
|
|
|
self.match_event_fields(actual, spec)
|
|
|
|
|
|
def coerce_result(opname, result):
|
|
"""Convert a pymongo result into the spec's result format."""
|
|
if hasattr(result, "acknowledged") and not result.acknowledged:
|
|
return {"acknowledged": False}
|
|
if opname == "bulkWrite":
|
|
return parse_bulk_write_result(result)
|
|
if opname == "clientBulkWrite":
|
|
return parse_client_bulk_write_result(result)
|
|
if opname == "insertOne":
|
|
return {"insertedId": result.inserted_id}
|
|
if opname == "insertMany":
|
|
return dict(enumerate(result.inserted_ids))
|
|
if opname in ("deleteOne", "deleteMany"):
|
|
return {"deletedCount": result.deleted_count}
|
|
if opname in ("updateOne", "updateMany", "replaceOne"):
|
|
value = {
|
|
"matchedCount": result.matched_count,
|
|
"modifiedCount": result.modified_count,
|
|
"upsertedCount": 0 if result.upserted_id is None else 1,
|
|
}
|
|
if result.upserted_id is not None:
|
|
value["upsertedId"] = result.upserted_id
|
|
return value
|
|
return result
|