diff --git a/.evergreen/resync-specs.sh b/.evergreen/resync-specs.sh index ac6944972..dca116c2d 100755 --- a/.evergreen/resync-specs.sh +++ b/.evergreen/resync-specs.sh @@ -76,6 +76,9 @@ do atlas-data-lake-testing|data_lake) cpjson atlas-data-lake-testing/tests/ data_lake ;; + bson-binary-vector|bson_binary_vector) + cpjson bson-binary-vector/tests/ bson_binary_vector + ;; bson-corpus|bson_corpus) cpjson bson-corpus/tests/ bson_corpus ;; diff --git a/bson/binary.py b/bson/binary.py index 5fe1bacd1..47c52d489 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -13,7 +13,10 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Tuple, Type, Union +import struct +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, Sequence, Tuple, Type, Union from uuid import UUID """Tools for representing BSON binary data. @@ -191,21 +194,75 @@ SENSITIVE_SUBTYPE = 8 """ +VECTOR_SUBTYPE = 9 +"""**(BETA)** BSON binary subtype for densely packed vector data. + +.. versionadded:: 4.10 +""" + + USER_DEFINED_SUBTYPE = 128 """BSON binary subtype for any user defined structure. """ +class BinaryVectorDtype(Enum): + """**(BETA)** Datatypes of vector subtype. + + :param FLOAT32: (0x27) Pack list of :class:`float` as float32 + :param INT8: (0x03) Pack list of :class:`int` in [-128, 127] as signed int8 + :param PACKED_BIT: (0x10) Pack list of :class:`int` in [0, 255] as unsigned uint8 + + The `PACKED_BIT` value represents a special case where vector values themselves + can only be of two values (0 or 1) but these are packed together into groups of 8, + a byte. In Python, these are displayed as ints in range [0, 255] + + Each value is of type bytes with a length of one. + + .. versionadded:: 4.10 + """ + + INT8 = b"\x03" + FLOAT32 = b"\x27" + PACKED_BIT = b"\x10" + + +@dataclass +class BinaryVector: + """**(BETA)** Vector of numbers along with metadata for binary interoperability. + .. versionadded:: 4.10 + """ + + __slots__ = ("data", "dtype", "padding") + + def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, padding: int = 0): + """ + :param data: Sequence of numbers representing the mathematical vector. + :param dtype: The data type stored in binary + :param padding: The number of bits in the final byte that are to be ignored + when a vector element's size is less than a byte + and the length of the vector is not a multiple of 8. + """ + self.data = data + self.dtype = dtype + self.padding = padding + + class Binary(bytes): """Representation of BSON binary data. - This is necessary because we want to represent Python strings as - the BSON string type. We need to wrap binary data so we can tell + We want to represent Python strings as the BSON string type. + We need to wrap binary data so that we can tell the difference between what should be considered binary data and what should be considered a string when we encode to BSON. - Raises TypeError if `data` is not an instance of :class:`bytes` - or `subtype` is not an instance of :class:`int`. + **(BETA)** Subtype 9 provides a space-efficient representation of 1-dimensional vector data. + Its data is prepended with two bytes of metadata. + The first (dtype) describes its data type, such as float32 or int8. + The second (padding) prescribes the number of bits to ignore in the final byte. + This is relevant when the element size of the dtype is not a multiple of 8. + + Raises TypeError if `subtype` is not an instance of :class:`int`. Raises ValueError if `subtype` is not in [0, 256). .. note:: @@ -218,7 +275,10 @@ class Binary(bytes): to use .. versionchanged:: 3.9 - Support any bytes-like type that implements the buffer protocol. + Support any bytes-like type that implements the buffer protocol. + + .. versionchanged:: 4.10 + **(BETA)** Addition of vector subtype. """ _type_marker = 5 @@ -337,6 +397,86 @@ class Binary(bytes): f"cannot decode subtype {self.subtype} to {UUID_REPRESENTATION_NAMES[uuid_representation]}" ) + @classmethod + def from_vector( + cls: Type[Binary], + vector: list[int, float], + dtype: BinaryVectorDtype, + padding: int = 0, + ) -> Binary: + """**(BETA)** Create a BSON :class:`~bson.binary.Binary` of Vector subtype from a list of Numbers. + + To interpret the representation of the numbers, a data type must be included. + See :class:`~bson.binary.BinaryVectorDtype` for available types and descriptions. + + The dtype and padding are prepended to the binary data's value. + + :param vector: List of values + :param dtype: Data type of the values + :param padding: For fractional bytes, number of bits to ignore at end of vector. + :return: Binary packed data identified by dtype and padding. + + .. versionadded:: 4.10 + """ + if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8 + format_str = "b" + if padding: + raise ValueError(f"padding does not apply to {dtype=}") + elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8 + format_str = "B" + elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32 + format_str = "f" + if padding: + raise ValueError(f"padding does not apply to {dtype=}") + else: + raise NotImplementedError("%s not yet supported" % dtype) + + metadata = struct.pack(" BinaryVector: + """**(BETA)** From the Binary, create a list of numbers, along with dtype and padding. + + :return: BinaryVector + + .. versionadded:: 4.10 + """ + + if self.subtype != VECTOR_SUBTYPE: + raise ValueError(f"Cannot decode subtype {self.subtype} as a vector.") + + position = 0 + dtype, padding = struct.unpack_from(" int: """Subtype of this binary data.""" diff --git a/doc/api/bson/binary.rst b/doc/api/bson/binary.rst index c933a687b..084fd02d5 100644 --- a/doc/api/bson/binary.rst +++ b/doc/api/bson/binary.rst @@ -21,6 +21,14 @@ .. autoclass:: UuidRepresentation :members: + .. autoclass:: BinaryVectorDtype + :members: + :show-inheritance: + + .. autoclass:: BinaryVector + :members: + + .. autoclass:: Binary(data, subtype=BINARY_SUBTYPE) :members: :show-inheritance: diff --git a/doc/async-tutorial.rst b/doc/async-tutorial.rst index caa277f9d..2ccf011d8 100644 --- a/doc/async-tutorial.rst +++ b/doc/async-tutorial.rst @@ -1,6 +1,11 @@ Async Tutorial ============== +.. warning:: This API is currently in beta, meaning the classes, methods, + and behaviors described within may change before the full release. + If you come across any bugs during your use of this API, + please file a Jira ticket in the "Python Driver" project at https://jira.mongodb.org/browse/PYTHON. + .. code-block:: pycon from pymongo import AsyncMongoClient diff --git a/doc/changelog.rst b/doc/changelog.rst index dfb3c7982..6c8b8261a 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,24 @@ Changelog ========= +Changes in Version 4.10.0 +------------------------- + +- Added provisional **(BETA)** support for a new Binary BSON subtype (9) used for efficient storage and retrieval of vectors: + densely packed arrays of numbers, all of the same type. + This includes new methods :meth:`~bson.binary.Binary.from_vector` and :meth:`~bson.binary.Binary.as_vector`. +- Added C extension use to client metadata, for example: ``{"driver": {"name": "PyMongo|c", "version": "4.10.0"}, ...}`` +- Fixed a bug where :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` could deadlock. +- Fixed a bug where PyMongo could fail to import on Windows if ``asyncio`` is misconfigured. + +Issues Resolved +............... + +See the `PyMongo 4.10 release notes in JIRA`_ for the list of resolved issues +in this release. + +.. _PyMongo 4.10 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=40553 + Changes in Version 4.9.0 ------------------------- diff --git a/pymongo/_version.py b/pymongo/_version.py index 5ff72d6cc..3de24a8e1 100644 --- a/pymongo/_version.py +++ b/pymongo/_version.py @@ -18,7 +18,7 @@ from __future__ import annotations import re from typing import List, Tuple, Union -__version__ = "4.10.0.dev0" +__version__ = "4.11.0.dev0" def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]: diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py new file mode 100644 index 000000000..46f66af62 --- /dev/null +++ b/test/asynchronous/helpers.py @@ -0,0 +1,360 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" +from __future__ import annotations + +import base64 +import gc +import multiprocessing +import os +import signal +import socket +import subprocess +import sys +import threading +import time +import traceback +import unittest +import warnings +from asyncio import iscoroutinefunction + +try: + import ipaddress + + HAVE_IPADDRESS = True +except ImportError: + HAVE_IPADDRESS = False +from functools import wraps +from typing import Any, Callable, Dict, Generator, no_type_check +from unittest import SkipTest + +from bson.son import SON +from pymongo import common, message +from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] +from pymongo.uri_parser import parse_uri + +if HAVE_SSL: + import ssl + +_IS_SYNC = False + +# Enable debug output for uncollectable objects. PyPy does not have set_debug. +if hasattr(gc, "set_debug"): + gc.set_debug( + gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0) + ) + +# The host and port of a single mongod or mongos, or the seed host +# for a replica set. +host = os.environ.get("DB_IP", "localhost") +port = int(os.environ.get("DB_PORT", 27017)) +IS_SRV = "mongodb+srv" in host + +db_user = os.environ.get("DB_USER", "user") +db_pwd = os.environ.get("DB_PASSWORD", "password") + +CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") +CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem")) +CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem")) + +TLS_OPTIONS: Dict = {"tls": True} +if CLIENT_PEM: + TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM +if CA_PEM: + TLS_OPTIONS["tlsCAFile"] = CA_PEM + +COMPRESSORS = os.environ.get("COMPRESSORS") +MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") +TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) +TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS")) +SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") +MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") + +if TEST_LOADBALANCER: + res = parse_uri(SINGLE_MONGOS_LB_URI or "") + host, port = res["nodelist"][0] + db_user = res["username"] or db_user + db_pwd = res["password"] or db_pwd +elif TEST_SERVERLESS: + TEST_LOADBALANCER = True + res = parse_uri(SINGLE_MONGOS_LB_URI or "") + host, port = res["nodelist"][0] + db_user = res["username"] or db_user + db_pwd = res["password"] or db_pwd + TLS_OPTIONS = {"tls": True} + # Spec says serverless tests must be run with compression. + COMPRESSORS = COMPRESSORS or "zlib" + + +# Shared KMS data. +LOCAL_MASTER_KEY = base64.b64decode( + b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ" + b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk" +) +AWS_CREDS = { + "accessKeyId": os.environ.get("FLE_AWS_KEY", ""), + "secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""), +} +AWS_CREDS_2 = { + "accessKeyId": os.environ.get("FLE_AWS_KEY2", ""), + "secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""), +} +AZURE_CREDS = { + "tenantId": os.environ.get("FLE_AZURE_TENANTID", ""), + "clientId": os.environ.get("FLE_AZURE_CLIENTID", ""), + "clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""), +} +GCP_CREDS = { + "email": os.environ.get("FLE_GCP_EMAIL", ""), + "privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""), +} +KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")} + +# Ensure Evergreen metadata doesn't result in truncation +os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000") + + +def is_server_resolvable(): + """Returns True if 'server' is resolvable.""" + socket_timeout = socket.getdefaulttimeout() + socket.setdefaulttimeout(1) + try: + try: + socket.gethostbyname("server") + return True + except OSError: + return False + finally: + socket.setdefaulttimeout(socket_timeout) + + +def _create_user(authdb, user, pwd=None, roles=None, **kwargs): + cmd = SON([("createUser", user)]) + # X509 doesn't use a password + if pwd: + cmd["pwd"] = pwd + cmd["roles"] = roles or ["root"] + cmd.update(**kwargs) + return authdb.command(cmd) + + +class client_knobs: + def __init__( + self, + heartbeat_frequency=None, + min_heartbeat_interval=None, + kill_cursor_frequency=None, + events_queue_frequency=None, + ): + self.heartbeat_frequency = heartbeat_frequency + self.min_heartbeat_interval = min_heartbeat_interval + self.kill_cursor_frequency = kill_cursor_frequency + self.events_queue_frequency = events_queue_frequency + + self.old_heartbeat_frequency = None + self.old_min_heartbeat_interval = None + self.old_kill_cursor_frequency = None + self.old_events_queue_frequency = None + self._enabled = False + self._stack = None + + def enable(self): + self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY + self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL + self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY + self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY + + if self.heartbeat_frequency is not None: + common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency + + if self.min_heartbeat_interval is not None: + common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval + + if self.kill_cursor_frequency is not None: + common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency + + if self.events_queue_frequency is not None: + common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency + self._enabled = True + # Store the allocation traceback to catch non-disabled client_knobs. + self._stack = "".join(traceback.format_stack()) + + def __enter__(self): + self.enable() + + @no_type_check + def disable(self): + common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency + common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval + common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency + common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency + self._enabled = False + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disable() + + def __call__(self, func): + def make_wrapper(f): + @wraps(f) + async def wrap(*args, **kwargs): + with self: + return await f(*args, **kwargs) + + return wrap + + return make_wrapper(func) + + def __del__(self): + if self._enabled: + msg = ( + "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, " + "MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, " + "EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format( + common.HEARTBEAT_FREQUENCY, + common.MIN_HEARTBEAT_INTERVAL, + common.KILL_CURSOR_FREQUENCY, + common.EVENTS_QUEUE_FREQUENCY, + self._stack, + ) + ) + self.disable() + raise Exception(msg) + + +def _all_users(db): + return {u["user"] for u in db.command("usersInfo").get("users", [])} + + +def sanitize_cmd(cmd): + cp = cmd.copy() + cp.pop("$clusterTime", None) + cp.pop("$db", None) + cp.pop("$readPreference", None) + cp.pop("lsid", None) + if MONGODB_API_VERSION: + # Stable API parameters + cp.pop("apiVersion", None) + # OP_MSG encoding may move the payload type one field to the + # end of the command. Do the same here. + name = next(iter(cp)) + try: + identifier = message._FIELD_MAP[name] + docs = cp.pop(identifier) + cp[identifier] = docs + except KeyError: + pass + return cp + + +def sanitize_reply(reply): + cp = reply.copy() + cp.pop("$clusterTime", None) + cp.pop("operationTime", None) + return cp + + +def print_thread_tracebacks() -> None: + """Print all Python thread tracebacks.""" + for thread_id, frame in sys._current_frames().items(): + sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n") + traceback.print_stack(frame, file=sys.stderr) + + +def print_thread_stacks(pid: int) -> None: + """Print all C-level thread stacks for a given process id.""" + if sys.platform == "darwin": + cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"'] + else: + cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"'] + + try: + res = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8" + ) + except Exception as exc: + sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}") + else: + sys.stderr.write(res.stdout) + + +# Global knobs to speed up the test suite. +global_knobs = client_knobs(events_queue_frequency=0.05) + + +def _get_executors(topology): + executors = [] + for server in topology._servers.values(): + # Some MockMonitor do not have an _executor. + if hasattr(server._monitor, "_executor"): + executors.append(server._monitor._executor) + if hasattr(server._monitor, "_rtt_monitor"): + executors.append(server._monitor._rtt_monitor._executor) + executors.append(topology._Topology__events_executor) + if topology._srv_monitor: + executors.append(topology._srv_monitor._executor) + + return [e for e in executors if e is not None] + + +def print_running_topology(topology): + running = [e for e in _get_executors(topology) if not e._stopped] + if running: + print( + "WARNING: found Topology with running threads:\n" + f" Threads: {running}\n" + f" Topology: {topology}\n" + f" Creation traceback:\n{topology._settings._stack}" + ) + + +def test_cases(suite): + """Iterator over all TestCases within a TestSuite.""" + for suite_or_case in suite._tests: + if isinstance(suite_or_case, unittest.TestCase): + # unittest.TestCase + yield suite_or_case + else: + # unittest.TestSuite + yield from test_cases(suite_or_case) + + +# Helper method to workaround https://bugs.python.org/issue21724 +def clear_warning_registry(): + """Clear the __warningregistry__ for all modules.""" + for _, module in list(sys.modules.items()): + if hasattr(module, "__warningregistry__"): + module.__warningregistry__ = {} # type:ignore[attr-defined] + + +class SystemCertsPatcher: + def __init__(self, ca_certs): + if ( + ssl.OPENSSL_VERSION.lower().startswith("libressl") + and sys.platform == "darwin" + and not _ssl.IS_PYOPENSSL + ): + raise SkipTest( + "LibreSSL on OSX doesn't support setting CA certificates " + "using SSL_CERT_FILE environment variable." + ) + self.original_certs = os.environ.get("SSL_CERT_FILE") + # Tell OpenSSL where CA certificates live. + os.environ["SSL_CERT_FILE"] = ca_certs + + def disable(self): + if self.original_certs is None: + os.environ.pop("SSL_CERT_FILE") + else: + os.environ["SSL_CERT_FILE"] = self.original_certs diff --git a/test/asynchronous/test_retryable_reads.py b/test/asynchronous/test_retryable_reads.py new file mode 100644 index 000000000..b2d86f5d8 --- /dev/null +++ b/test/asynchronous/test_retryable_reads.py @@ -0,0 +1,191 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test retryable reads spec.""" +from __future__ import annotations + +import os +import pprint +import sys +import threading + +from pymongo.errors import AutoReconnect + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + async_client_context, + client_knobs, + unittest, +) +from test.utils import ( + CMAPListener, + OvertCommandListener, + async_set_fail_point, +) + +from pymongo.monitoring import ( + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutFailedReason, + PoolClearedEvent, +) + +_IS_SYNC = False + + +class TestClientOptions(AsyncPyMongoTestCase): + async def test_default(self): + client = self.simple_client(connect=False) + self.assertEqual(client.options.retry_reads, True) + + async def test_kwargs(self): + client = self.simple_client(retryReads=True, connect=False) + self.assertEqual(client.options.retry_reads, True) + client = self.simple_client(retryReads=False, connect=False) + self.assertEqual(client.options.retry_reads, False) + + async def test_uri(self): + client = self.simple_client("mongodb://h/?retryReads=true", connect=False) + self.assertEqual(client.options.retry_reads, True) + client = self.simple_client("mongodb://h/?retryReads=false", connect=False) + self.assertEqual(client.options.retry_reads, False) + + +class FindThread(threading.Thread): + def __init__(self, collection): + super().__init__() + self.daemon = True + self.collection = collection + self.passed = False + + async def run(self): + await self.collection.find_one({}) + self.passed = True + + +class TestPoolPausedError(AsyncIntegrationTest): + # Pools don't get paused in load balanced mode. + RUN_ON_LOAD_BALANCER = False + RUN_ON_SERVERLESS = False + + @async_client_context.require_sync + @async_client_context.require_failCommand_blockConnection + @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) + async def test_pool_paused_error_is_retryable(self): + if "PyPy" in sys.version: + # Tracked in PYTHON-3519 + self.skipTest("Test is flakey on PyPy") + cmap_listener = CMAPListener() + cmd_listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] + ) + for _ in range(10): + cmap_listener.reset() + cmd_listener.reset() + threads = [FindThread(client.pymongo_test.test) for _ in range(2)] + fail_command = { + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "blockConnection": True, + "blockTimeMS": 1000, + "errorCode": 91, + }, + } + async with self.fail_point(fail_command): + for thread in threads: + thread.start() + for thread in threads: + thread.join() + for thread in threads: + self.assertTrue(thread.passed) + + # It's possible that SDAM can rediscover the server and mark the + # pool ready before the thread in the wait queue has a chance + # to run. Repeat the test until the thread actually encounters + # a PoolClearedError. + if cmap_listener.event_count(ConnectionCheckOutFailedEvent): + break + + # Via CMAP monitoring, assert that the first check out succeeds. + cmap_events = cmap_listener.events_by_type( + (ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, PoolClearedEvent) + ) + msg = pprint.pformat(cmap_listener.events) + self.assertIsInstance(cmap_events[0], ConnectionCheckedOutEvent, msg) + self.assertIsInstance(cmap_events[1], PoolClearedEvent, msg) + self.assertIsInstance(cmap_events[2], ConnectionCheckOutFailedEvent, msg) + self.assertEqual(cmap_events[2].reason, ConnectionCheckOutFailedReason.CONN_ERROR, msg) + self.assertIsInstance(cmap_events[3], ConnectionCheckedOutEvent, msg) + + # Connection check out failures are not reflected in command + # monitoring because we only publish command events _after_ checking + # out a connection. + started = cmd_listener.started_events + msg = pprint.pformat(cmd_listener.results) + self.assertEqual(3, len(started), msg) + succeeded = cmd_listener.succeeded_events + self.assertEqual(2, len(succeeded), msg) + failed = cmd_listener.failed_events + self.assertEqual(1, len(failed), msg) + + +class TestRetryableReads(AsyncIntegrationTest): + @async_client_context.require_multiple_mongoses + @async_client_context.require_failCommand_fail_point + async def test_retryable_reads_in_sharded_cluster_multiple_available(self): + fail_command = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "closeConnection": True, + "appName": "retryableReadTest", + }, + } + + mongos_clients = [] + + for mongos in async_client_context.mongos_seeds().split(","): + client = await self.async_rs_or_single_client(mongos) + await async_set_fail_point(client, fail_command) + mongos_clients.append(client) + + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + async_client_context.mongos_seeds(), + appName="retryableReadTest", + event_listeners=[listener], + retryReads=True, + ) + + async with self.fail_point(fail_command): + with self.assertRaises(AutoReconnect): + await client.t.t.find_one({}) + + # Disable failpoints on each mongos + for client in mongos_clients: + fail_command["mode"] = "off" + await async_set_fail_point(client, fail_command) + + self.assertEqual(len(listener.failed_events), 2) + self.assertEqual(len(listener.succeeded_events), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py new file mode 100644 index 000000000..accbbd003 --- /dev/null +++ b/test/asynchronous/test_retryable_writes.py @@ -0,0 +1,694 @@ +# Copyright 2017 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test retryable writes.""" +from __future__ import annotations + +import asyncio +import copy +import pprint +import sys +import threading + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + SkipTest, + async_client_context, + unittest, +) +from test.asynchronous.helpers import client_knobs +from test.utils import ( + CMAPListener, + DeprecationFilter, + EventListener, + OvertCommandListener, + async_set_fail_point, +) +from test.version import Version + +from bson.codec_options import DEFAULT_CODEC_OPTIONS +from bson.int64 import Int64 +from bson.raw_bson import RawBSONDocument +from bson.son import SON +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + AutoReconnect, + ConnectionFailure, + OperationFailure, + ServerSelectionTimeoutError, + WriteConcernError, +) +from pymongo.monitoring import ( + CommandSucceededEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutFailedReason, + PoolClearedEvent, +) +from pymongo.operations import ( + DeleteMany, + DeleteOne, + InsertOne, + ReplaceOne, + UpdateMany, + UpdateOne, +) +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class InsertEventListener(EventListener): + def succeeded(self, event: CommandSucceededEvent) -> None: + super().succeeded(event) + if ( + event.command_name == "insert" + and event.reply.get("writeConcernError", {}).get("code", None) == 91 + ): + async_client_context.client.admin.command( + { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "errorCode": 10107, + "errorLabels": ["RetryableWriteError", "NoWritesPerformed"], + "failCommands": ["insert"], + }, + } + ) + + +def retryable_single_statement_ops(coll): + return [ + (coll.bulk_write, [[InsertOne({}), InsertOne({})]], {}), + (coll.bulk_write, [[InsertOne({}), InsertOne({})]], {"ordered": False}), + (coll.bulk_write, [[ReplaceOne({}, {"a1": 1})]], {}), + (coll.bulk_write, [[ReplaceOne({}, {"a2": 1}), ReplaceOne({}, {"a3": 1})]], {}), + ( + coll.bulk_write, + [[UpdateOne({}, {"$set": {"a4": 1}}), UpdateOne({}, {"$set": {"a5": 1}})]], + {}, + ), + (coll.bulk_write, [[DeleteOne({})]], {}), + (coll.bulk_write, [[DeleteOne({}), DeleteOne({})]], {}), + (coll.insert_one, [{}], {}), + (coll.insert_many, [[{}, {}]], {}), + (coll.replace_one, [{}, {"a6": 1}], {}), + (coll.update_one, [{}, {"$set": {"a7": 1}}], {}), + (coll.delete_one, [{}], {}), + (coll.find_one_and_replace, [{}, {"a8": 1}], {}), + (coll.find_one_and_update, [{}, {"$set": {"a9": 1}}], {}), + (coll.find_one_and_delete, [{}, {"a10": 1}], {}), + ] + + +def non_retryable_single_statement_ops(coll): + return [ + ( + coll.bulk_write, + [[UpdateOne({}, {"$set": {"a": 1}}), UpdateMany({}, {"$set": {"a": 1}})]], + {}, + ), + (coll.bulk_write, [[DeleteOne({}), DeleteMany({})]], {}), + (coll.update_many, [{}, {"$set": {"a": 1}}], {}), + (coll.delete_many, [{}], {}), + ] + + +class IgnoreDeprecationsTest(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + deprecation_filter: DeprecationFilter + + @classmethod + async def _setup_class(cls): + await super()._setup_class() + cls.deprecation_filter = DeprecationFilter() + + @classmethod + async def _tearDown_class(cls): + cls.deprecation_filter.stop() + await super()._tearDown_class() + + +class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): + knobs: client_knobs + + @classmethod + async def _setup_class(cls): + await super()._setup_class() + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + cls.knobs.enable() + cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True) + cls.db = cls.client.pymongo_test + + @classmethod + async def _tearDown_class(cls): + cls.knobs.disable() + await cls.client.close() + await super()._tearDown_class() + + @async_client_context.require_no_standalone + async def test_actionable_error_message(self): + if async_client_context.storage_engine != "mmapv1": + raise SkipTest("This cluster is not running MMAPv1") + + expected_msg = ( + "This MongoDB deployment does not support retryable " + "writes. Please add retryWrites=false to your " + "connection string." + ) + for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test): + with self.assertRaisesRegex(OperationFailure, expected_msg): + await method(*args, **kwargs) + + +class TestRetryableWrites(IgnoreDeprecationsTest): + listener: OvertCommandListener + knobs: client_knobs + + @classmethod + @async_client_context.require_no_mmap + async def _setup_class(cls): + await super()._setup_class() + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + cls.knobs.enable() + cls.listener = OvertCommandListener() + cls.client = await cls.unmanaged_async_rs_or_single_client( + retryWrites=True, event_listeners=[cls.listener] + ) + cls.db = cls.client.pymongo_test + + @classmethod + async def _tearDown_class(cls): + cls.knobs.disable() + await cls.client.close() + await super()._tearDown_class() + + async def asyncSetUp(self): + if async_client_context.is_rs and async_client_context.test_commands_enabled: + await self.client.admin.command( + SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) + ) + + async def asyncTearDown(self): + if async_client_context.is_rs and async_client_context.test_commands_enabled: + await self.client.admin.command( + SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) + ) + + async def test_supported_single_statement_no_retry(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(retryWrites=False, event_listeners=[listener]) + for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" + listener.reset() + await method(*args, **kwargs) + for event in listener.started_events: + self.assertNotIn( + "txnNumber", + event.command, + f"{msg} sent txnNumber with {event.command_name}", + ) + + @async_client_context.require_no_standalone + async def test_supported_single_statement_supported_cluster(self): + for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test): + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" + self.listener.reset() + await method(*args, **kwargs) + commands_started = self.listener.started_events + self.assertEqual(len(self.listener.succeeded_events), 1, msg) + first_attempt = commands_started[0] + self.assertIn( + "lsid", + first_attempt.command, + f"{msg} sent no lsid with {first_attempt.command_name}", + ) + initial_session_id = first_attempt.command["lsid"] + self.assertIn( + "txnNumber", + first_attempt.command, + f"{msg} sent no txnNumber with {first_attempt.command_name}", + ) + + # There should be no retry when the failpoint is not active. + if async_client_context.is_mongos or not async_client_context.test_commands_enabled: + self.assertEqual(len(commands_started), 1) + continue + + initial_transaction_id = first_attempt.command["txnNumber"] + retry_attempt = commands_started[1] + self.assertIn( + "lsid", + retry_attempt.command, + f"{msg} sent no lsid with {first_attempt.command_name}", + ) + self.assertEqual(retry_attempt.command["lsid"], initial_session_id, msg) + self.assertIn( + "txnNumber", + retry_attempt.command, + f"{msg} sent no txnNumber with {first_attempt.command_name}", + ) + self.assertEqual(retry_attempt.command["txnNumber"], initial_transaction_id, msg) + + async def test_supported_single_statement_unsupported_cluster(self): + if async_client_context.is_rs or async_client_context.is_mongos: + raise SkipTest("This cluster supports retryable writes") + + for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test): + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" + self.listener.reset() + await method(*args, **kwargs) + + for event in self.listener.started_events: + self.assertNotIn( + "txnNumber", + event.command, + f"{msg} sent txnNumber with {event.command_name}", + ) + + async def test_unsupported_single_statement(self): + coll = self.db.retryable_write_test + await coll.insert_many([{}, {}]) + coll_w0 = coll.with_options(write_concern=WriteConcern(w=0)) + for method, args, kwargs in non_retryable_single_statement_ops( + coll + ) + retryable_single_statement_ops(coll_w0): + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" + self.listener.reset() + await method(*args, **kwargs) + started_events = self.listener.started_events + self.assertEqual(len(self.listener.succeeded_events), len(started_events), msg) + self.assertEqual(len(self.listener.failed_events), 0, msg) + for event in started_events: + self.assertNotIn( + "txnNumber", + event.command, + f"{msg} sent txnNumber with {event.command_name}", + ) + + async def test_server_selection_timeout_not_retried(self): + """A ServerSelectionTimeoutError is not retried.""" + listener = OvertCommandListener() + client = self.simple_client( + "somedomainthatdoesntexist.org", + serverSelectionTimeoutMS=1, + retryWrites=True, + event_listeners=[listener], + ) + for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" + listener.reset() + with self.assertRaises(ServerSelectionTimeoutError, msg=msg): + await method(*args, **kwargs) + self.assertEqual(len(listener.started_events), 0, msg) + + @async_client_context.require_replica_set + @async_client_context.require_test_commands + async def test_retry_timeout_raises_original_error(self): + """A ServerSelectionTimeoutError on the retry attempt raises the + original error. + """ + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener]) + topology = client._topology + select_server = topology.select_server + + def mock_select_server(*args, **kwargs): + server = select_server(*args, **kwargs) + + def raise_error(*args, **kwargs): + raise ServerSelectionTimeoutError("No primary available for writes") + + # Raise ServerSelectionTimeout on the retry attempt. + topology.select_server = raise_error + return server + + for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" + listener.reset() + topology.select_server = mock_select_server + with self.assertRaises(ConnectionFailure, msg=msg): + await method(*args, **kwargs) + self.assertEqual(len(listener.started_events), 1, msg) + + @async_client_context.require_replica_set + @async_client_context.require_test_commands + async def test_batch_splitting(self): + """Test retry succeeds after failures during batch splitting.""" + large = "s" * 1024 * 1024 * 15 + coll = self.db.retryable_write_test + await coll.delete_many({}) + self.listener.reset() + bulk_result = await coll.bulk_write( + [ + InsertOne({"_id": 1, "l": large}), + InsertOne({"_id": 2, "l": large}), + InsertOne({"_id": 3, "l": large}), + UpdateOne({"_id": 1, "l": large}, {"$unset": {"l": 1}, "$inc": {"count": 1}}), + UpdateOne({"_id": 2, "l": large}, {"$set": {"foo": "bar"}}), + DeleteOne({"l": large}), + DeleteOne({"l": large}), + ] + ) + # Each command should fail and be retried. + # With OP_MSG 3 inserts are one batch. 2 updates another. + # 2 deletes a third. + self.assertEqual(len(self.listener.started_events), 6) + self.assertEqual(await coll.find_one(), {"_id": 1, "count": 1}) + # Assert the final result + expected_result = { + "writeErrors": [], + "writeConcernErrors": [], + "nInserted": 3, + "nUpserted": 0, + "nMatched": 2, + "nModified": 2, + "nRemoved": 2, + "upserted": [], + } + self.assertEqual(bulk_result.bulk_api_result, expected_result) + + @async_client_context.require_replica_set + @async_client_context.require_test_commands + async def test_batch_splitting_retry_fails(self): + """Test retry fails during batch splitting.""" + large = "s" * 1024 * 1024 * 15 + coll = self.db.retryable_write_test + await coll.delete_many({}) + await self.client.admin.command( + SON( + [ + ("configureFailPoint", "onPrimaryTransactionalWrite"), + ("mode", {"skip": 3}), # The number of _documents_ to skip. + ("data", {"failBeforeCommitExceptionCode": 1}), + ] + ) + ) + self.listener.reset() + async with self.client.start_session() as session: + initial_txn = session._transaction_id + try: + await coll.bulk_write( + [ + InsertOne({"_id": 1, "l": large}), + InsertOne({"_id": 2, "l": large}), + InsertOne({"_id": 3, "l": large}), + InsertOne({"_id": 4, "l": large}), + ], + session=session, + ) + except ConnectionFailure: + pass + else: + self.fail("bulk_write should have failed") + + started = self.listener.started_events + self.assertEqual(len(started), 3) + self.assertEqual(len(self.listener.succeeded_events), 1) + expected_txn = Int64(initial_txn + 1) + self.assertEqual(started[0].command["txnNumber"], expected_txn) + self.assertEqual(started[0].command["lsid"], session.session_id) + expected_txn = Int64(initial_txn + 2) + self.assertEqual(started[1].command["txnNumber"], expected_txn) + self.assertEqual(started[1].command["lsid"], session.session_id) + started[1].command.pop("$clusterTime") + started[2].command.pop("$clusterTime") + self.assertEqual(started[1].command, started[2].command) + final_txn = session._transaction_id + self.assertEqual(final_txn, expected_txn) + self.assertEqual(await coll.find_one(projection={"_id": True}), {"_id": 1}) + + @async_client_context.require_multiple_mongoses + @async_client_context.require_failCommand_fail_point + async def test_retryable_writes_in_sharded_cluster_multiple_available(self): + fail_command = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "closeConnection": True, + "appName": "retryableWriteTest", + }, + } + + mongos_clients = [] + + for mongos in async_client_context.mongos_seeds().split(","): + client = await self.async_rs_or_single_client(mongos) + await async_set_fail_point(client, fail_command) + mongos_clients.append(client) + + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + async_client_context.mongos_seeds(), + appName="retryableWriteTest", + event_listeners=[listener], + retryWrites=True, + ) + + with self.assertRaises(AutoReconnect): + await client.t.t.insert_one({"x": 1}) + + # Disable failpoints on each mongos + for client in mongos_clients: + fail_command["mode"] = "off" + await async_set_fail_point(client, fail_command) + + self.assertEqual(len(listener.failed_events), 2) + self.assertEqual(len(listener.succeeded_events), 0) + + +class TestWriteConcernError(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + fail_insert: dict + + @classmethod + @async_client_context.require_replica_set + @async_client_context.require_no_mmap + @async_client_context.require_failCommand_fail_point + async def _setup_class(cls): + await super()._setup_class() + cls.fail_insert = { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": { + "failCommands": ["insert"], + "writeConcernError": {"code": 91, "errmsg": "Replication is being shut down"}, + }, + } + + @async_client_context.require_version_min(4, 0) + @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) + async def test_RetryableWriteError_error_label(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener]) + + # Ensure collection exists. + await client.pymongo_test.testcoll.insert_one({}) + + async with self.fail_point(self.fail_insert): + with self.assertRaises(WriteConcernError) as cm: + await client.pymongo_test.testcoll.insert_one({}) + self.assertTrue(cm.exception.has_error_label("RetryableWriteError")) + + if async_client_context.version >= Version(4, 4): + # In MongoDB 4.4+ we rely on the server returning the error label. + self.assertIn("RetryableWriteError", listener.succeeded_events[-1].reply["errorLabels"]) + + @async_client_context.require_version_min(4, 4) + async def test_RetryableWriteError_error_label_RawBSONDocument(self): + # using RawBSONDocument should not cause errorLabel parsing to fail + async with self.fail_point(self.fail_insert): + async with self.client.start_session() as s: + s._start_retryable_write() + result = await self.client.pymongo_test.command( + "insert", + "testcoll", + documents=[{"_id": 1}], + txnNumber=s._transaction_id, + session=s, + codec_options=DEFAULT_CODEC_OPTIONS.with_options( + document_class=RawBSONDocument + ), + ) + + self.assertIn("writeConcernError", result) + self.assertIn("RetryableWriteError", result["errorLabels"]) + + +class InsertThread(threading.Thread): + def __init__(self, collection): + super().__init__() + self.daemon = True + self.collection = collection + self.passed = False + + async def run(self): + await self.collection.insert_one({}) + self.passed = True + + +class TestPoolPausedError(AsyncIntegrationTest): + # Pools don't get paused in load balanced mode. + RUN_ON_LOAD_BALANCER = False + RUN_ON_SERVERLESS = False + + @async_client_context.require_sync + @async_client_context.require_failCommand_blockConnection + @async_client_context.require_retryable_writes + @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) + async def test_pool_paused_error_is_retryable(self): + cmap_listener = CMAPListener() + cmd_listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] + ) + for _ in range(10): + cmap_listener.reset() + cmd_listener.reset() + threads = [InsertThread(client.pymongo_test.test) for _ in range(2)] + fail_command = { + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "blockConnection": True, + "blockTimeMS": 1000, + "errorCode": 91, + "errorLabels": ["RetryableWriteError"], + }, + } + async with self.fail_point(fail_command): + for thread in threads: + thread.start() + for thread in threads: + thread.join() + for thread in threads: + self.assertTrue(thread.passed) + # It's possible that SDAM can rediscover the server and mark the + # pool ready before the thread in the wait queue has a chance + # to run. Repeat the test until the thread actually encounters + # a PoolClearedError. + if cmap_listener.event_count(ConnectionCheckOutFailedEvent): + break + + # Via CMAP monitoring, assert that the first check out succeeds. + cmap_events = cmap_listener.events_by_type( + (ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, PoolClearedEvent) + ) + msg = pprint.pformat(cmap_listener.events) + self.assertIsInstance(cmap_events[0], ConnectionCheckedOutEvent, msg) + self.assertIsInstance(cmap_events[1], PoolClearedEvent, msg) + self.assertIsInstance(cmap_events[2], ConnectionCheckOutFailedEvent, msg) + self.assertEqual(cmap_events[2].reason, ConnectionCheckOutFailedReason.CONN_ERROR, msg) + self.assertIsInstance(cmap_events[3], ConnectionCheckedOutEvent, msg) + + # Connection check out failures are not reflected in command + # monitoring because we only publish command events _after_ checking + # out a connection. + started = cmd_listener.started_events + msg = pprint.pformat(cmd_listener.results) + self.assertEqual(3, len(started), msg) + succeeded = cmd_listener.succeeded_events + self.assertEqual(2, len(succeeded), msg) + failed = cmd_listener.failed_events + self.assertEqual(1, len(failed), msg) + + @async_client_context.require_sync + @async_client_context.require_failCommand_fail_point + @async_client_context.require_replica_set + @async_client_context.require_version_min( + 6, 0, 0 + ) # the spec requires that this prose test only be run on 6.0+ + @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) + async def test_returns_original_error_code( + self, + ): + cmd_listener = InsertEventListener() + client = await self.async_rs_or_single_client( + retryWrites=True, event_listeners=[cmd_listener] + ) + await client.test.test.drop() + cmd_listener.reset() + await client.admin.command( + { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "writeConcernError": { + "code": 91, + "errorLabels": ["RetryableWriteError"], + }, + "failCommands": ["insert"], + }, + } + ) + with self.assertRaises(WriteConcernError) as exc: + await client.test.test.insert_one({"_id": 1}) + self.assertEqual(exc.exception.code, 91) + await client.admin.command( + { + "configureFailPoint": "failCommand", + "mode": "off", + } + ) + + +# TODO: Make this a real integration test where we stepdown the primary. +class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest): + @async_client_context.require_replica_set + @async_client_context.require_no_mmap + async def test_increment_transaction_id_without_sending_command(self): + """Test that the txnNumber field is properly incremented, even when + the first attempt fails before sending the command. + """ + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener]) + topology = client._topology + select_server = topology.select_server + + def raise_connection_err_select_server(*args, **kwargs): + # Raise ConnectionFailure on the first attempt and perform + # normal selection on the retry attempt. + topology.select_server = select_server + raise ConnectionFailure("Connection refused") + + for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): + listener.reset() + topology.select_server = raise_connection_err_select_server + async with client.start_session() as session: + kwargs = copy.deepcopy(kwargs) + kwargs["session"] = session + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" + initial_txn_id = session._transaction_id + + # Each operation should fail on the first attempt and succeed + # on the second. + await method(*args, **kwargs) + self.assertEqual(len(listener.started_events), 1, msg) + retry_cmd = listener.started_events[0].command + sent_txn_id = retry_cmd["txnNumber"] + final_txn_id = session._transaction_id + self.assertEqual(Int64(initial_txn_id + 1), sent_txn_id, msg) + self.assertEqual(sent_txn_id, final_txn_id, msg) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/bson_binary_vector/float32.json b/test/bson_binary_vector/float32.json new file mode 100644 index 000000000..bbbe00b75 --- /dev/null +++ b/test/bson_binary_vector/float32.json @@ -0,0 +1,42 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype FLOAT32", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector FLOAT32", + "valid": true, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927000000FE420000E04000" + }, + { + "description": "Empty Vector FLOAT32", + "valid": true, + "vector": [], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009270000" + }, + { + "description": "Infinity Vector FLOAT32", + "valid": true, + "vector": ["-inf", 0.0, "inf"], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "2000000005766563746F72000E000000092700000080FF000000000000807F00" + }, + { + "description": "FLOAT32 with padding", + "valid": false, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 3 + } + ] +} + diff --git a/test/bson_binary_vector/int8.json b/test/bson_binary_vector/int8.json new file mode 100644 index 000000000..7529721e5 --- /dev/null +++ b/test/bson_binary_vector/int8.json @@ -0,0 +1,57 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype INT8", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector INT8", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000903007F0700" + }, + { + "description": "Empty Vector INT8", + "valid": true, + "vector": [], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009030000" + }, + { + "description": "Overflow Vector INT8", + "valid": false, + "vector": [128], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "Underflow Vector INT8", + "valid": false, + "vector": [-129], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "INT8 with padding", + "valid": false, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 3 + }, + { + "description": "INT8 with float inputs", + "valid": false, + "vector": [127.77, 7.77], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + } + ] +} + diff --git a/test/bson_binary_vector/packed_bit.json b/test/bson_binary_vector/packed_bit.json new file mode 100644 index 000000000..a41cd593f --- /dev/null +++ b/test/bson_binary_vector/packed_bit.json @@ -0,0 +1,50 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype PACKED_BIT", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector PACKED_BIT", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000910007F0700" + }, + { + "description": "Empty Vector PACKED_BIT", + "valid": true, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009100000" + }, + { + "description": "PACKED_BIT with padding", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 3, + "canonical_bson": "1600000005766563746F7200040000000910037F0700" + }, + { + "description": "Overflow Vector PACKED_BIT", + "valid": false, + "vector": [256], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Underflow Vector PACKED_BIT", + "valid": false, + "vector": [-1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + } + ] +} + diff --git a/test/bson_corpus/binary.json b/test/bson_corpus/binary.json index 20aaef743..0e0056f3a 100644 --- a/test/bson_corpus/binary.json +++ b/test/bson_corpus/binary.json @@ -74,6 +74,36 @@ "description": "$type query operator (conflicts with legacy $binary form with $type field)", "canonical_bson": "180000000378001000000010247479706500020000000000", "canonical_extjson": "{\"x\" : { \"$type\" : {\"$numberInt\": \"2\"}}}" + }, + { + "description": "subtype 0x09 Vector FLOAT32", + "canonical_bson": "170000000578000A0000000927000000FE420000E04000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwAAAP5CAADgQA==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector INT8", + "canonical_bson": "11000000057800040000000903007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector PACKED_BIT", + "canonical_bson": "11000000057800040000000910007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) FLOAT32", + "canonical_bson": "0F0000000578000200000009270000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) INT8", + "canonical_bson": "0F0000000578000200000009030000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) PACKED_BIT", + "canonical_bson": "0F0000000578000200000009100000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAA=\", \"subType\": \"09\"}}}" } ], "decodeErrors": [ diff --git a/test/helpers.py b/test/helpers.py index b38b2e298..bf6186d1a 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -28,6 +28,7 @@ import time import traceback import unittest import warnings +from asyncio import iscoroutinefunction try: import ipaddress @@ -47,6 +48,8 @@ from pymongo.uri_parser import parse_uri if HAVE_SSL: import ssl +_IS_SYNC = True + # Enable debug output for uncollectable objects. PyPy does not have set_debug. if hasattr(gc, "set_debug"): gc.set_debug( diff --git a/test/test_bson.py b/test/test_bson.py index a0190ef2d..96aa897d1 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -49,8 +49,9 @@ from bson import ( decode_iter, encode, is_valid, + json_util, ) -from bson.binary import USER_DEFINED_SUBTYPE, Binary, UuidRepresentation +from bson.binary import USER_DEFINED_SUBTYPE, Binary, BinaryVectorDtype, UuidRepresentation from bson.code import Code from bson.codec_options import CodecOptions, DatetimeConversion from bson.datetime_ms import _DATETIME_ERROR_SUGGESTION @@ -148,6 +149,9 @@ class TestBSON(unittest.TestCase): helper({"a binary": Binary(b"test", 128)}) helper({"a binary": Binary(b"test", 254)}) helper({"another binary": Binary(b"test", 2)}) + helper({"binary packed bit vector": Binary(b"\x10\x00\x7f\x07", 9)}) + helper({"binary int8 vector": Binary(b"\x03\x00\x7f\x07", 9)}) + helper({"binary float32 vector": Binary(b"'\x00\x00\x00\xfeB\x00\x00\xe0@", 9)}) helper(SON([("test dst", datetime.datetime(1993, 4, 4, 2))])) helper(SON([("test negative dst", datetime.datetime(1, 1, 1, 1, 1, 1))])) helper({"big float": float(10000000000)}) @@ -447,6 +451,20 @@ class TestBSON(unittest.TestCase): encode({"test": Binary(b"test", 128)}), b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00\x00\x00\x80\x74\x65\x73\x74\x00", ) + self.assertEqual( + encode({"vector_int8": Binary.from_vector([-128, -1, 127], BinaryVectorDtype.INT8)}), + b"\x1c\x00\x00\x00\x05vector_int8\x00\x05\x00\x00\x00\t\x03\x00\x80\xff\x7f\x00", + ) + self.assertEqual( + encode({"vector_bool": Binary.from_vector([1, 127], BinaryVectorDtype.PACKED_BIT)}), + b"\x1b\x00\x00\x00\x05vector_bool\x00\x04\x00\x00\x00\t\x10\x00\x01\x7f\x00", + ) + self.assertEqual( + encode( + {"vector_float32": Binary.from_vector([-1.1, 1.1e10], BinaryVectorDtype.FLOAT32)} + ), + b"$\x00\x00\x00\x05vector_float32\x00\n\x00\x00\x00\t'\x00\xcd\xcc\x8c\xbf\xac\xe9#P\x00", + ) self.assertEqual(encode({"test": None}), b"\x0B\x00\x00\x00\x0A\x74\x65\x73\x74\x00\x00") self.assertEqual( encode({"date": datetime.datetime(2007, 1, 8, 0, 30, 11)}), @@ -711,9 +729,66 @@ class TestBSON(unittest.TestCase): transformed = bin.as_uuid(UuidRepresentation.PYTHON_LEGACY) self.assertEqual(id, transformed) - # The C extension was segfaulting on unicode RegExs, so we have this test - # that doesn't really test anything but the lack of a segfault. + def test_vector(self): + """Tests of subtype 9""" + # We start with valid cases, across the 3 dtypes implemented. + # Work with a simple vector that can be interpreted as int8, float32, or ubyte + list_vector = [127, 7] + # As INT8, vector has length 2 + binary_vector = Binary.from_vector(list_vector, BinaryVectorDtype.INT8) + vector = binary_vector.as_vector() + assert vector.data == list_vector + # test encoding roundtrip + assert {"vector": binary_vector} == decode(encode({"vector": binary_vector})) + # test json roundtrip + assert binary_vector == json_util.loads(json_util.dumps(binary_vector)) + + # For vectors of bits, aka PACKED_BIT type, vector has length 8 * 2 + packed_bit_binary = Binary.from_vector(list_vector, BinaryVectorDtype.PACKED_BIT) + packed_bit_vec = packed_bit_binary.as_vector() + assert packed_bit_vec.data == list_vector + + # A padding parameter permits vectors of length that aren't divisible by 8 + # The following ignores the last 3 bits in list_vector, + # hence it's length is 8 * len(list_vector) - padding + padding = 3 + padded_vec = Binary.from_vector(list_vector, BinaryVectorDtype.PACKED_BIT, padding=padding) + assert padded_vec.as_vector().data == list_vector + # To visualize how this looks as a binary vector.. + uncompressed = "" + for val in list_vector: + uncompressed += format(val, "08b") + assert uncompressed[:-padding] == "0111111100000" + + # It is worthwhile explicitly showing the values encoded to BSON + padded_doc = {"padded_vec": padded_vec} + assert ( + encode(padded_doc) + == b"\x1a\x00\x00\x00\x05padded_vec\x00\x04\x00\x00\x00\t\x10\x03\x7f\x07\x00" + ) + # and dumped to json + assert ( + json_util.dumps(padded_doc) + == '{"padded_vec": {"$binary": {"base64": "EAN/Bw==", "subType": "09"}}}' + ) + + # FLOAT32 is also implemented + float_binary = Binary.from_vector(list_vector, BinaryVectorDtype.FLOAT32) + assert all(isinstance(d, float) for d in float_binary.as_vector().data) + + # Now some invalid cases + for x in [-1, 257]: + try: + Binary.from_vector([x], BinaryVectorDtype.PACKED_BIT) + except Exception as exc: + self.assertTrue(isinstance(exc, struct.error)) + else: + self.fail("Failed to raise an exception.") + def test_unicode_regex(self): + """Tests we do not get a segfault for C extension on unicode RegExs. + This had been happening. + """ regex = re.compile("revisi\xf3n") decode(encode({"regex": regex})) diff --git a/test/test_bson_binary_vector.py b/test/test_bson_binary_vector.py new file mode 100644 index 000000000..00c82bbb6 --- /dev/null +++ b/test/test_bson_binary_vector.py @@ -0,0 +1,105 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import binascii +import codecs +import json +import struct +from pathlib import Path +from test import unittest + +from bson import decode, encode +from bson.binary import Binary, BinaryVectorDtype + +_TEST_PATH = Path(__file__).parent / "bson_binary_vector" + + +class TestBSONBinaryVector(unittest.TestCase): + """Runs Binary Vector subtype tests. + + Follows the style of the BSON corpus specification tests. + Tests are automatically generated on import + from json files in _TEST_PATH via `create_tests`. + The actual tests are defined in the inner function `run_test` + of the test generator `create_test`.""" + + +def create_test(case_spec): + """Create standard test given specification in json. + + We use the naming convention expected (exp) and observed (obj) + to differentiate what is in the json (expected or suffix _exp) + from what is produced by the API (observed or suffix _obs) + """ + test_key = case_spec.get("test_key") + + def run_test(self): + for test_case in case_spec.get("tests", []): + description = test_case["description"] + vector_exp = test_case["vector"] + dtype_hex_exp = test_case["dtype_hex"] + dtype_alias_exp = test_case.get("dtype_alias") + padding_exp = test_case.get("padding", 0) + canonical_bson_exp = test_case.get("canonical_bson") + # Convert dtype hex string into bytes + dtype_exp = BinaryVectorDtype(int(dtype_hex_exp, 16).to_bytes(1, byteorder="little")) + + if test_case["valid"]: + # Convert bson string to bytes + cB_exp = binascii.unhexlify(canonical_bson_exp.encode("utf8")) + decoded_doc = decode(cB_exp) + binary_obs = decoded_doc[test_key] + # Handle special float cases like '-inf' + if dtype_exp in [BinaryVectorDtype.FLOAT32]: + vector_exp = [float(x) for x in vector_exp] + + # Test round-tripping canonical bson. + self.assertEqual(encode(decoded_doc), cB_exp, description) + + # Test BSON to Binary Vector + vector_obs = binary_obs.as_vector() + self.assertEqual(vector_obs.dtype, dtype_exp, description) + if dtype_alias_exp: + self.assertEqual( + vector_obs.dtype, BinaryVectorDtype[dtype_alias_exp], description + ) + self.assertEqual(vector_obs.data, vector_exp, description) + self.assertEqual(vector_obs.padding, padding_exp, description) + + # Test Binary Vector to BSON + vector_exp = Binary.from_vector(vector_exp, dtype_exp, padding_exp) + cB_obs = binascii.hexlify(encode({test_key: vector_exp})).decode().upper() + self.assertEqual(cB_obs, canonical_bson_exp, description) + + else: + with self.assertRaises((struct.error, ValueError), msg=description): + Binary.from_vector(vector_exp, dtype_exp, padding_exp) + + return run_test + + +def create_tests(): + for filename in _TEST_PATH.glob("*.json"): + with codecs.open(str(filename), encoding="utf-8") as test_file: + test_method = create_test(json.load(test_file)) + setattr(TestBSONBinaryVector, "test_" + filename.stem, test_method) + + +create_tests() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index b4fafe465..d4951db5e 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -44,8 +44,7 @@ from pymongo.monitoring import ( PoolClearedEvent, ) -# Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_reads", "legacy") +_IS_SYNC = True class TestClientOptions(PyMongoTestCase): @@ -83,6 +82,7 @@ class TestPoolPausedError(IntegrationTest): RUN_ON_LOAD_BALANCER = False RUN_ON_SERVERLESS = False + @client_context.require_sync @client_context.require_failCommand_blockConnection @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) def test_pool_paused_error_is_retryable(self): @@ -94,7 +94,6 @@ class TestPoolPausedError(IntegrationTest): client = self.rs_or_single_client( maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] ) - self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() cmd_listener.reset() @@ -165,7 +164,6 @@ class TestRetryableReads(IntegrationTest): for mongos in client_context.mongos_seeds().split(","): client = self.rs_or_single_client(mongos) set_fail_point(client, fail_command) - self.addCleanup(client.close) mongos_clients.append(client) listener = OvertCommandListener() diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 89454ad23..5df6c41f7 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -15,6 +15,7 @@ """Test retryable writes.""" from __future__ import annotations +import asyncio import copy import pprint import sys @@ -22,7 +23,13 @@ import threading sys.path[0:0] = [""] -from test import IntegrationTest, SkipTest, client_context, client_knobs, unittest +from test import ( + IntegrationTest, + SkipTest, + client_context, + unittest, +) +from test.helpers import client_knobs from test.utils import ( CMAPListener, DeprecationFilter, @@ -61,6 +68,8 @@ from pymongo.operations import ( from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class InsertEventListener(EventListener): def succeeded(self, event: CommandSucceededEvent) -> None: @@ -125,22 +134,22 @@ class IgnoreDeprecationsTest(IntegrationTest): deprecation_filter: DeprecationFilter @classmethod - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.deprecation_filter = DeprecationFilter() @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.deprecation_filter.stop() - super().tearDownClass() + super()._tearDown_class() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs @classmethod - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() @@ -148,10 +157,10 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): cls.db = cls.client.pymongo_test @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.knobs.disable() cls.client.close() - super().tearDownClass() + super()._tearDown_class() @client_context.require_no_standalone def test_actionable_error_message(self): @@ -174,8 +183,8 @@ class TestRetryableWrites(IgnoreDeprecationsTest): @classmethod @client_context.require_no_mmap - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() @@ -186,10 +195,10 @@ class TestRetryableWrites(IgnoreDeprecationsTest): cls.db = cls.client.pymongo_test @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.knobs.disable() cls.client.close() - super().tearDownClass() + super()._tearDown_class() def setUp(self): if client_context.is_rs and client_context.test_commands_enabled: @@ -206,7 +215,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest): def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() client = self.rs_or_single_client(retryWrites=False, event_listeners=[listener]) - self.addCleanup(client.close) for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" listener.reset() @@ -319,7 +327,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest): """ listener = OvertCommandListener() client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) - self.addCleanup(client.close) topology = client._topology select_server = topology.select_server @@ -446,7 +453,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest): for mongos in client_context.mongos_seeds().split(","): client = self.rs_or_single_client(mongos) set_fail_point(client, fail_command) - self.addCleanup(client.close) mongos_clients.append(client) listener = OvertCommandListener() @@ -478,8 +484,8 @@ class TestWriteConcernError(IntegrationTest): @client_context.require_replica_set @client_context.require_no_mmap @client_context.require_failCommand_fail_point - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, @@ -494,7 +500,6 @@ class TestWriteConcernError(IntegrationTest): def test_RetryableWriteError_error_label(self): listener = OvertCommandListener() client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) - self.addCleanup(client.close) # Ensure collection exists. client.pymongo_test.testcoll.insert_one({}) @@ -546,6 +551,7 @@ class TestPoolPausedError(IntegrationTest): RUN_ON_LOAD_BALANCER = False RUN_ON_SERVERLESS = False + @client_context.require_sync @client_context.require_failCommand_blockConnection @client_context.require_retryable_writes @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) @@ -555,7 +561,6 @@ class TestPoolPausedError(IntegrationTest): client = self.rs_or_single_client( maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] ) - self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() cmd_listener.reset() @@ -606,6 +611,7 @@ class TestPoolPausedError(IntegrationTest): failed = cmd_listener.failed_events self.assertEqual(1, len(failed), msg) + @client_context.require_sync @client_context.require_failCommand_fail_point @client_context.require_replica_set @client_context.require_version_min( @@ -618,7 +624,6 @@ class TestPoolPausedError(IntegrationTest): cmd_listener = InsertEventListener() client = self.rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener]) client.test.test.drop() - self.addCleanup(client.close) cmd_listener.reset() client.admin.command( { @@ -654,7 +659,6 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest): """ listener = OvertCommandListener() client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) - self.addCleanup(client.close) topology = client._topology select_server = topology.select_server diff --git a/test/utils.py b/test/utils.py index 6eefd1c7e..961503489 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1157,3 +1157,9 @@ def set_fail_point(client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) client.admin.command(cmd) + + +async def async_set_fail_point(client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await client.admin.command(cmd) diff --git a/tools/synchro.py b/tools/synchro.py index e0c194f96..3333b0de2 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -104,6 +104,7 @@ replacements = { "PyMongo|c|async": "PyMongo|c", "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", + "async_set_fail_point": "set_fail_point", } docstring_replacements: dict[tuple[str, str], str] = { @@ -173,6 +174,7 @@ sync_gridfs_files = [ converted_tests = [ "__init__.py", "conftest.py", + "helpers.py", "pymongo_mocks.py", "utils_spec_runner.py", "qcheck.py", @@ -191,6 +193,8 @@ converted_tests = [ "test_logger.py", "test_monitoring.py", "test_raw_bson.py", + "test_retryable_reads.py", + "test_retryable_writes.py", "test_session.py", "test_transactions.py", ]