Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
7da2c35c9a
@ -76,6 +76,9 @@ do
|
||||
atlas-data-lake-testing|data_lake)
|
||||
cpjson atlas-data-lake-testing/tests/ data_lake
|
||||
;;
|
||||
bson-binary-vector|bson_binary_vector)
|
||||
cpjson bson-binary-vector/tests/ bson_binary_vector
|
||||
;;
|
||||
bson-corpus|bson_corpus)
|
||||
cpjson bson-corpus/tests/ bson_corpus
|
||||
;;
|
||||
|
||||
152
bson/binary.py
152
bson/binary.py
@ -13,7 +13,10 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Tuple, Type, Union
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Sequence, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
"""Tools for representing BSON binary data.
|
||||
@ -191,21 +194,75 @@ SENSITIVE_SUBTYPE = 8
|
||||
"""
|
||||
|
||||
|
||||
VECTOR_SUBTYPE = 9
|
||||
"""**(BETA)** BSON binary subtype for densely packed vector data.
|
||||
|
||||
.. versionadded:: 4.10
|
||||
"""
|
||||
|
||||
|
||||
USER_DEFINED_SUBTYPE = 128
|
||||
"""BSON binary subtype for any user defined structure.
|
||||
"""
|
||||
|
||||
|
||||
class BinaryVectorDtype(Enum):
|
||||
"""**(BETA)** Datatypes of vector subtype.
|
||||
|
||||
:param FLOAT32: (0x27) Pack list of :class:`float` as float32
|
||||
:param INT8: (0x03) Pack list of :class:`int` in [-128, 127] as signed int8
|
||||
:param PACKED_BIT: (0x10) Pack list of :class:`int` in [0, 255] as unsigned uint8
|
||||
|
||||
The `PACKED_BIT` value represents a special case where vector values themselves
|
||||
can only be of two values (0 or 1) but these are packed together into groups of 8,
|
||||
a byte. In Python, these are displayed as ints in range [0, 255]
|
||||
|
||||
Each value is of type bytes with a length of one.
|
||||
|
||||
.. versionadded:: 4.10
|
||||
"""
|
||||
|
||||
INT8 = b"\x03"
|
||||
FLOAT32 = b"\x27"
|
||||
PACKED_BIT = b"\x10"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryVector:
|
||||
"""**(BETA)** Vector of numbers along with metadata for binary interoperability.
|
||||
.. versionadded:: 4.10
|
||||
"""
|
||||
|
||||
__slots__ = ("data", "dtype", "padding")
|
||||
|
||||
def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, padding: int = 0):
|
||||
"""
|
||||
:param data: Sequence of numbers representing the mathematical vector.
|
||||
:param dtype: The data type stored in binary
|
||||
:param padding: The number of bits in the final byte that are to be ignored
|
||||
when a vector element's size is less than a byte
|
||||
and the length of the vector is not a multiple of 8.
|
||||
"""
|
||||
self.data = data
|
||||
self.dtype = dtype
|
||||
self.padding = padding
|
||||
|
||||
|
||||
class Binary(bytes):
|
||||
"""Representation of BSON binary data.
|
||||
|
||||
This is necessary because we want to represent Python strings as
|
||||
the BSON string type. We need to wrap binary data so we can tell
|
||||
We want to represent Python strings as the BSON string type.
|
||||
We need to wrap binary data so that we can tell
|
||||
the difference between what should be considered binary data and
|
||||
what should be considered a string when we encode to BSON.
|
||||
|
||||
Raises TypeError if `data` is not an instance of :class:`bytes`
|
||||
or `subtype` is not an instance of :class:`int`.
|
||||
**(BETA)** Subtype 9 provides a space-efficient representation of 1-dimensional vector data.
|
||||
Its data is prepended with two bytes of metadata.
|
||||
The first (dtype) describes its data type, such as float32 or int8.
|
||||
The second (padding) prescribes the number of bits to ignore in the final byte.
|
||||
This is relevant when the element size of the dtype is not a multiple of 8.
|
||||
|
||||
Raises TypeError if `subtype` is not an instance of :class:`int`.
|
||||
Raises ValueError if `subtype` is not in [0, 256).
|
||||
|
||||
.. note::
|
||||
@ -218,7 +275,10 @@ class Binary(bytes):
|
||||
to use
|
||||
|
||||
.. versionchanged:: 3.9
|
||||
Support any bytes-like type that implements the buffer protocol.
|
||||
Support any bytes-like type that implements the buffer protocol.
|
||||
|
||||
.. versionchanged:: 4.10
|
||||
**(BETA)** Addition of vector subtype.
|
||||
"""
|
||||
|
||||
_type_marker = 5
|
||||
@ -337,6 +397,86 @@ class Binary(bytes):
|
||||
f"cannot decode subtype {self.subtype} to {UUID_REPRESENTATION_NAMES[uuid_representation]}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_vector(
|
||||
cls: Type[Binary],
|
||||
vector: list[int, float],
|
||||
dtype: BinaryVectorDtype,
|
||||
padding: int = 0,
|
||||
) -> Binary:
|
||||
"""**(BETA)** Create a BSON :class:`~bson.binary.Binary` of Vector subtype from a list of Numbers.
|
||||
|
||||
To interpret the representation of the numbers, a data type must be included.
|
||||
See :class:`~bson.binary.BinaryVectorDtype` for available types and descriptions.
|
||||
|
||||
The dtype and padding are prepended to the binary data's value.
|
||||
|
||||
:param vector: List of values
|
||||
:param dtype: Data type of the values
|
||||
:param padding: For fractional bytes, number of bits to ignore at end of vector.
|
||||
:return: Binary packed data identified by dtype and padding.
|
||||
|
||||
.. versionadded:: 4.10
|
||||
"""
|
||||
if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8
|
||||
format_str = "b"
|
||||
if padding:
|
||||
raise ValueError(f"padding does not apply to {dtype=}")
|
||||
elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8
|
||||
format_str = "B"
|
||||
elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32
|
||||
format_str = "f"
|
||||
if padding:
|
||||
raise ValueError(f"padding does not apply to {dtype=}")
|
||||
else:
|
||||
raise NotImplementedError("%s not yet supported" % dtype)
|
||||
|
||||
metadata = struct.pack("<sB", dtype.value, padding)
|
||||
data = struct.pack(f"{len(vector)}{format_str}", *vector)
|
||||
return cls(metadata + data, subtype=VECTOR_SUBTYPE)
|
||||
|
||||
def as_vector(self) -> BinaryVector:
|
||||
"""**(BETA)** From the Binary, create a list of numbers, along with dtype and padding.
|
||||
|
||||
:return: BinaryVector
|
||||
|
||||
.. versionadded:: 4.10
|
||||
"""
|
||||
|
||||
if self.subtype != VECTOR_SUBTYPE:
|
||||
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector.")
|
||||
|
||||
position = 0
|
||||
dtype, padding = struct.unpack_from("<sB", self, position)
|
||||
position += 2
|
||||
dtype = BinaryVectorDtype(dtype)
|
||||
n_values = len(self) - position
|
||||
|
||||
if dtype == BinaryVectorDtype.INT8:
|
||||
dtype_format = "b"
|
||||
format_string = f"{n_values}{dtype_format}"
|
||||
vector = list(struct.unpack_from(format_string, self, position))
|
||||
return BinaryVector(vector, dtype, padding)
|
||||
|
||||
elif dtype == BinaryVectorDtype.FLOAT32:
|
||||
n_bytes = len(self) - position
|
||||
n_values = n_bytes // 4
|
||||
if n_bytes % 4:
|
||||
raise ValueError(
|
||||
"Corrupt data. N bytes for a float32 vector must be a multiple of 4."
|
||||
)
|
||||
vector = list(struct.unpack_from(f"{n_values}f", self, position))
|
||||
return BinaryVector(vector, dtype, padding)
|
||||
|
||||
elif dtype == BinaryVectorDtype.PACKED_BIT:
|
||||
# data packed as uint8
|
||||
dtype_format = "B"
|
||||
unpacked_uint8s = list(struct.unpack_from(f"{n_values}{dtype_format}", self, position))
|
||||
return BinaryVector(unpacked_uint8s, dtype, padding)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Binary Vector dtype %s not yet supported" % dtype.name)
|
||||
|
||||
@property
|
||||
def subtype(self) -> int:
|
||||
"""Subtype of this binary data."""
|
||||
|
||||
@ -21,6 +21,14 @@
|
||||
.. autoclass:: UuidRepresentation
|
||||
:members:
|
||||
|
||||
.. autoclass:: BinaryVectorDtype
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: BinaryVector
|
||||
:members:
|
||||
|
||||
|
||||
.. autoclass:: Binary(data, subtype=BINARY_SUBTYPE)
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
Async Tutorial
|
||||
==============
|
||||
|
||||
.. warning:: This API is currently in beta, meaning the classes, methods,
|
||||
and behaviors described within may change before the full release.
|
||||
If you come across any bugs during your use of this API,
|
||||
please file a Jira ticket in the "Python Driver" project at https://jira.mongodb.org/browse/PYTHON.
|
||||
|
||||
.. code-block:: pycon
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
|
||||
@ -1,6 +1,24 @@
|
||||
Changelog
|
||||
=========
|
||||
|
||||
Changes in Version 4.10.0
|
||||
-------------------------
|
||||
|
||||
- Added provisional **(BETA)** support for a new Binary BSON subtype (9) used for efficient storage and retrieval of vectors:
|
||||
densely packed arrays of numbers, all of the same type.
|
||||
This includes new methods :meth:`~bson.binary.Binary.from_vector` and :meth:`~bson.binary.Binary.as_vector`.
|
||||
- Added C extension use to client metadata, for example: ``{"driver": {"name": "PyMongo|c", "version": "4.10.0"}, ...}``
|
||||
- Fixed a bug where :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` could deadlock.
|
||||
- Fixed a bug where PyMongo could fail to import on Windows if ``asyncio`` is misconfigured.
|
||||
|
||||
Issues Resolved
|
||||
...............
|
||||
|
||||
See the `PyMongo 4.10 release notes in JIRA`_ for the list of resolved issues
|
||||
in this release.
|
||||
|
||||
.. _PyMongo 4.10 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=40553
|
||||
|
||||
Changes in Version 4.9.0
|
||||
-------------------------
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import re
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
__version__ = "4.10.0.dev0"
|
||||
__version__ = "4.11.0.dev0"
|
||||
|
||||
|
||||
def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]:
|
||||
|
||||
360
test/asynchronous/helpers.py
Normal file
360
test/asynchronous/helpers.py
Normal file
@ -0,0 +1,360 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import gc
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Generator, no_type_check
|
||||
from unittest import SkipTest
|
||||
|
||||
from bson.son import SON
|
||||
from pymongo import common, message
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
if HAVE_SSL:
|
||||
import ssl
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
|
||||
if hasattr(gc, "set_debug"):
|
||||
gc.set_debug(
|
||||
gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0)
|
||||
)
|
||||
|
||||
# The host and port of a single mongod or mongos, or the seed host
|
||||
# for a replica set.
|
||||
host = os.environ.get("DB_IP", "localhost")
|
||||
port = int(os.environ.get("DB_PORT", 27017))
|
||||
IS_SRV = "mongodb+srv" in host
|
||||
|
||||
db_user = os.environ.get("DB_USER", "user")
|
||||
db_pwd = os.environ.get("DB_PASSWORD", "password")
|
||||
|
||||
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
|
||||
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
|
||||
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
|
||||
|
||||
TLS_OPTIONS: Dict = {"tls": True}
|
||||
if CLIENT_PEM:
|
||||
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
|
||||
if CA_PEM:
|
||||
TLS_OPTIONS["tlsCAFile"] = CA_PEM
|
||||
|
||||
COMPRESSORS = os.environ.get("COMPRESSORS")
|
||||
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
|
||||
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
|
||||
TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS"))
|
||||
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
|
||||
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
|
||||
|
||||
if TEST_LOADBALANCER:
|
||||
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
|
||||
host, port = res["nodelist"][0]
|
||||
db_user = res["username"] or db_user
|
||||
db_pwd = res["password"] or db_pwd
|
||||
elif TEST_SERVERLESS:
|
||||
TEST_LOADBALANCER = True
|
||||
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
|
||||
host, port = res["nodelist"][0]
|
||||
db_user = res["username"] or db_user
|
||||
db_pwd = res["password"] or db_pwd
|
||||
TLS_OPTIONS = {"tls": True}
|
||||
# Spec says serverless tests must be run with compression.
|
||||
COMPRESSORS = COMPRESSORS or "zlib"
|
||||
|
||||
|
||||
# Shared KMS data.
|
||||
LOCAL_MASTER_KEY = base64.b64decode(
|
||||
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
|
||||
b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
|
||||
)
|
||||
AWS_CREDS = {
|
||||
"accessKeyId": os.environ.get("FLE_AWS_KEY", ""),
|
||||
"secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""),
|
||||
}
|
||||
AWS_CREDS_2 = {
|
||||
"accessKeyId": os.environ.get("FLE_AWS_KEY2", ""),
|
||||
"secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""),
|
||||
}
|
||||
AZURE_CREDS = {
|
||||
"tenantId": os.environ.get("FLE_AZURE_TENANTID", ""),
|
||||
"clientId": os.environ.get("FLE_AZURE_CLIENTID", ""),
|
||||
"clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""),
|
||||
}
|
||||
GCP_CREDS = {
|
||||
"email": os.environ.get("FLE_GCP_EMAIL", ""),
|
||||
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
|
||||
}
|
||||
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
|
||||
|
||||
# Ensure Evergreen metadata doesn't result in truncation
|
||||
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
|
||||
|
||||
|
||||
def is_server_resolvable():
|
||||
"""Returns True if 'server' is resolvable."""
|
||||
socket_timeout = socket.getdefaulttimeout()
|
||||
socket.setdefaulttimeout(1)
|
||||
try:
|
||||
try:
|
||||
socket.gethostbyname("server")
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
finally:
|
||||
socket.setdefaulttimeout(socket_timeout)
|
||||
|
||||
|
||||
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
|
||||
cmd = SON([("createUser", user)])
|
||||
# X509 doesn't use a password
|
||||
if pwd:
|
||||
cmd["pwd"] = pwd
|
||||
cmd["roles"] = roles or ["root"]
|
||||
cmd.update(**kwargs)
|
||||
return authdb.command(cmd)
|
||||
|
||||
|
||||
class client_knobs:
|
||||
def __init__(
|
||||
self,
|
||||
heartbeat_frequency=None,
|
||||
min_heartbeat_interval=None,
|
||||
kill_cursor_frequency=None,
|
||||
events_queue_frequency=None,
|
||||
):
|
||||
self.heartbeat_frequency = heartbeat_frequency
|
||||
self.min_heartbeat_interval = min_heartbeat_interval
|
||||
self.kill_cursor_frequency = kill_cursor_frequency
|
||||
self.events_queue_frequency = events_queue_frequency
|
||||
|
||||
self.old_heartbeat_frequency = None
|
||||
self.old_min_heartbeat_interval = None
|
||||
self.old_kill_cursor_frequency = None
|
||||
self.old_events_queue_frequency = None
|
||||
self._enabled = False
|
||||
self._stack = None
|
||||
|
||||
def enable(self):
|
||||
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
|
||||
self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
|
||||
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
|
||||
self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
|
||||
|
||||
if self.heartbeat_frequency is not None:
|
||||
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
|
||||
|
||||
if self.min_heartbeat_interval is not None:
|
||||
common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
|
||||
|
||||
if self.kill_cursor_frequency is not None:
|
||||
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
|
||||
|
||||
if self.events_queue_frequency is not None:
|
||||
common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
|
||||
self._enabled = True
|
||||
# Store the allocation traceback to catch non-disabled client_knobs.
|
||||
self._stack = "".join(traceback.format_stack())
|
||||
|
||||
def __enter__(self):
|
||||
self.enable()
|
||||
|
||||
@no_type_check
|
||||
def disable(self):
|
||||
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
|
||||
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
|
||||
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
|
||||
common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
|
||||
self._enabled = False
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.disable()
|
||||
|
||||
def __call__(self, func):
|
||||
def make_wrapper(f):
|
||||
@wraps(f)
|
||||
async def wrap(*args, **kwargs):
|
||||
with self:
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
return wrap
|
||||
|
||||
return make_wrapper(func)
|
||||
|
||||
def __del__(self):
|
||||
if self._enabled:
|
||||
msg = (
|
||||
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
|
||||
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
|
||||
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
|
||||
common.HEARTBEAT_FREQUENCY,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
common.KILL_CURSOR_FREQUENCY,
|
||||
common.EVENTS_QUEUE_FREQUENCY,
|
||||
self._stack,
|
||||
)
|
||||
)
|
||||
self.disable()
|
||||
raise Exception(msg)
|
||||
|
||||
|
||||
def _all_users(db):
|
||||
return {u["user"] for u in db.command("usersInfo").get("users", [])}
|
||||
|
||||
|
||||
def sanitize_cmd(cmd):
|
||||
cp = cmd.copy()
|
||||
cp.pop("$clusterTime", None)
|
||||
cp.pop("$db", None)
|
||||
cp.pop("$readPreference", None)
|
||||
cp.pop("lsid", None)
|
||||
if MONGODB_API_VERSION:
|
||||
# Stable API parameters
|
||||
cp.pop("apiVersion", None)
|
||||
# OP_MSG encoding may move the payload type one field to the
|
||||
# end of the command. Do the same here.
|
||||
name = next(iter(cp))
|
||||
try:
|
||||
identifier = message._FIELD_MAP[name]
|
||||
docs = cp.pop(identifier)
|
||||
cp[identifier] = docs
|
||||
except KeyError:
|
||||
pass
|
||||
return cp
|
||||
|
||||
|
||||
def sanitize_reply(reply):
|
||||
cp = reply.copy()
|
||||
cp.pop("$clusterTime", None)
|
||||
cp.pop("operationTime", None)
|
||||
return cp
|
||||
|
||||
|
||||
def print_thread_tracebacks() -> None:
|
||||
"""Print all Python thread tracebacks."""
|
||||
for thread_id, frame in sys._current_frames().items():
|
||||
sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n")
|
||||
traceback.print_stack(frame, file=sys.stderr)
|
||||
|
||||
|
||||
def print_thread_stacks(pid: int) -> None:
|
||||
"""Print all C-level thread stacks for a given process id."""
|
||||
if sys.platform == "darwin":
|
||||
cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"']
|
||||
else:
|
||||
cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"']
|
||||
|
||||
try:
|
||||
res = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
|
||||
)
|
||||
except Exception as exc:
|
||||
sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}")
|
||||
else:
|
||||
sys.stderr.write(res.stdout)
|
||||
|
||||
|
||||
# Global knobs to speed up the test suite.
|
||||
global_knobs = client_knobs(events_queue_frequency=0.05)
|
||||
|
||||
|
||||
def _get_executors(topology):
|
||||
executors = []
|
||||
for server in topology._servers.values():
|
||||
# Some MockMonitor do not have an _executor.
|
||||
if hasattr(server._monitor, "_executor"):
|
||||
executors.append(server._monitor._executor)
|
||||
if hasattr(server._monitor, "_rtt_monitor"):
|
||||
executors.append(server._monitor._rtt_monitor._executor)
|
||||
executors.append(topology._Topology__events_executor)
|
||||
if topology._srv_monitor:
|
||||
executors.append(topology._srv_monitor._executor)
|
||||
|
||||
return [e for e in executors if e is not None]
|
||||
|
||||
|
||||
def print_running_topology(topology):
|
||||
running = [e for e in _get_executors(topology) if not e._stopped]
|
||||
if running:
|
||||
print(
|
||||
"WARNING: found Topology with running threads:\n"
|
||||
f" Threads: {running}\n"
|
||||
f" Topology: {topology}\n"
|
||||
f" Creation traceback:\n{topology._settings._stack}"
|
||||
)
|
||||
|
||||
|
||||
def test_cases(suite):
|
||||
"""Iterator over all TestCases within a TestSuite."""
|
||||
for suite_or_case in suite._tests:
|
||||
if isinstance(suite_or_case, unittest.TestCase):
|
||||
# unittest.TestCase
|
||||
yield suite_or_case
|
||||
else:
|
||||
# unittest.TestSuite
|
||||
yield from test_cases(suite_or_case)
|
||||
|
||||
|
||||
# Helper method to workaround https://bugs.python.org/issue21724
|
||||
def clear_warning_registry():
|
||||
"""Clear the __warningregistry__ for all modules."""
|
||||
for _, module in list(sys.modules.items()):
|
||||
if hasattr(module, "__warningregistry__"):
|
||||
module.__warningregistry__ = {} # type:ignore[attr-defined]
|
||||
|
||||
|
||||
class SystemCertsPatcher:
|
||||
def __init__(self, ca_certs):
|
||||
if (
|
||||
ssl.OPENSSL_VERSION.lower().startswith("libressl")
|
||||
and sys.platform == "darwin"
|
||||
and not _ssl.IS_PYOPENSSL
|
||||
):
|
||||
raise SkipTest(
|
||||
"LibreSSL on OSX doesn't support setting CA certificates "
|
||||
"using SSL_CERT_FILE environment variable."
|
||||
)
|
||||
self.original_certs = os.environ.get("SSL_CERT_FILE")
|
||||
# Tell OpenSSL where CA certificates live.
|
||||
os.environ["SSL_CERT_FILE"] = ca_certs
|
||||
|
||||
def disable(self):
|
||||
if self.original_certs is None:
|
||||
os.environ.pop("SSL_CERT_FILE")
|
||||
else:
|
||||
os.environ["SSL_CERT_FILE"] = self.original_certs
|
||||
191
test/asynchronous/test_retryable_reads.py
Normal file
191
test/asynchronous/test_retryable_reads.py
Normal file
@ -0,0 +1,191 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test retryable reads spec."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from pymongo.errors import AutoReconnect
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
AsyncPyMongoTestCase,
|
||||
async_client_context,
|
||||
client_knobs,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
OvertCommandListener,
|
||||
async_set_fail_point,
|
||||
)
|
||||
|
||||
from pymongo.monitoring import (
|
||||
ConnectionCheckedOutEvent,
|
||||
ConnectionCheckOutFailedEvent,
|
||||
ConnectionCheckOutFailedReason,
|
||||
PoolClearedEvent,
|
||||
)
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestClientOptions(AsyncPyMongoTestCase):
|
||||
async def test_default(self):
|
||||
client = self.simple_client(connect=False)
|
||||
self.assertEqual(client.options.retry_reads, True)
|
||||
|
||||
async def test_kwargs(self):
|
||||
client = self.simple_client(retryReads=True, connect=False)
|
||||
self.assertEqual(client.options.retry_reads, True)
|
||||
client = self.simple_client(retryReads=False, connect=False)
|
||||
self.assertEqual(client.options.retry_reads, False)
|
||||
|
||||
async def test_uri(self):
|
||||
client = self.simple_client("mongodb://h/?retryReads=true", connect=False)
|
||||
self.assertEqual(client.options.retry_reads, True)
|
||||
client = self.simple_client("mongodb://h/?retryReads=false", connect=False)
|
||||
self.assertEqual(client.options.retry_reads, False)
|
||||
|
||||
|
||||
class FindThread(threading.Thread):
|
||||
def __init__(self, collection):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self.collection = collection
|
||||
self.passed = False
|
||||
|
||||
async def run(self):
|
||||
await self.collection.find_one({})
|
||||
self.passed = True
|
||||
|
||||
|
||||
class TestPoolPausedError(AsyncIntegrationTest):
|
||||
# Pools don't get paused in load balanced mode.
|
||||
RUN_ON_LOAD_BALANCER = False
|
||||
RUN_ON_SERVERLESS = False
|
||||
|
||||
@async_client_context.require_sync
|
||||
@async_client_context.require_failCommand_blockConnection
|
||||
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
|
||||
async def test_pool_paused_error_is_retryable(self):
|
||||
if "PyPy" in sys.version:
|
||||
# Tracked in PYTHON-3519
|
||||
self.skipTest("Test is flakey on PyPy")
|
||||
cmap_listener = CMAPListener()
|
||||
cmd_listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
|
||||
)
|
||||
for _ in range(10):
|
||||
cmap_listener.reset()
|
||||
cmd_listener.reset()
|
||||
threads = [FindThread(client.pymongo_test.test) for _ in range(2)]
|
||||
fail_command = {
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["find"],
|
||||
"blockConnection": True,
|
||||
"blockTimeMS": 1000,
|
||||
"errorCode": 91,
|
||||
},
|
||||
}
|
||||
async with self.fail_point(fail_command):
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
for thread in threads:
|
||||
self.assertTrue(thread.passed)
|
||||
|
||||
# It's possible that SDAM can rediscover the server and mark the
|
||||
# pool ready before the thread in the wait queue has a chance
|
||||
# to run. Repeat the test until the thread actually encounters
|
||||
# a PoolClearedError.
|
||||
if cmap_listener.event_count(ConnectionCheckOutFailedEvent):
|
||||
break
|
||||
|
||||
# Via CMAP monitoring, assert that the first check out succeeds.
|
||||
cmap_events = cmap_listener.events_by_type(
|
||||
(ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, PoolClearedEvent)
|
||||
)
|
||||
msg = pprint.pformat(cmap_listener.events)
|
||||
self.assertIsInstance(cmap_events[0], ConnectionCheckedOutEvent, msg)
|
||||
self.assertIsInstance(cmap_events[1], PoolClearedEvent, msg)
|
||||
self.assertIsInstance(cmap_events[2], ConnectionCheckOutFailedEvent, msg)
|
||||
self.assertEqual(cmap_events[2].reason, ConnectionCheckOutFailedReason.CONN_ERROR, msg)
|
||||
self.assertIsInstance(cmap_events[3], ConnectionCheckedOutEvent, msg)
|
||||
|
||||
# Connection check out failures are not reflected in command
|
||||
# monitoring because we only publish command events _after_ checking
|
||||
# out a connection.
|
||||
started = cmd_listener.started_events
|
||||
msg = pprint.pformat(cmd_listener.results)
|
||||
self.assertEqual(3, len(started), msg)
|
||||
succeeded = cmd_listener.succeeded_events
|
||||
self.assertEqual(2, len(succeeded), msg)
|
||||
failed = cmd_listener.failed_events
|
||||
self.assertEqual(1, len(failed), msg)
|
||||
|
||||
|
||||
class TestRetryableReads(AsyncIntegrationTest):
|
||||
@async_client_context.require_multiple_mongoses
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def test_retryable_reads_in_sharded_cluster_multiple_available(self):
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["find"],
|
||||
"closeConnection": True,
|
||||
"appName": "retryableReadTest",
|
||||
},
|
||||
}
|
||||
|
||||
mongos_clients = []
|
||||
|
||||
for mongos in async_client_context.mongos_seeds().split(","):
|
||||
client = await self.async_rs_or_single_client(mongos)
|
||||
await async_set_fail_point(client, fail_command)
|
||||
mongos_clients.append(client)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
async_client_context.mongos_seeds(),
|
||||
appName="retryableReadTest",
|
||||
event_listeners=[listener],
|
||||
retryReads=True,
|
||||
)
|
||||
|
||||
async with self.fail_point(fail_command):
|
||||
with self.assertRaises(AutoReconnect):
|
||||
await client.t.t.find_one({})
|
||||
|
||||
# Disable failpoints on each mongos
|
||||
for client in mongos_clients:
|
||||
fail_command["mode"] = "off"
|
||||
await async_set_fail_point(client, fail_command)
|
||||
|
||||
self.assertEqual(len(listener.failed_events), 2)
|
||||
self.assertEqual(len(listener.succeeded_events), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
694
test/asynchronous/test_retryable_writes.py
Normal file
694
test/asynchronous/test_retryable_writes.py
Normal file
@ -0,0 +1,694 @@
|
||||
# Copyright 2017 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test retryable writes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import pprint
|
||||
import sys
|
||||
import threading
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
SkipTest,
|
||||
async_client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.asynchronous.helpers import client_knobs
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
DeprecationFilter,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
async_set_fail_point,
|
||||
)
|
||||
from test.version import Version
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS
|
||||
from bson.int64 import Int64
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.son import SON
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
ConnectionFailure,
|
||||
OperationFailure,
|
||||
ServerSelectionTimeoutError,
|
||||
WriteConcernError,
|
||||
)
|
||||
from pymongo.monitoring import (
|
||||
CommandSucceededEvent,
|
||||
ConnectionCheckedOutEvent,
|
||||
ConnectionCheckOutFailedEvent,
|
||||
ConnectionCheckOutFailedReason,
|
||||
PoolClearedEvent,
|
||||
)
|
||||
from pymongo.operations import (
|
||||
DeleteMany,
|
||||
DeleteOne,
|
||||
InsertOne,
|
||||
ReplaceOne,
|
||||
UpdateMany,
|
||||
UpdateOne,
|
||||
)
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class InsertEventListener(EventListener):
|
||||
def succeeded(self, event: CommandSucceededEvent) -> None:
|
||||
super().succeeded(event)
|
||||
if (
|
||||
event.command_name == "insert"
|
||||
and event.reply.get("writeConcernError", {}).get("code", None) == 91
|
||||
):
|
||||
async_client_context.client.admin.command(
|
||||
{
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"errorCode": 10107,
|
||||
"errorLabels": ["RetryableWriteError", "NoWritesPerformed"],
|
||||
"failCommands": ["insert"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def retryable_single_statement_ops(coll):
|
||||
return [
|
||||
(coll.bulk_write, [[InsertOne({}), InsertOne({})]], {}),
|
||||
(coll.bulk_write, [[InsertOne({}), InsertOne({})]], {"ordered": False}),
|
||||
(coll.bulk_write, [[ReplaceOne({}, {"a1": 1})]], {}),
|
||||
(coll.bulk_write, [[ReplaceOne({}, {"a2": 1}), ReplaceOne({}, {"a3": 1})]], {}),
|
||||
(
|
||||
coll.bulk_write,
|
||||
[[UpdateOne({}, {"$set": {"a4": 1}}), UpdateOne({}, {"$set": {"a5": 1}})]],
|
||||
{},
|
||||
),
|
||||
(coll.bulk_write, [[DeleteOne({})]], {}),
|
||||
(coll.bulk_write, [[DeleteOne({}), DeleteOne({})]], {}),
|
||||
(coll.insert_one, [{}], {}),
|
||||
(coll.insert_many, [[{}, {}]], {}),
|
||||
(coll.replace_one, [{}, {"a6": 1}], {}),
|
||||
(coll.update_one, [{}, {"$set": {"a7": 1}}], {}),
|
||||
(coll.delete_one, [{}], {}),
|
||||
(coll.find_one_and_replace, [{}, {"a8": 1}], {}),
|
||||
(coll.find_one_and_update, [{}, {"$set": {"a9": 1}}], {}),
|
||||
(coll.find_one_and_delete, [{}, {"a10": 1}], {}),
|
||||
]
|
||||
|
||||
|
||||
def non_retryable_single_statement_ops(coll):
|
||||
return [
|
||||
(
|
||||
coll.bulk_write,
|
||||
[[UpdateOne({}, {"$set": {"a": 1}}), UpdateMany({}, {"$set": {"a": 1}})]],
|
||||
{},
|
||||
),
|
||||
(coll.bulk_write, [[DeleteOne({}), DeleteMany({})]], {}),
|
||||
(coll.update_many, [{}, {"$set": {"a": 1}}], {}),
|
||||
(coll.delete_many, [{}], {}),
|
||||
]
|
||||
|
||||
|
||||
class IgnoreDeprecationsTest(AsyncIntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
RUN_ON_SERVERLESS = True
|
||||
deprecation_filter: DeprecationFilter
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.deprecation_filter = DeprecationFilter()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.deprecation_filter.stop()
|
||||
await super()._tearDown_class()
|
||||
|
||||
|
||||
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
@async_client_context.require_no_standalone
|
||||
async def test_actionable_error_message(self):
|
||||
if async_client_context.storage_engine != "mmapv1":
|
||||
raise SkipTest("This cluster is not running MMAPv1")
|
||||
|
||||
expected_msg = (
|
||||
"This MongoDB deployment does not support retryable "
|
||||
"writes. Please add retryWrites=false to your "
|
||||
"connection string."
|
||||
)
|
||||
for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test):
|
||||
with self.assertRaisesRegex(OperationFailure, expected_msg):
|
||||
await method(*args, **kwargs)
|
||||
|
||||
|
||||
class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
listener: OvertCommandListener
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_no_mmap
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
retryWrites=True, event_listeners=[cls.listener]
|
||||
)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
if async_client_context.is_rs and async_client_context.test_commands_enabled:
|
||||
await self.client.admin.command(
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
|
||||
)
|
||||
|
||||
async def asyncTearDown(self):
|
||||
if async_client_context.is_rs and async_client_context.test_commands_enabled:
|
||||
await self.client.admin.command(
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
||||
)
|
||||
|
||||
async def test_supported_single_statement_no_retry(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(retryWrites=False, event_listeners=[listener])
|
||||
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
listener.reset()
|
||||
await method(*args, **kwargs)
|
||||
for event in listener.started_events:
|
||||
self.assertNotIn(
|
||||
"txnNumber",
|
||||
event.command,
|
||||
f"{msg} sent txnNumber with {event.command_name}",
|
||||
)
|
||||
|
||||
@async_client_context.require_no_standalone
|
||||
async def test_supported_single_statement_supported_cluster(self):
|
||||
for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test):
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
self.listener.reset()
|
||||
await method(*args, **kwargs)
|
||||
commands_started = self.listener.started_events
|
||||
self.assertEqual(len(self.listener.succeeded_events), 1, msg)
|
||||
first_attempt = commands_started[0]
|
||||
self.assertIn(
|
||||
"lsid",
|
||||
first_attempt.command,
|
||||
f"{msg} sent no lsid with {first_attempt.command_name}",
|
||||
)
|
||||
initial_session_id = first_attempt.command["lsid"]
|
||||
self.assertIn(
|
||||
"txnNumber",
|
||||
first_attempt.command,
|
||||
f"{msg} sent no txnNumber with {first_attempt.command_name}",
|
||||
)
|
||||
|
||||
# There should be no retry when the failpoint is not active.
|
||||
if async_client_context.is_mongos or not async_client_context.test_commands_enabled:
|
||||
self.assertEqual(len(commands_started), 1)
|
||||
continue
|
||||
|
||||
initial_transaction_id = first_attempt.command["txnNumber"]
|
||||
retry_attempt = commands_started[1]
|
||||
self.assertIn(
|
||||
"lsid",
|
||||
retry_attempt.command,
|
||||
f"{msg} sent no lsid with {first_attempt.command_name}",
|
||||
)
|
||||
self.assertEqual(retry_attempt.command["lsid"], initial_session_id, msg)
|
||||
self.assertIn(
|
||||
"txnNumber",
|
||||
retry_attempt.command,
|
||||
f"{msg} sent no txnNumber with {first_attempt.command_name}",
|
||||
)
|
||||
self.assertEqual(retry_attempt.command["txnNumber"], initial_transaction_id, msg)
|
||||
|
||||
async def test_supported_single_statement_unsupported_cluster(self):
|
||||
if async_client_context.is_rs or async_client_context.is_mongos:
|
||||
raise SkipTest("This cluster supports retryable writes")
|
||||
|
||||
for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test):
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
self.listener.reset()
|
||||
await method(*args, **kwargs)
|
||||
|
||||
for event in self.listener.started_events:
|
||||
self.assertNotIn(
|
||||
"txnNumber",
|
||||
event.command,
|
||||
f"{msg} sent txnNumber with {event.command_name}",
|
||||
)
|
||||
|
||||
async def test_unsupported_single_statement(self):
|
||||
coll = self.db.retryable_write_test
|
||||
await coll.insert_many([{}, {}])
|
||||
coll_w0 = coll.with_options(write_concern=WriteConcern(w=0))
|
||||
for method, args, kwargs in non_retryable_single_statement_ops(
|
||||
coll
|
||||
) + retryable_single_statement_ops(coll_w0):
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
self.listener.reset()
|
||||
await method(*args, **kwargs)
|
||||
started_events = self.listener.started_events
|
||||
self.assertEqual(len(self.listener.succeeded_events), len(started_events), msg)
|
||||
self.assertEqual(len(self.listener.failed_events), 0, msg)
|
||||
for event in started_events:
|
||||
self.assertNotIn(
|
||||
"txnNumber",
|
||||
event.command,
|
||||
f"{msg} sent txnNumber with {event.command_name}",
|
||||
)
|
||||
|
||||
async def test_server_selection_timeout_not_retried(self):
|
||||
"""A ServerSelectionTimeoutError is not retried."""
|
||||
listener = OvertCommandListener()
|
||||
client = self.simple_client(
|
||||
"somedomainthatdoesntexist.org",
|
||||
serverSelectionTimeoutMS=1,
|
||||
retryWrites=True,
|
||||
event_listeners=[listener],
|
||||
)
|
||||
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
listener.reset()
|
||||
with self.assertRaises(ServerSelectionTimeoutError, msg=msg):
|
||||
await method(*args, **kwargs)
|
||||
self.assertEqual(len(listener.started_events), 0, msg)
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_test_commands
|
||||
async def test_retry_timeout_raises_original_error(self):
|
||||
"""A ServerSelectionTimeoutError on the retry attempt raises the
|
||||
original error.
|
||||
"""
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
topology = client._topology
|
||||
select_server = topology.select_server
|
||||
|
||||
def mock_select_server(*args, **kwargs):
|
||||
server = select_server(*args, **kwargs)
|
||||
|
||||
def raise_error(*args, **kwargs):
|
||||
raise ServerSelectionTimeoutError("No primary available for writes")
|
||||
|
||||
# Raise ServerSelectionTimeout on the retry attempt.
|
||||
topology.select_server = raise_error
|
||||
return server
|
||||
|
||||
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
listener.reset()
|
||||
topology.select_server = mock_select_server
|
||||
with self.assertRaises(ConnectionFailure, msg=msg):
|
||||
await method(*args, **kwargs)
|
||||
self.assertEqual(len(listener.started_events), 1, msg)
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_test_commands
|
||||
async def test_batch_splitting(self):
|
||||
"""Test retry succeeds after failures during batch splitting."""
|
||||
large = "s" * 1024 * 1024 * 15
|
||||
coll = self.db.retryable_write_test
|
||||
await coll.delete_many({})
|
||||
self.listener.reset()
|
||||
bulk_result = await coll.bulk_write(
|
||||
[
|
||||
InsertOne({"_id": 1, "l": large}),
|
||||
InsertOne({"_id": 2, "l": large}),
|
||||
InsertOne({"_id": 3, "l": large}),
|
||||
UpdateOne({"_id": 1, "l": large}, {"$unset": {"l": 1}, "$inc": {"count": 1}}),
|
||||
UpdateOne({"_id": 2, "l": large}, {"$set": {"foo": "bar"}}),
|
||||
DeleteOne({"l": large}),
|
||||
DeleteOne({"l": large}),
|
||||
]
|
||||
)
|
||||
# Each command should fail and be retried.
|
||||
# With OP_MSG 3 inserts are one batch. 2 updates another.
|
||||
# 2 deletes a third.
|
||||
self.assertEqual(len(self.listener.started_events), 6)
|
||||
self.assertEqual(await coll.find_one(), {"_id": 1, "count": 1})
|
||||
# Assert the final result
|
||||
expected_result = {
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 3,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 2,
|
||||
"nModified": 2,
|
||||
"nRemoved": 2,
|
||||
"upserted": [],
|
||||
}
|
||||
self.assertEqual(bulk_result.bulk_api_result, expected_result)
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_test_commands
|
||||
async def test_batch_splitting_retry_fails(self):
|
||||
"""Test retry fails during batch splitting."""
|
||||
large = "s" * 1024 * 1024 * 15
|
||||
coll = self.db.retryable_write_test
|
||||
await coll.delete_many({})
|
||||
await self.client.admin.command(
|
||||
SON(
|
||||
[
|
||||
("configureFailPoint", "onPrimaryTransactionalWrite"),
|
||||
("mode", {"skip": 3}), # The number of _documents_ to skip.
|
||||
("data", {"failBeforeCommitExceptionCode": 1}),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.listener.reset()
|
||||
async with self.client.start_session() as session:
|
||||
initial_txn = session._transaction_id
|
||||
try:
|
||||
await coll.bulk_write(
|
||||
[
|
||||
InsertOne({"_id": 1, "l": large}),
|
||||
InsertOne({"_id": 2, "l": large}),
|
||||
InsertOne({"_id": 3, "l": large}),
|
||||
InsertOne({"_id": 4, "l": large}),
|
||||
],
|
||||
session=session,
|
||||
)
|
||||
except ConnectionFailure:
|
||||
pass
|
||||
else:
|
||||
self.fail("bulk_write should have failed")
|
||||
|
||||
started = self.listener.started_events
|
||||
self.assertEqual(len(started), 3)
|
||||
self.assertEqual(len(self.listener.succeeded_events), 1)
|
||||
expected_txn = Int64(initial_txn + 1)
|
||||
self.assertEqual(started[0].command["txnNumber"], expected_txn)
|
||||
self.assertEqual(started[0].command["lsid"], session.session_id)
|
||||
expected_txn = Int64(initial_txn + 2)
|
||||
self.assertEqual(started[1].command["txnNumber"], expected_txn)
|
||||
self.assertEqual(started[1].command["lsid"], session.session_id)
|
||||
started[1].command.pop("$clusterTime")
|
||||
started[2].command.pop("$clusterTime")
|
||||
self.assertEqual(started[1].command, started[2].command)
|
||||
final_txn = session._transaction_id
|
||||
self.assertEqual(final_txn, expected_txn)
|
||||
self.assertEqual(await coll.find_one(projection={"_id": True}), {"_id": 1})
|
||||
|
||||
@async_client_context.require_multiple_mongoses
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def test_retryable_writes_in_sharded_cluster_multiple_available(self):
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["insert"],
|
||||
"closeConnection": True,
|
||||
"appName": "retryableWriteTest",
|
||||
},
|
||||
}
|
||||
|
||||
mongos_clients = []
|
||||
|
||||
for mongos in async_client_context.mongos_seeds().split(","):
|
||||
client = await self.async_rs_or_single_client(mongos)
|
||||
await async_set_fail_point(client, fail_command)
|
||||
mongos_clients.append(client)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
async_client_context.mongos_seeds(),
|
||||
appName="retryableWriteTest",
|
||||
event_listeners=[listener],
|
||||
retryWrites=True,
|
||||
)
|
||||
|
||||
with self.assertRaises(AutoReconnect):
|
||||
await client.t.t.insert_one({"x": 1})
|
||||
|
||||
# Disable failpoints on each mongos
|
||||
for client in mongos_clients:
|
||||
fail_command["mode"] = "off"
|
||||
await async_set_fail_point(client, fail_command)
|
||||
|
||||
self.assertEqual(len(listener.failed_events), 2)
|
||||
self.assertEqual(len(listener.succeeded_events), 0)
|
||||
|
||||
|
||||
class TestWriteConcernError(AsyncIntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
RUN_ON_SERVERLESS = True
|
||||
fail_insert: dict
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_no_mmap
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.fail_insert = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 2},
|
||||
"data": {
|
||||
"failCommands": ["insert"],
|
||||
"writeConcernError": {"code": 91, "errmsg": "Replication is being shut down"},
|
||||
},
|
||||
}
|
||||
|
||||
@async_client_context.require_version_min(4, 0)
|
||||
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
|
||||
async def test_RetryableWriteError_error_label(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
|
||||
# Ensure collection exists.
|
||||
await client.pymongo_test.testcoll.insert_one({})
|
||||
|
||||
async with self.fail_point(self.fail_insert):
|
||||
with self.assertRaises(WriteConcernError) as cm:
|
||||
await client.pymongo_test.testcoll.insert_one({})
|
||||
self.assertTrue(cm.exception.has_error_label("RetryableWriteError"))
|
||||
|
||||
if async_client_context.version >= Version(4, 4):
|
||||
# In MongoDB 4.4+ we rely on the server returning the error label.
|
||||
self.assertIn("RetryableWriteError", listener.succeeded_events[-1].reply["errorLabels"])
|
||||
|
||||
@async_client_context.require_version_min(4, 4)
|
||||
async def test_RetryableWriteError_error_label_RawBSONDocument(self):
|
||||
# using RawBSONDocument should not cause errorLabel parsing to fail
|
||||
async with self.fail_point(self.fail_insert):
|
||||
async with self.client.start_session() as s:
|
||||
s._start_retryable_write()
|
||||
result = await self.client.pymongo_test.command(
|
||||
"insert",
|
||||
"testcoll",
|
||||
documents=[{"_id": 1}],
|
||||
txnNumber=s._transaction_id,
|
||||
session=s,
|
||||
codec_options=DEFAULT_CODEC_OPTIONS.with_options(
|
||||
document_class=RawBSONDocument
|
||||
),
|
||||
)
|
||||
|
||||
self.assertIn("writeConcernError", result)
|
||||
self.assertIn("RetryableWriteError", result["errorLabels"])
|
||||
|
||||
|
||||
class InsertThread(threading.Thread):
|
||||
def __init__(self, collection):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self.collection = collection
|
||||
self.passed = False
|
||||
|
||||
async def run(self):
|
||||
await self.collection.insert_one({})
|
||||
self.passed = True
|
||||
|
||||
|
||||
class TestPoolPausedError(AsyncIntegrationTest):
|
||||
# Pools don't get paused in load balanced mode.
|
||||
RUN_ON_LOAD_BALANCER = False
|
||||
RUN_ON_SERVERLESS = False
|
||||
|
||||
@async_client_context.require_sync
|
||||
@async_client_context.require_failCommand_blockConnection
|
||||
@async_client_context.require_retryable_writes
|
||||
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
|
||||
async def test_pool_paused_error_is_retryable(self):
|
||||
cmap_listener = CMAPListener()
|
||||
cmd_listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
|
||||
)
|
||||
for _ in range(10):
|
||||
cmap_listener.reset()
|
||||
cmd_listener.reset()
|
||||
threads = [InsertThread(client.pymongo_test.test) for _ in range(2)]
|
||||
fail_command = {
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["insert"],
|
||||
"blockConnection": True,
|
||||
"blockTimeMS": 1000,
|
||||
"errorCode": 91,
|
||||
"errorLabels": ["RetryableWriteError"],
|
||||
},
|
||||
}
|
||||
async with self.fail_point(fail_command):
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
for thread in threads:
|
||||
self.assertTrue(thread.passed)
|
||||
# It's possible that SDAM can rediscover the server and mark the
|
||||
# pool ready before the thread in the wait queue has a chance
|
||||
# to run. Repeat the test until the thread actually encounters
|
||||
# a PoolClearedError.
|
||||
if cmap_listener.event_count(ConnectionCheckOutFailedEvent):
|
||||
break
|
||||
|
||||
# Via CMAP monitoring, assert that the first check out succeeds.
|
||||
cmap_events = cmap_listener.events_by_type(
|
||||
(ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, PoolClearedEvent)
|
||||
)
|
||||
msg = pprint.pformat(cmap_listener.events)
|
||||
self.assertIsInstance(cmap_events[0], ConnectionCheckedOutEvent, msg)
|
||||
self.assertIsInstance(cmap_events[1], PoolClearedEvent, msg)
|
||||
self.assertIsInstance(cmap_events[2], ConnectionCheckOutFailedEvent, msg)
|
||||
self.assertEqual(cmap_events[2].reason, ConnectionCheckOutFailedReason.CONN_ERROR, msg)
|
||||
self.assertIsInstance(cmap_events[3], ConnectionCheckedOutEvent, msg)
|
||||
|
||||
# Connection check out failures are not reflected in command
|
||||
# monitoring because we only publish command events _after_ checking
|
||||
# out a connection.
|
||||
started = cmd_listener.started_events
|
||||
msg = pprint.pformat(cmd_listener.results)
|
||||
self.assertEqual(3, len(started), msg)
|
||||
succeeded = cmd_listener.succeeded_events
|
||||
self.assertEqual(2, len(succeeded), msg)
|
||||
failed = cmd_listener.failed_events
|
||||
self.assertEqual(1, len(failed), msg)
|
||||
|
||||
@async_client_context.require_sync
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_version_min(
|
||||
6, 0, 0
|
||||
) # the spec requires that this prose test only be run on 6.0+
|
||||
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
|
||||
async def test_returns_original_error_code(
|
||||
self,
|
||||
):
|
||||
cmd_listener = InsertEventListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
retryWrites=True, event_listeners=[cmd_listener]
|
||||
)
|
||||
await client.test.test.drop()
|
||||
cmd_listener.reset()
|
||||
await client.admin.command(
|
||||
{
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"writeConcernError": {
|
||||
"code": 91,
|
||||
"errorLabels": ["RetryableWriteError"],
|
||||
},
|
||||
"failCommands": ["insert"],
|
||||
},
|
||||
}
|
||||
)
|
||||
with self.assertRaises(WriteConcernError) as exc:
|
||||
await client.test.test.insert_one({"_id": 1})
|
||||
self.assertEqual(exc.exception.code, 91)
|
||||
await client.admin.command(
|
||||
{
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": "off",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# TODO: Make this a real integration test where we stepdown the primary.
|
||||
class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_no_mmap
|
||||
async def test_increment_transaction_id_without_sending_command(self):
|
||||
"""Test that the txnNumber field is properly incremented, even when
|
||||
the first attempt fails before sending the command.
|
||||
"""
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
topology = client._topology
|
||||
select_server = topology.select_server
|
||||
|
||||
def raise_connection_err_select_server(*args, **kwargs):
|
||||
# Raise ConnectionFailure on the first attempt and perform
|
||||
# normal selection on the retry attempt.
|
||||
topology.select_server = select_server
|
||||
raise ConnectionFailure("Connection refused")
|
||||
|
||||
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
|
||||
listener.reset()
|
||||
topology.select_server = raise_connection_err_select_server
|
||||
async with client.start_session() as session:
|
||||
kwargs = copy.deepcopy(kwargs)
|
||||
kwargs["session"] = session
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
initial_txn_id = session._transaction_id
|
||||
|
||||
# Each operation should fail on the first attempt and succeed
|
||||
# on the second.
|
||||
await method(*args, **kwargs)
|
||||
self.assertEqual(len(listener.started_events), 1, msg)
|
||||
retry_cmd = listener.started_events[0].command
|
||||
sent_txn_id = retry_cmd["txnNumber"]
|
||||
final_txn_id = session._transaction_id
|
||||
self.assertEqual(Int64(initial_txn_id + 1), sent_txn_id, msg)
|
||||
self.assertEqual(sent_txn_id, final_txn_id, msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
42
test/bson_binary_vector/float32.json
Normal file
42
test/bson_binary_vector/float32.json
Normal file
@ -0,0 +1,42 @@
|
||||
{
|
||||
"description": "Tests of Binary subtype 9, Vectors, with dtype FLOAT32",
|
||||
"test_key": "vector",
|
||||
"tests": [
|
||||
{
|
||||
"description": "Simple Vector FLOAT32",
|
||||
"valid": true,
|
||||
"vector": [127.0, 7.0],
|
||||
"dtype_hex": "0x27",
|
||||
"dtype_alias": "FLOAT32",
|
||||
"padding": 0,
|
||||
"canonical_bson": "1C00000005766563746F72000A0000000927000000FE420000E04000"
|
||||
},
|
||||
{
|
||||
"description": "Empty Vector FLOAT32",
|
||||
"valid": true,
|
||||
"vector": [],
|
||||
"dtype_hex": "0x27",
|
||||
"dtype_alias": "FLOAT32",
|
||||
"padding": 0,
|
||||
"canonical_bson": "1400000005766563746F72000200000009270000"
|
||||
},
|
||||
{
|
||||
"description": "Infinity Vector FLOAT32",
|
||||
"valid": true,
|
||||
"vector": ["-inf", 0.0, "inf"],
|
||||
"dtype_hex": "0x27",
|
||||
"dtype_alias": "FLOAT32",
|
||||
"padding": 0,
|
||||
"canonical_bson": "2000000005766563746F72000E000000092700000080FF000000000000807F00"
|
||||
},
|
||||
{
|
||||
"description": "FLOAT32 with padding",
|
||||
"valid": false,
|
||||
"vector": [127.0, 7.0],
|
||||
"dtype_hex": "0x27",
|
||||
"dtype_alias": "FLOAT32",
|
||||
"padding": 3
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
57
test/bson_binary_vector/int8.json
Normal file
57
test/bson_binary_vector/int8.json
Normal file
@ -0,0 +1,57 @@
|
||||
{
|
||||
"description": "Tests of Binary subtype 9, Vectors, with dtype INT8",
|
||||
"test_key": "vector",
|
||||
"tests": [
|
||||
{
|
||||
"description": "Simple Vector INT8",
|
||||
"valid": true,
|
||||
"vector": [127, 7],
|
||||
"dtype_hex": "0x03",
|
||||
"dtype_alias": "INT8",
|
||||
"padding": 0,
|
||||
"canonical_bson": "1600000005766563746F7200040000000903007F0700"
|
||||
},
|
||||
{
|
||||
"description": "Empty Vector INT8",
|
||||
"valid": true,
|
||||
"vector": [],
|
||||
"dtype_hex": "0x03",
|
||||
"dtype_alias": "INT8",
|
||||
"padding": 0,
|
||||
"canonical_bson": "1400000005766563746F72000200000009030000"
|
||||
},
|
||||
{
|
||||
"description": "Overflow Vector INT8",
|
||||
"valid": false,
|
||||
"vector": [128],
|
||||
"dtype_hex": "0x03",
|
||||
"dtype_alias": "INT8",
|
||||
"padding": 0
|
||||
},
|
||||
{
|
||||
"description": "Underflow Vector INT8",
|
||||
"valid": false,
|
||||
"vector": [-129],
|
||||
"dtype_hex": "0x03",
|
||||
"dtype_alias": "INT8",
|
||||
"padding": 0
|
||||
},
|
||||
{
|
||||
"description": "INT8 with padding",
|
||||
"valid": false,
|
||||
"vector": [127, 7],
|
||||
"dtype_hex": "0x03",
|
||||
"dtype_alias": "INT8",
|
||||
"padding": 3
|
||||
},
|
||||
{
|
||||
"description": "INT8 with float inputs",
|
||||
"valid": false,
|
||||
"vector": [127.77, 7.77],
|
||||
"dtype_hex": "0x03",
|
||||
"dtype_alias": "INT8",
|
||||
"padding": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
50
test/bson_binary_vector/packed_bit.json
Normal file
50
test/bson_binary_vector/packed_bit.json
Normal file
@ -0,0 +1,50 @@
|
||||
{
|
||||
"description": "Tests of Binary subtype 9, Vectors, with dtype PACKED_BIT",
|
||||
"test_key": "vector",
|
||||
"tests": [
|
||||
{
|
||||
"description": "Simple Vector PACKED_BIT",
|
||||
"valid": true,
|
||||
"vector": [127, 7],
|
||||
"dtype_hex": "0x10",
|
||||
"dtype_alias": "PACKED_BIT",
|
||||
"padding": 0,
|
||||
"canonical_bson": "1600000005766563746F7200040000000910007F0700"
|
||||
},
|
||||
{
|
||||
"description": "Empty Vector PACKED_BIT",
|
||||
"valid": true,
|
||||
"vector": [],
|
||||
"dtype_hex": "0x10",
|
||||
"dtype_alias": "PACKED_BIT",
|
||||
"padding": 0,
|
||||
"canonical_bson": "1400000005766563746F72000200000009100000"
|
||||
},
|
||||
{
|
||||
"description": "PACKED_BIT with padding",
|
||||
"valid": true,
|
||||
"vector": [127, 7],
|
||||
"dtype_hex": "0x10",
|
||||
"dtype_alias": "PACKED_BIT",
|
||||
"padding": 3,
|
||||
"canonical_bson": "1600000005766563746F7200040000000910037F0700"
|
||||
},
|
||||
{
|
||||
"description": "Overflow Vector PACKED_BIT",
|
||||
"valid": false,
|
||||
"vector": [256],
|
||||
"dtype_hex": "0x10",
|
||||
"dtype_alias": "PACKED_BIT",
|
||||
"padding": 0
|
||||
},
|
||||
{
|
||||
"description": "Underflow Vector PACKED_BIT",
|
||||
"valid": false,
|
||||
"vector": [-1],
|
||||
"dtype_hex": "0x10",
|
||||
"dtype_alias": "PACKED_BIT",
|
||||
"padding": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -74,6 +74,36 @@
|
||||
"description": "$type query operator (conflicts with legacy $binary form with $type field)",
|
||||
"canonical_bson": "180000000378001000000010247479706500020000000000",
|
||||
"canonical_extjson": "{\"x\" : { \"$type\" : {\"$numberInt\": \"2\"}}}"
|
||||
},
|
||||
{
|
||||
"description": "subtype 0x09 Vector FLOAT32",
|
||||
"canonical_bson": "170000000578000A0000000927000000FE420000E04000",
|
||||
"canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwAAAP5CAADgQA==\", \"subType\": \"09\"}}}"
|
||||
},
|
||||
{
|
||||
"description": "subtype 0x09 Vector INT8",
|
||||
"canonical_bson": "11000000057800040000000903007F0700",
|
||||
"canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwB/Bw==\", \"subType\": \"09\"}}}"
|
||||
},
|
||||
{
|
||||
"description": "subtype 0x09 Vector PACKED_BIT",
|
||||
"canonical_bson": "11000000057800040000000910007F0700",
|
||||
"canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAB/Bw==\", \"subType\": \"09\"}}}"
|
||||
},
|
||||
{
|
||||
"description": "subtype 0x09 Vector (Zero-length) FLOAT32",
|
||||
"canonical_bson": "0F0000000578000200000009270000",
|
||||
"canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwA=\", \"subType\": \"09\"}}}"
|
||||
},
|
||||
{
|
||||
"description": "subtype 0x09 Vector (Zero-length) INT8",
|
||||
"canonical_bson": "0F0000000578000200000009030000",
|
||||
"canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwA=\", \"subType\": \"09\"}}}"
|
||||
},
|
||||
{
|
||||
"description": "subtype 0x09 Vector (Zero-length) PACKED_BIT",
|
||||
"canonical_bson": "0F0000000578000200000009100000",
|
||||
"canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAA=\", \"subType\": \"09\"}}}"
|
||||
}
|
||||
],
|
||||
"decodeErrors": [
|
||||
|
||||
@ -28,6 +28,7 @@ import time
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
@ -47,6 +48,8 @@ from pymongo.uri_parser import parse_uri
|
||||
if HAVE_SSL:
|
||||
import ssl
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
|
||||
if hasattr(gc, "set_debug"):
|
||||
gc.set_debug(
|
||||
|
||||
@ -49,8 +49,9 @@ from bson import (
|
||||
decode_iter,
|
||||
encode,
|
||||
is_valid,
|
||||
json_util,
|
||||
)
|
||||
from bson.binary import USER_DEFINED_SUBTYPE, Binary, UuidRepresentation
|
||||
from bson.binary import USER_DEFINED_SUBTYPE, Binary, BinaryVectorDtype, UuidRepresentation
|
||||
from bson.code import Code
|
||||
from bson.codec_options import CodecOptions, DatetimeConversion
|
||||
from bson.datetime_ms import _DATETIME_ERROR_SUGGESTION
|
||||
@ -148,6 +149,9 @@ class TestBSON(unittest.TestCase):
|
||||
helper({"a binary": Binary(b"test", 128)})
|
||||
helper({"a binary": Binary(b"test", 254)})
|
||||
helper({"another binary": Binary(b"test", 2)})
|
||||
helper({"binary packed bit vector": Binary(b"\x10\x00\x7f\x07", 9)})
|
||||
helper({"binary int8 vector": Binary(b"\x03\x00\x7f\x07", 9)})
|
||||
helper({"binary float32 vector": Binary(b"'\x00\x00\x00\xfeB\x00\x00\xe0@", 9)})
|
||||
helper(SON([("test dst", datetime.datetime(1993, 4, 4, 2))]))
|
||||
helper(SON([("test negative dst", datetime.datetime(1, 1, 1, 1, 1, 1))]))
|
||||
helper({"big float": float(10000000000)})
|
||||
@ -447,6 +451,20 @@ class TestBSON(unittest.TestCase):
|
||||
encode({"test": Binary(b"test", 128)}),
|
||||
b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00\x00\x00\x80\x74\x65\x73\x74\x00",
|
||||
)
|
||||
self.assertEqual(
|
||||
encode({"vector_int8": Binary.from_vector([-128, -1, 127], BinaryVectorDtype.INT8)}),
|
||||
b"\x1c\x00\x00\x00\x05vector_int8\x00\x05\x00\x00\x00\t\x03\x00\x80\xff\x7f\x00",
|
||||
)
|
||||
self.assertEqual(
|
||||
encode({"vector_bool": Binary.from_vector([1, 127], BinaryVectorDtype.PACKED_BIT)}),
|
||||
b"\x1b\x00\x00\x00\x05vector_bool\x00\x04\x00\x00\x00\t\x10\x00\x01\x7f\x00",
|
||||
)
|
||||
self.assertEqual(
|
||||
encode(
|
||||
{"vector_float32": Binary.from_vector([-1.1, 1.1e10], BinaryVectorDtype.FLOAT32)}
|
||||
),
|
||||
b"$\x00\x00\x00\x05vector_float32\x00\n\x00\x00\x00\t'\x00\xcd\xcc\x8c\xbf\xac\xe9#P\x00",
|
||||
)
|
||||
self.assertEqual(encode({"test": None}), b"\x0B\x00\x00\x00\x0A\x74\x65\x73\x74\x00\x00")
|
||||
self.assertEqual(
|
||||
encode({"date": datetime.datetime(2007, 1, 8, 0, 30, 11)}),
|
||||
@ -711,9 +729,66 @@ class TestBSON(unittest.TestCase):
|
||||
transformed = bin.as_uuid(UuidRepresentation.PYTHON_LEGACY)
|
||||
self.assertEqual(id, transformed)
|
||||
|
||||
# The C extension was segfaulting on unicode RegExs, so we have this test
|
||||
# that doesn't really test anything but the lack of a segfault.
|
||||
def test_vector(self):
|
||||
"""Tests of subtype 9"""
|
||||
# We start with valid cases, across the 3 dtypes implemented.
|
||||
# Work with a simple vector that can be interpreted as int8, float32, or ubyte
|
||||
list_vector = [127, 7]
|
||||
# As INT8, vector has length 2
|
||||
binary_vector = Binary.from_vector(list_vector, BinaryVectorDtype.INT8)
|
||||
vector = binary_vector.as_vector()
|
||||
assert vector.data == list_vector
|
||||
# test encoding roundtrip
|
||||
assert {"vector": binary_vector} == decode(encode({"vector": binary_vector}))
|
||||
# test json roundtrip
|
||||
assert binary_vector == json_util.loads(json_util.dumps(binary_vector))
|
||||
|
||||
# For vectors of bits, aka PACKED_BIT type, vector has length 8 * 2
|
||||
packed_bit_binary = Binary.from_vector(list_vector, BinaryVectorDtype.PACKED_BIT)
|
||||
packed_bit_vec = packed_bit_binary.as_vector()
|
||||
assert packed_bit_vec.data == list_vector
|
||||
|
||||
# A padding parameter permits vectors of length that aren't divisible by 8
|
||||
# The following ignores the last 3 bits in list_vector,
|
||||
# hence it's length is 8 * len(list_vector) - padding
|
||||
padding = 3
|
||||
padded_vec = Binary.from_vector(list_vector, BinaryVectorDtype.PACKED_BIT, padding=padding)
|
||||
assert padded_vec.as_vector().data == list_vector
|
||||
# To visualize how this looks as a binary vector..
|
||||
uncompressed = ""
|
||||
for val in list_vector:
|
||||
uncompressed += format(val, "08b")
|
||||
assert uncompressed[:-padding] == "0111111100000"
|
||||
|
||||
# It is worthwhile explicitly showing the values encoded to BSON
|
||||
padded_doc = {"padded_vec": padded_vec}
|
||||
assert (
|
||||
encode(padded_doc)
|
||||
== b"\x1a\x00\x00\x00\x05padded_vec\x00\x04\x00\x00\x00\t\x10\x03\x7f\x07\x00"
|
||||
)
|
||||
# and dumped to json
|
||||
assert (
|
||||
json_util.dumps(padded_doc)
|
||||
== '{"padded_vec": {"$binary": {"base64": "EAN/Bw==", "subType": "09"}}}'
|
||||
)
|
||||
|
||||
# FLOAT32 is also implemented
|
||||
float_binary = Binary.from_vector(list_vector, BinaryVectorDtype.FLOAT32)
|
||||
assert all(isinstance(d, float) for d in float_binary.as_vector().data)
|
||||
|
||||
# Now some invalid cases
|
||||
for x in [-1, 257]:
|
||||
try:
|
||||
Binary.from_vector([x], BinaryVectorDtype.PACKED_BIT)
|
||||
except Exception as exc:
|
||||
self.assertTrue(isinstance(exc, struct.error))
|
||||
else:
|
||||
self.fail("Failed to raise an exception.")
|
||||
|
||||
def test_unicode_regex(self):
|
||||
"""Tests we do not get a segfault for C extension on unicode RegExs.
|
||||
This had been happening.
|
||||
"""
|
||||
regex = re.compile("revisi\xf3n")
|
||||
decode(encode({"regex": regex}))
|
||||
|
||||
|
||||
105
test/test_bson_binary_vector.py
Normal file
105
test/test_bson_binary_vector.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import codecs
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from test import unittest
|
||||
|
||||
from bson import decode, encode
|
||||
from bson.binary import Binary, BinaryVectorDtype
|
||||
|
||||
_TEST_PATH = Path(__file__).parent / "bson_binary_vector"
|
||||
|
||||
|
||||
class TestBSONBinaryVector(unittest.TestCase):
|
||||
"""Runs Binary Vector subtype tests.
|
||||
|
||||
Follows the style of the BSON corpus specification tests.
|
||||
Tests are automatically generated on import
|
||||
from json files in _TEST_PATH via `create_tests`.
|
||||
The actual tests are defined in the inner function `run_test`
|
||||
of the test generator `create_test`."""
|
||||
|
||||
|
||||
def create_test(case_spec):
|
||||
"""Create standard test given specification in json.
|
||||
|
||||
We use the naming convention expected (exp) and observed (obj)
|
||||
to differentiate what is in the json (expected or suffix _exp)
|
||||
from what is produced by the API (observed or suffix _obs)
|
||||
"""
|
||||
test_key = case_spec.get("test_key")
|
||||
|
||||
def run_test(self):
|
||||
for test_case in case_spec.get("tests", []):
|
||||
description = test_case["description"]
|
||||
vector_exp = test_case["vector"]
|
||||
dtype_hex_exp = test_case["dtype_hex"]
|
||||
dtype_alias_exp = test_case.get("dtype_alias")
|
||||
padding_exp = test_case.get("padding", 0)
|
||||
canonical_bson_exp = test_case.get("canonical_bson")
|
||||
# Convert dtype hex string into bytes
|
||||
dtype_exp = BinaryVectorDtype(int(dtype_hex_exp, 16).to_bytes(1, byteorder="little"))
|
||||
|
||||
if test_case["valid"]:
|
||||
# Convert bson string to bytes
|
||||
cB_exp = binascii.unhexlify(canonical_bson_exp.encode("utf8"))
|
||||
decoded_doc = decode(cB_exp)
|
||||
binary_obs = decoded_doc[test_key]
|
||||
# Handle special float cases like '-inf'
|
||||
if dtype_exp in [BinaryVectorDtype.FLOAT32]:
|
||||
vector_exp = [float(x) for x in vector_exp]
|
||||
|
||||
# Test round-tripping canonical bson.
|
||||
self.assertEqual(encode(decoded_doc), cB_exp, description)
|
||||
|
||||
# Test BSON to Binary Vector
|
||||
vector_obs = binary_obs.as_vector()
|
||||
self.assertEqual(vector_obs.dtype, dtype_exp, description)
|
||||
if dtype_alias_exp:
|
||||
self.assertEqual(
|
||||
vector_obs.dtype, BinaryVectorDtype[dtype_alias_exp], description
|
||||
)
|
||||
self.assertEqual(vector_obs.data, vector_exp, description)
|
||||
self.assertEqual(vector_obs.padding, padding_exp, description)
|
||||
|
||||
# Test Binary Vector to BSON
|
||||
vector_exp = Binary.from_vector(vector_exp, dtype_exp, padding_exp)
|
||||
cB_obs = binascii.hexlify(encode({test_key: vector_exp})).decode().upper()
|
||||
self.assertEqual(cB_obs, canonical_bson_exp, description)
|
||||
|
||||
else:
|
||||
with self.assertRaises((struct.error, ValueError), msg=description):
|
||||
Binary.from_vector(vector_exp, dtype_exp, padding_exp)
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
def create_tests():
|
||||
for filename in _TEST_PATH.glob("*.json"):
|
||||
with codecs.open(str(filename), encoding="utf-8") as test_file:
|
||||
test_method = create_test(json.load(test_file))
|
||||
setattr(TestBSONBinaryVector, "test_" + filename.stem, test_method)
|
||||
|
||||
|
||||
create_tests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -44,8 +44,7 @@ from pymongo.monitoring import (
|
||||
PoolClearedEvent,
|
||||
)
|
||||
|
||||
# Location of JSON test specifications.
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_reads", "legacy")
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestClientOptions(PyMongoTestCase):
|
||||
@ -83,6 +82,7 @@ class TestPoolPausedError(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = False
|
||||
RUN_ON_SERVERLESS = False
|
||||
|
||||
@client_context.require_sync
|
||||
@client_context.require_failCommand_blockConnection
|
||||
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
|
||||
def test_pool_paused_error_is_retryable(self):
|
||||
@ -94,7 +94,6 @@ class TestPoolPausedError(IntegrationTest):
|
||||
client = self.rs_or_single_client(
|
||||
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
for _ in range(10):
|
||||
cmap_listener.reset()
|
||||
cmd_listener.reset()
|
||||
@ -165,7 +164,6 @@ class TestRetryableReads(IntegrationTest):
|
||||
for mongos in client_context.mongos_seeds().split(","):
|
||||
client = self.rs_or_single_client(mongos)
|
||||
set_fail_point(client, fail_command)
|
||||
self.addCleanup(client.close)
|
||||
mongos_clients.append(client)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Test retryable writes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import pprint
|
||||
import sys
|
||||
@ -22,7 +23,13 @@ import threading
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, SkipTest, client_context, client_knobs, unittest
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
SkipTest,
|
||||
client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.helpers import client_knobs
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
DeprecationFilter,
|
||||
@ -61,6 +68,8 @@ from pymongo.operations import (
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class InsertEventListener(EventListener):
|
||||
def succeeded(self, event: CommandSucceededEvent) -> None:
|
||||
@ -125,22 +134,22 @@ class IgnoreDeprecationsTest(IntegrationTest):
|
||||
deprecation_filter: DeprecationFilter
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.deprecation_filter = DeprecationFilter()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def _tearDown_class(cls):
|
||||
cls.deprecation_filter.stop()
|
||||
super().tearDownClass()
|
||||
super()._tearDown_class()
|
||||
|
||||
|
||||
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
@ -148,10 +157,10 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
cls.client.close()
|
||||
super().tearDownClass()
|
||||
super()._tearDown_class()
|
||||
|
||||
@client_context.require_no_standalone
|
||||
def test_actionable_error_message(self):
|
||||
@ -174,8 +183,8 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
|
||||
@classmethod
|
||||
@client_context.require_no_mmap
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
@ -186,10 +195,10 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
cls.client.close()
|
||||
super().tearDownClass()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
if client_context.is_rs and client_context.test_commands_enabled:
|
||||
@ -206,7 +215,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
def test_supported_single_statement_no_retry(self):
|
||||
listener = OvertCommandListener()
|
||||
client = self.rs_or_single_client(retryWrites=False, event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
|
||||
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
|
||||
listener.reset()
|
||||
@ -319,7 +327,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
"""
|
||||
listener = OvertCommandListener()
|
||||
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
topology = client._topology
|
||||
select_server = topology.select_server
|
||||
|
||||
@ -446,7 +453,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
for mongos in client_context.mongos_seeds().split(","):
|
||||
client = self.rs_or_single_client(mongos)
|
||||
set_fail_point(client, fail_command)
|
||||
self.addCleanup(client.close)
|
||||
mongos_clients.append(client)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
@ -478,8 +484,8 @@ class TestWriteConcernError(IntegrationTest):
|
||||
@client_context.require_replica_set
|
||||
@client_context.require_no_mmap
|
||||
@client_context.require_failCommand_fail_point
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.fail_insert = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 2},
|
||||
@ -494,7 +500,6 @@ class TestWriteConcernError(IntegrationTest):
|
||||
def test_RetryableWriteError_error_label(self):
|
||||
listener = OvertCommandListener()
|
||||
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
|
||||
# Ensure collection exists.
|
||||
client.pymongo_test.testcoll.insert_one({})
|
||||
@ -546,6 +551,7 @@ class TestPoolPausedError(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = False
|
||||
RUN_ON_SERVERLESS = False
|
||||
|
||||
@client_context.require_sync
|
||||
@client_context.require_failCommand_blockConnection
|
||||
@client_context.require_retryable_writes
|
||||
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
|
||||
@ -555,7 +561,6 @@ class TestPoolPausedError(IntegrationTest):
|
||||
client = self.rs_or_single_client(
|
||||
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
for _ in range(10):
|
||||
cmap_listener.reset()
|
||||
cmd_listener.reset()
|
||||
@ -606,6 +611,7 @@ class TestPoolPausedError(IntegrationTest):
|
||||
failed = cmd_listener.failed_events
|
||||
self.assertEqual(1, len(failed), msg)
|
||||
|
||||
@client_context.require_sync
|
||||
@client_context.require_failCommand_fail_point
|
||||
@client_context.require_replica_set
|
||||
@client_context.require_version_min(
|
||||
@ -618,7 +624,6 @@ class TestPoolPausedError(IntegrationTest):
|
||||
cmd_listener = InsertEventListener()
|
||||
client = self.rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener])
|
||||
client.test.test.drop()
|
||||
self.addCleanup(client.close)
|
||||
cmd_listener.reset()
|
||||
client.admin.command(
|
||||
{
|
||||
@ -654,7 +659,6 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
|
||||
"""
|
||||
listener = OvertCommandListener()
|
||||
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
topology = client._topology
|
||||
select_server = topology.select_server
|
||||
|
||||
|
||||
@ -1157,3 +1157,9 @@ def set_fail_point(client, command_args):
|
||||
cmd = SON([("configureFailPoint", "failCommand")])
|
||||
cmd.update(command_args)
|
||||
client.admin.command(cmd)
|
||||
|
||||
|
||||
async def async_set_fail_point(client, command_args):
|
||||
cmd = SON([("configureFailPoint", "failCommand")])
|
||||
cmd.update(command_args)
|
||||
await client.admin.command(cmd)
|
||||
|
||||
@ -104,6 +104,7 @@ replacements = {
|
||||
"PyMongo|c|async": "PyMongo|c",
|
||||
"AsyncTestGridFile": "TestGridFile",
|
||||
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
|
||||
"async_set_fail_point": "set_fail_point",
|
||||
}
|
||||
|
||||
docstring_replacements: dict[tuple[str, str], str] = {
|
||||
@ -173,6 +174,7 @@ sync_gridfs_files = [
|
||||
converted_tests = [
|
||||
"__init__.py",
|
||||
"conftest.py",
|
||||
"helpers.py",
|
||||
"pymongo_mocks.py",
|
||||
"utils_spec_runner.py",
|
||||
"qcheck.py",
|
||||
@ -191,6 +193,8 @@ converted_tests = [
|
||||
"test_logger.py",
|
||||
"test_monitoring.py",
|
||||
"test_raw_bson.py",
|
||||
"test_retryable_reads.py",
|
||||
"test_retryable_writes.py",
|
||||
"test_session.py",
|
||||
"test_transactions.py",
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user