PYTHON-4790 Migrate test_retryable_writes.py to async (#1876)

This commit is contained in:
Iris 2024-10-01 08:39:57 -07:00 committed by GitHub
parent c0f7810d56
commit 8791aa00ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1092 additions and 22 deletions

View File

@ -0,0 +1,360 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations
import base64
import gc
import multiprocessing
import os
import signal
import socket
import subprocess
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from functools import wraps
from typing import Any, Callable, Dict, Generator, no_type_check
from unittest import SkipTest
from bson.son import SON
from pymongo import common, message
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.uri_parser import parse_uri
if HAVE_SSL:
import ssl
_IS_SYNC = False
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
if hasattr(gc, "set_debug"):
gc.set_debug(
gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0)
)
# The host and port of a single mongod or mongos, or the seed host
# for a replica set.
host = os.environ.get("DB_IP", "localhost")
port = int(os.environ.get("DB_PORT", 27017))
IS_SRV = "mongodb+srv" in host
db_user = os.environ.get("DB_USER", "user")
db_pwd = os.environ.get("DB_PASSWORD", "password")
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
TLS_OPTIONS: Dict = {"tls": True}
if CLIENT_PEM:
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
if CA_PEM:
TLS_OPTIONS["tlsCAFile"] = CA_PEM
COMPRESSORS = os.environ.get("COMPRESSORS")
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS"))
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
if TEST_LOADBALANCER:
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res["nodelist"][0]
db_user = res["username"] or db_user
db_pwd = res["password"] or db_pwd
elif TEST_SERVERLESS:
TEST_LOADBALANCER = True
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res["nodelist"][0]
db_user = res["username"] or db_user
db_pwd = res["password"] or db_pwd
TLS_OPTIONS = {"tls": True}
# Spec says serverless tests must be run with compression.
COMPRESSORS = COMPRESSORS or "zlib"
# Shared KMS data.
LOCAL_MASTER_KEY = base64.b64decode(
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
)
AWS_CREDS = {
"accessKeyId": os.environ.get("FLE_AWS_KEY", ""),
"secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""),
}
AWS_CREDS_2 = {
"accessKeyId": os.environ.get("FLE_AWS_KEY2", ""),
"secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""),
}
AZURE_CREDS = {
"tenantId": os.environ.get("FLE_AZURE_TENANTID", ""),
"clientId": os.environ.get("FLE_AZURE_CLIENTID", ""),
"clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""),
}
GCP_CREDS = {
"email": os.environ.get("FLE_GCP_EMAIL", ""),
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
}
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
# Ensure Evergreen metadata doesn't result in truncation
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
def is_server_resolvable():
"""Returns True if 'server' is resolvable."""
socket_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(1)
try:
try:
socket.gethostbyname("server")
return True
except OSError:
return False
finally:
socket.setdefaulttimeout(socket_timeout)
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
cmd = SON([("createUser", user)])
# X509 doesn't use a password
if pwd:
cmd["pwd"] = pwd
cmd["roles"] = roles or ["root"]
cmd.update(**kwargs)
return authdb.command(cmd)
class client_knobs:
def __init__(
self,
heartbeat_frequency=None,
min_heartbeat_interval=None,
kill_cursor_frequency=None,
events_queue_frequency=None,
):
self.heartbeat_frequency = heartbeat_frequency
self.min_heartbeat_interval = min_heartbeat_interval
self.kill_cursor_frequency = kill_cursor_frequency
self.events_queue_frequency = events_queue_frequency
self.old_heartbeat_frequency = None
self.old_min_heartbeat_interval = None
self.old_kill_cursor_frequency = None
self.old_events_queue_frequency = None
self._enabled = False
self._stack = None
def enable(self):
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
if self.heartbeat_frequency is not None:
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
if self.min_heartbeat_interval is not None:
common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
if self.kill_cursor_frequency is not None:
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
if self.events_queue_frequency is not None:
common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
self._enabled = True
# Store the allocation traceback to catch non-disabled client_knobs.
self._stack = "".join(traceback.format_stack())
def __enter__(self):
self.enable()
@no_type_check
def disable(self):
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
self._enabled = False
def __exit__(self, exc_type, exc_val, exc_tb):
self.disable()
def __call__(self, func):
def make_wrapper(f):
@wraps(f)
async def wrap(*args, **kwargs):
with self:
return await f(*args, **kwargs)
return wrap
return make_wrapper(func)
def __del__(self):
if self._enabled:
msg = (
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
common.HEARTBEAT_FREQUENCY,
common.MIN_HEARTBEAT_INTERVAL,
common.KILL_CURSOR_FREQUENCY,
common.EVENTS_QUEUE_FREQUENCY,
self._stack,
)
)
self.disable()
raise Exception(msg)
def _all_users(db):
return {u["user"] for u in db.command("usersInfo").get("users", [])}
def sanitize_cmd(cmd):
cp = cmd.copy()
cp.pop("$clusterTime", None)
cp.pop("$db", None)
cp.pop("$readPreference", None)
cp.pop("lsid", None)
if MONGODB_API_VERSION:
# Stable API parameters
cp.pop("apiVersion", None)
# OP_MSG encoding may move the payload type one field to the
# end of the command. Do the same here.
name = next(iter(cp))
try:
identifier = message._FIELD_MAP[name]
docs = cp.pop(identifier)
cp[identifier] = docs
except KeyError:
pass
return cp
def sanitize_reply(reply):
cp = reply.copy()
cp.pop("$clusterTime", None)
cp.pop("operationTime", None)
return cp
def print_thread_tracebacks() -> None:
"""Print all Python thread tracebacks."""
for thread_id, frame in sys._current_frames().items():
sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n")
traceback.print_stack(frame, file=sys.stderr)
def print_thread_stacks(pid: int) -> None:
"""Print all C-level thread stacks for a given process id."""
if sys.platform == "darwin":
cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"']
else:
cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"']
try:
res = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
)
except Exception as exc:
sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}")
else:
sys.stderr.write(res.stdout)
# Global knobs to speed up the test suite.
global_knobs = client_knobs(events_queue_frequency=0.05)
def _get_executors(topology):
executors = []
for server in topology._servers.values():
# Some MockMonitor do not have an _executor.
if hasattr(server._monitor, "_executor"):
executors.append(server._monitor._executor)
if hasattr(server._monitor, "_rtt_monitor"):
executors.append(server._monitor._rtt_monitor._executor)
executors.append(topology._Topology__events_executor)
if topology._srv_monitor:
executors.append(topology._srv_monitor._executor)
return [e for e in executors if e is not None]
def print_running_topology(topology):
running = [e for e in _get_executors(topology) if not e._stopped]
if running:
print(
"WARNING: found Topology with running threads:\n"
f" Threads: {running}\n"
f" Topology: {topology}\n"
f" Creation traceback:\n{topology._settings._stack}"
)
def test_cases(suite):
"""Iterator over all TestCases within a TestSuite."""
for suite_or_case in suite._tests:
if isinstance(suite_or_case, unittest.TestCase):
# unittest.TestCase
yield suite_or_case
else:
# unittest.TestSuite
yield from test_cases(suite_or_case)
# Helper method to workaround https://bugs.python.org/issue21724
def clear_warning_registry():
"""Clear the __warningregistry__ for all modules."""
for _, module in list(sys.modules.items()):
if hasattr(module, "__warningregistry__"):
module.__warningregistry__ = {} # type:ignore[attr-defined]
class SystemCertsPatcher:
def __init__(self, ca_certs):
if (
ssl.OPENSSL_VERSION.lower().startswith("libressl")
and sys.platform == "darwin"
and not _ssl.IS_PYOPENSSL
):
raise SkipTest(
"LibreSSL on OSX doesn't support setting CA certificates "
"using SSL_CERT_FILE environment variable."
)
self.original_certs = os.environ.get("SSL_CERT_FILE")
# Tell OpenSSL where CA certificates live.
os.environ["SSL_CERT_FILE"] = ca_certs
def disable(self):
if self.original_certs is None:
os.environ.pop("SSL_CERT_FILE")
else:
os.environ["SSL_CERT_FILE"] = self.original_certs

View File

@ -0,0 +1,694 @@
# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test retryable writes."""
from __future__ import annotations
import asyncio
import copy
import pprint
import sys
import threading
sys.path[0:0] = [""]
from test.asynchronous import (
AsyncIntegrationTest,
SkipTest,
async_client_context,
unittest,
)
from test.asynchronous.helpers import client_knobs
from test.utils import (
CMAPListener,
DeprecationFilter,
EventListener,
OvertCommandListener,
async_set_fail_point,
)
from test.version import Version
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.errors import (
AutoReconnect,
ConnectionFailure,
OperationFailure,
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.monitoring import (
CommandSucceededEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
ConnectionCheckOutFailedReason,
PoolClearedEvent,
)
from pymongo.operations import (
DeleteMany,
DeleteOne,
InsertOne,
ReplaceOne,
UpdateMany,
UpdateOne,
)
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class InsertEventListener(EventListener):
def succeeded(self, event: CommandSucceededEvent) -> None:
super().succeeded(event)
if (
event.command_name == "insert"
and event.reply.get("writeConcernError", {}).get("code", None) == 91
):
async_client_context.client.admin.command(
{
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"errorCode": 10107,
"errorLabels": ["RetryableWriteError", "NoWritesPerformed"],
"failCommands": ["insert"],
},
}
)
def retryable_single_statement_ops(coll):
return [
(coll.bulk_write, [[InsertOne({}), InsertOne({})]], {}),
(coll.bulk_write, [[InsertOne({}), InsertOne({})]], {"ordered": False}),
(coll.bulk_write, [[ReplaceOne({}, {"a1": 1})]], {}),
(coll.bulk_write, [[ReplaceOne({}, {"a2": 1}), ReplaceOne({}, {"a3": 1})]], {}),
(
coll.bulk_write,
[[UpdateOne({}, {"$set": {"a4": 1}}), UpdateOne({}, {"$set": {"a5": 1}})]],
{},
),
(coll.bulk_write, [[DeleteOne({})]], {}),
(coll.bulk_write, [[DeleteOne({}), DeleteOne({})]], {}),
(coll.insert_one, [{}], {}),
(coll.insert_many, [[{}, {}]], {}),
(coll.replace_one, [{}, {"a6": 1}], {}),
(coll.update_one, [{}, {"$set": {"a7": 1}}], {}),
(coll.delete_one, [{}], {}),
(coll.find_one_and_replace, [{}, {"a8": 1}], {}),
(coll.find_one_and_update, [{}, {"$set": {"a9": 1}}], {}),
(coll.find_one_and_delete, [{}, {"a10": 1}], {}),
]
def non_retryable_single_statement_ops(coll):
return [
(
coll.bulk_write,
[[UpdateOne({}, {"$set": {"a": 1}}), UpdateMany({}, {"$set": {"a": 1}})]],
{},
),
(coll.bulk_write, [[DeleteOne({}), DeleteMany({})]], {}),
(coll.update_many, [{}, {"$set": {"a": 1}}], {}),
(coll.delete_many, [{}], {}),
]
class IgnoreDeprecationsTest(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True
deprecation_filter: DeprecationFilter
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.deprecation_filter = DeprecationFilter()
@classmethod
async def _tearDown_class(cls):
cls.deprecation_filter.stop()
await super()._tearDown_class()
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
knobs: client_knobs
@classmethod
async def _setup_class(cls):
await super()._setup_class()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True)
cls.db = cls.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
await cls.client.close()
await super()._tearDown_class()
@async_client_context.require_no_standalone
async def test_actionable_error_message(self):
if async_client_context.storage_engine != "mmapv1":
raise SkipTest("This cluster is not running MMAPv1")
expected_msg = (
"This MongoDB deployment does not support retryable "
"writes. Please add retryWrites=false to your "
"connection string."
)
for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test):
with self.assertRaisesRegex(OperationFailure, expected_msg):
await method(*args, **kwargs)
class TestRetryableWrites(IgnoreDeprecationsTest):
listener: OvertCommandListener
knobs: client_knobs
@classmethod
@async_client_context.require_no_mmap
async def _setup_class(cls):
await super()._setup_class()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(
retryWrites=True, event_listeners=[cls.listener]
)
cls.db = cls.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
await cls.client.close()
await super()._tearDown_class()
async def asyncSetUp(self):
if async_client_context.is_rs and async_client_context.test_commands_enabled:
await self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
)
async def asyncTearDown(self):
if async_client_context.is_rs and async_client_context.test_commands_enabled:
await self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
)
async def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(retryWrites=False, event_listeners=[listener])
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
listener.reset()
await method(*args, **kwargs)
for event in listener.started_events:
self.assertNotIn(
"txnNumber",
event.command,
f"{msg} sent txnNumber with {event.command_name}",
)
@async_client_context.require_no_standalone
async def test_supported_single_statement_supported_cluster(self):
for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test):
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
self.listener.reset()
await method(*args, **kwargs)
commands_started = self.listener.started_events
self.assertEqual(len(self.listener.succeeded_events), 1, msg)
first_attempt = commands_started[0]
self.assertIn(
"lsid",
first_attempt.command,
f"{msg} sent no lsid with {first_attempt.command_name}",
)
initial_session_id = first_attempt.command["lsid"]
self.assertIn(
"txnNumber",
first_attempt.command,
f"{msg} sent no txnNumber with {first_attempt.command_name}",
)
# There should be no retry when the failpoint is not active.
if async_client_context.is_mongos or not async_client_context.test_commands_enabled:
self.assertEqual(len(commands_started), 1)
continue
initial_transaction_id = first_attempt.command["txnNumber"]
retry_attempt = commands_started[1]
self.assertIn(
"lsid",
retry_attempt.command,
f"{msg} sent no lsid with {first_attempt.command_name}",
)
self.assertEqual(retry_attempt.command["lsid"], initial_session_id, msg)
self.assertIn(
"txnNumber",
retry_attempt.command,
f"{msg} sent no txnNumber with {first_attempt.command_name}",
)
self.assertEqual(retry_attempt.command["txnNumber"], initial_transaction_id, msg)
async def test_supported_single_statement_unsupported_cluster(self):
if async_client_context.is_rs or async_client_context.is_mongos:
raise SkipTest("This cluster supports retryable writes")
for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test):
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
self.listener.reset()
await method(*args, **kwargs)
for event in self.listener.started_events:
self.assertNotIn(
"txnNumber",
event.command,
f"{msg} sent txnNumber with {event.command_name}",
)
async def test_unsupported_single_statement(self):
coll = self.db.retryable_write_test
await coll.insert_many([{}, {}])
coll_w0 = coll.with_options(write_concern=WriteConcern(w=0))
for method, args, kwargs in non_retryable_single_statement_ops(
coll
) + retryable_single_statement_ops(coll_w0):
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
self.listener.reset()
await method(*args, **kwargs)
started_events = self.listener.started_events
self.assertEqual(len(self.listener.succeeded_events), len(started_events), msg)
self.assertEqual(len(self.listener.failed_events), 0, msg)
for event in started_events:
self.assertNotIn(
"txnNumber",
event.command,
f"{msg} sent txnNumber with {event.command_name}",
)
async def test_server_selection_timeout_not_retried(self):
"""A ServerSelectionTimeoutError is not retried."""
listener = OvertCommandListener()
client = self.simple_client(
"somedomainthatdoesntexist.org",
serverSelectionTimeoutMS=1,
retryWrites=True,
event_listeners=[listener],
)
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
listener.reset()
with self.assertRaises(ServerSelectionTimeoutError, msg=msg):
await method(*args, **kwargs)
self.assertEqual(len(listener.started_events), 0, msg)
@async_client_context.require_replica_set
@async_client_context.require_test_commands
async def test_retry_timeout_raises_original_error(self):
"""A ServerSelectionTimeoutError on the retry attempt raises the
original error.
"""
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener])
topology = client._topology
select_server = topology.select_server
def mock_select_server(*args, **kwargs):
server = select_server(*args, **kwargs)
def raise_error(*args, **kwargs):
raise ServerSelectionTimeoutError("No primary available for writes")
# Raise ServerSelectionTimeout on the retry attempt.
topology.select_server = raise_error
return server
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
listener.reset()
topology.select_server = mock_select_server
with self.assertRaises(ConnectionFailure, msg=msg):
await method(*args, **kwargs)
self.assertEqual(len(listener.started_events), 1, msg)
@async_client_context.require_replica_set
@async_client_context.require_test_commands
async def test_batch_splitting(self):
"""Test retry succeeds after failures during batch splitting."""
large = "s" * 1024 * 1024 * 15
coll = self.db.retryable_write_test
await coll.delete_many({})
self.listener.reset()
bulk_result = await coll.bulk_write(
[
InsertOne({"_id": 1, "l": large}),
InsertOne({"_id": 2, "l": large}),
InsertOne({"_id": 3, "l": large}),
UpdateOne({"_id": 1, "l": large}, {"$unset": {"l": 1}, "$inc": {"count": 1}}),
UpdateOne({"_id": 2, "l": large}, {"$set": {"foo": "bar"}}),
DeleteOne({"l": large}),
DeleteOne({"l": large}),
]
)
# Each command should fail and be retried.
# With OP_MSG 3 inserts are one batch. 2 updates another.
# 2 deletes a third.
self.assertEqual(len(self.listener.started_events), 6)
self.assertEqual(await coll.find_one(), {"_id": 1, "count": 1})
# Assert the final result
expected_result = {
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 3,
"nUpserted": 0,
"nMatched": 2,
"nModified": 2,
"nRemoved": 2,
"upserted": [],
}
self.assertEqual(bulk_result.bulk_api_result, expected_result)
@async_client_context.require_replica_set
@async_client_context.require_test_commands
async def test_batch_splitting_retry_fails(self):
"""Test retry fails during batch splitting."""
large = "s" * 1024 * 1024 * 15
coll = self.db.retryable_write_test
await coll.delete_many({})
await self.client.admin.command(
SON(
[
("configureFailPoint", "onPrimaryTransactionalWrite"),
("mode", {"skip": 3}), # The number of _documents_ to skip.
("data", {"failBeforeCommitExceptionCode": 1}),
]
)
)
self.listener.reset()
async with self.client.start_session() as session:
initial_txn = session._transaction_id
try:
await coll.bulk_write(
[
InsertOne({"_id": 1, "l": large}),
InsertOne({"_id": 2, "l": large}),
InsertOne({"_id": 3, "l": large}),
InsertOne({"_id": 4, "l": large}),
],
session=session,
)
except ConnectionFailure:
pass
else:
self.fail("bulk_write should have failed")
started = self.listener.started_events
self.assertEqual(len(started), 3)
self.assertEqual(len(self.listener.succeeded_events), 1)
expected_txn = Int64(initial_txn + 1)
self.assertEqual(started[0].command["txnNumber"], expected_txn)
self.assertEqual(started[0].command["lsid"], session.session_id)
expected_txn = Int64(initial_txn + 2)
self.assertEqual(started[1].command["txnNumber"], expected_txn)
self.assertEqual(started[1].command["lsid"], session.session_id)
started[1].command.pop("$clusterTime")
started[2].command.pop("$clusterTime")
self.assertEqual(started[1].command, started[2].command)
final_txn = session._transaction_id
self.assertEqual(final_txn, expected_txn)
self.assertEqual(await coll.find_one(projection={"_id": True}), {"_id": 1})
@async_client_context.require_multiple_mongoses
@async_client_context.require_failCommand_fail_point
async def test_retryable_writes_in_sharded_cluster_multiple_available(self):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"closeConnection": True,
"appName": "retryableWriteTest",
},
}
mongos_clients = []
for mongos in async_client_context.mongos_seeds().split(","):
client = await self.async_rs_or_single_client(mongos)
await async_set_fail_point(client, fail_command)
mongos_clients.append(client)
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
async_client_context.mongos_seeds(),
appName="retryableWriteTest",
event_listeners=[listener],
retryWrites=True,
)
with self.assertRaises(AutoReconnect):
await client.t.t.insert_one({"x": 1})
# Disable failpoints on each mongos
for client in mongos_clients:
fail_command["mode"] = "off"
await async_set_fail_point(client, fail_command)
self.assertEqual(len(listener.failed_events), 2)
self.assertEqual(len(listener.succeeded_events), 0)
class TestWriteConcernError(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True
fail_insert: dict
@classmethod
@async_client_context.require_replica_set
@async_client_context.require_no_mmap
@async_client_context.require_failCommand_fail_point
async def _setup_class(cls):
await super()._setup_class()
cls.fail_insert = {
"configureFailPoint": "failCommand",
"mode": {"times": 2},
"data": {
"failCommands": ["insert"],
"writeConcernError": {"code": 91, "errmsg": "Replication is being shut down"},
},
}
@async_client_context.require_version_min(4, 0)
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
async def test_RetryableWriteError_error_label(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener])
# Ensure collection exists.
await client.pymongo_test.testcoll.insert_one({})
async with self.fail_point(self.fail_insert):
with self.assertRaises(WriteConcernError) as cm:
await client.pymongo_test.testcoll.insert_one({})
self.assertTrue(cm.exception.has_error_label("RetryableWriteError"))
if async_client_context.version >= Version(4, 4):
# In MongoDB 4.4+ we rely on the server returning the error label.
self.assertIn("RetryableWriteError", listener.succeeded_events[-1].reply["errorLabels"])
@async_client_context.require_version_min(4, 4)
async def test_RetryableWriteError_error_label_RawBSONDocument(self):
# using RawBSONDocument should not cause errorLabel parsing to fail
async with self.fail_point(self.fail_insert):
async with self.client.start_session() as s:
s._start_retryable_write()
result = await self.client.pymongo_test.command(
"insert",
"testcoll",
documents=[{"_id": 1}],
txnNumber=s._transaction_id,
session=s,
codec_options=DEFAULT_CODEC_OPTIONS.with_options(
document_class=RawBSONDocument
),
)
self.assertIn("writeConcernError", result)
self.assertIn("RetryableWriteError", result["errorLabels"])
class InsertThread(threading.Thread):
def __init__(self, collection):
super().__init__()
self.daemon = True
self.collection = collection
self.passed = False
async def run(self):
await self.collection.insert_one({})
self.passed = True
class TestPoolPausedError(AsyncIntegrationTest):
# Pools don't get paused in load balanced mode.
RUN_ON_LOAD_BALANCER = False
RUN_ON_SERVERLESS = False
@async_client_context.require_sync
@async_client_context.require_failCommand_blockConnection
@async_client_context.require_retryable_writes
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
async def test_pool_paused_error_is_retryable(self):
cmap_listener = CMAPListener()
cmd_listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
)
for _ in range(10):
cmap_listener.reset()
cmd_listener.reset()
threads = [InsertThread(client.pymongo_test.test) for _ in range(2)]
fail_command = {
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"blockConnection": True,
"blockTimeMS": 1000,
"errorCode": 91,
"errorLabels": ["RetryableWriteError"],
},
}
async with self.fail_point(fail_command):
for thread in threads:
thread.start()
for thread in threads:
thread.join()
for thread in threads:
self.assertTrue(thread.passed)
# It's possible that SDAM can rediscover the server and mark the
# pool ready before the thread in the wait queue has a chance
# to run. Repeat the test until the thread actually encounters
# a PoolClearedError.
if cmap_listener.event_count(ConnectionCheckOutFailedEvent):
break
# Via CMAP monitoring, assert that the first check out succeeds.
cmap_events = cmap_listener.events_by_type(
(ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, PoolClearedEvent)
)
msg = pprint.pformat(cmap_listener.events)
self.assertIsInstance(cmap_events[0], ConnectionCheckedOutEvent, msg)
self.assertIsInstance(cmap_events[1], PoolClearedEvent, msg)
self.assertIsInstance(cmap_events[2], ConnectionCheckOutFailedEvent, msg)
self.assertEqual(cmap_events[2].reason, ConnectionCheckOutFailedReason.CONN_ERROR, msg)
self.assertIsInstance(cmap_events[3], ConnectionCheckedOutEvent, msg)
# Connection check out failures are not reflected in command
# monitoring because we only publish command events _after_ checking
# out a connection.
started = cmd_listener.started_events
msg = pprint.pformat(cmd_listener.results)
self.assertEqual(3, len(started), msg)
succeeded = cmd_listener.succeeded_events
self.assertEqual(2, len(succeeded), msg)
failed = cmd_listener.failed_events
self.assertEqual(1, len(failed), msg)
@async_client_context.require_sync
@async_client_context.require_failCommand_fail_point
@async_client_context.require_replica_set
@async_client_context.require_version_min(
6, 0, 0
) # the spec requires that this prose test only be run on 6.0+
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
async def test_returns_original_error_code(
self,
):
cmd_listener = InsertEventListener()
client = await self.async_rs_or_single_client(
retryWrites=True, event_listeners=[cmd_listener]
)
await client.test.test.drop()
cmd_listener.reset()
await client.admin.command(
{
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"writeConcernError": {
"code": 91,
"errorLabels": ["RetryableWriteError"],
},
"failCommands": ["insert"],
},
}
)
with self.assertRaises(WriteConcernError) as exc:
await client.test.test.insert_one({"_id": 1})
self.assertEqual(exc.exception.code, 91)
await client.admin.command(
{
"configureFailPoint": "failCommand",
"mode": "off",
}
)
# TODO: Make this a real integration test where we stepdown the primary.
class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
@async_client_context.require_replica_set
@async_client_context.require_no_mmap
async def test_increment_transaction_id_without_sending_command(self):
"""Test that the txnNumber field is properly incremented, even when
the first attempt fails before sending the command.
"""
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(retryWrites=True, event_listeners=[listener])
topology = client._topology
select_server = topology.select_server
def raise_connection_err_select_server(*args, **kwargs):
# Raise ConnectionFailure on the first attempt and perform
# normal selection on the retry attempt.
topology.select_server = select_server
raise ConnectionFailure("Connection refused")
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
listener.reset()
topology.select_server = raise_connection_err_select_server
async with client.start_session() as session:
kwargs = copy.deepcopy(kwargs)
kwargs["session"] = session
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
initial_txn_id = session._transaction_id
# Each operation should fail on the first attempt and succeed
# on the second.
await method(*args, **kwargs)
self.assertEqual(len(listener.started_events), 1, msg)
retry_cmd = listener.started_events[0].command
sent_txn_id = retry_cmd["txnNumber"]
final_txn_id = session._transaction_id
self.assertEqual(Int64(initial_txn_id + 1), sent_txn_id, msg)
self.assertEqual(sent_txn_id, final_txn_id, msg)
if __name__ == "__main__":
unittest.main()

View File

@ -28,6 +28,7 @@ import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
try:
import ipaddress
@ -47,6 +48,8 @@ from pymongo.uri_parser import parse_uri
if HAVE_SSL:
import ssl
_IS_SYNC = True
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
if hasattr(gc, "set_debug"):
gc.set_debug(

View File

@ -15,6 +15,7 @@
"""Test retryable writes."""
from __future__ import annotations
import asyncio
import copy
import pprint
import sys
@ -22,7 +23,13 @@ import threading
sys.path[0:0] = [""]
from test import IntegrationTest, SkipTest, client_context, client_knobs, unittest
from test import (
IntegrationTest,
SkipTest,
client_context,
unittest,
)
from test.helpers import client_knobs
from test.utils import (
CMAPListener,
DeprecationFilter,
@ -61,6 +68,8 @@ from pymongo.operations import (
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class InsertEventListener(EventListener):
def succeeded(self, event: CommandSucceededEvent) -> None:
@ -125,22 +134,22 @@ class IgnoreDeprecationsTest(IntegrationTest):
deprecation_filter: DeprecationFilter
@classmethod
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
cls.deprecation_filter = DeprecationFilter()
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
cls.deprecation_filter.stop()
super().tearDownClass()
super()._tearDown_class()
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
knobs: client_knobs
@classmethod
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
@ -148,10 +157,10 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
cls.db = cls.client.pymongo_test
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
cls.knobs.disable()
cls.client.close()
super().tearDownClass()
super()._tearDown_class()
@client_context.require_no_standalone
def test_actionable_error_message(self):
@ -174,8 +183,8 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
@classmethod
@client_context.require_no_mmap
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
@ -186,10 +195,10 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
cls.db = cls.client.pymongo_test
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
cls.knobs.disable()
cls.client.close()
super().tearDownClass()
super()._tearDown_class()
def setUp(self):
if client_context.is_rs and client_context.test_commands_enabled:
@ -206,7 +215,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()
client = self.rs_or_single_client(retryWrites=False, event_listeners=[listener])
self.addCleanup(client.close)
for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test):
msg = f"{method.__name__}(*{args!r}, **{kwargs!r})"
listener.reset()
@ -319,7 +327,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
"""
listener = OvertCommandListener()
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
self.addCleanup(client.close)
topology = client._topology
select_server = topology.select_server
@ -446,7 +453,6 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
for mongos in client_context.mongos_seeds().split(","):
client = self.rs_or_single_client(mongos)
set_fail_point(client, fail_command)
self.addCleanup(client.close)
mongos_clients.append(client)
listener = OvertCommandListener()
@ -478,8 +484,8 @@ class TestWriteConcernError(IntegrationTest):
@client_context.require_replica_set
@client_context.require_no_mmap
@client_context.require_failCommand_fail_point
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
cls.fail_insert = {
"configureFailPoint": "failCommand",
"mode": {"times": 2},
@ -494,7 +500,6 @@ class TestWriteConcernError(IntegrationTest):
def test_RetryableWriteError_error_label(self):
listener = OvertCommandListener()
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
self.addCleanup(client.close)
# Ensure collection exists.
client.pymongo_test.testcoll.insert_one({})
@ -546,6 +551,7 @@ class TestPoolPausedError(IntegrationTest):
RUN_ON_LOAD_BALANCER = False
RUN_ON_SERVERLESS = False
@client_context.require_sync
@client_context.require_failCommand_blockConnection
@client_context.require_retryable_writes
@client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05)
@ -555,7 +561,6 @@ class TestPoolPausedError(IntegrationTest):
client = self.rs_or_single_client(
maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]
)
self.addCleanup(client.close)
for _ in range(10):
cmap_listener.reset()
cmd_listener.reset()
@ -606,6 +611,7 @@ class TestPoolPausedError(IntegrationTest):
failed = cmd_listener.failed_events
self.assertEqual(1, len(failed), msg)
@client_context.require_sync
@client_context.require_failCommand_fail_point
@client_context.require_replica_set
@client_context.require_version_min(
@ -618,7 +624,6 @@ class TestPoolPausedError(IntegrationTest):
cmd_listener = InsertEventListener()
client = self.rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener])
client.test.test.drop()
self.addCleanup(client.close)
cmd_listener.reset()
client.admin.command(
{
@ -654,7 +659,6 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
"""
listener = OvertCommandListener()
client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener])
self.addCleanup(client.close)
topology = client._topology
select_server = topology.select_server

View File

@ -1157,3 +1157,9 @@ def set_fail_point(client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
client.admin.command(cmd)
async def async_set_fail_point(client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
await client.admin.command(cmd)

View File

@ -104,6 +104,7 @@ replacements = {
"PyMongo|c|async": "PyMongo|c",
"AsyncTestGridFile": "TestGridFile",
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
"async_set_fail_point": "set_fail_point",
}
docstring_replacements: dict[tuple[str, str], str] = {
@ -173,6 +174,7 @@ sync_gridfs_files = [
converted_tests = [
"__init__.py",
"conftest.py",
"helpers.py",
"pymongo_mocks.py",
"utils_spec_runner.py",
"qcheck.py",
@ -191,6 +193,7 @@ converted_tests = [
"test_logger.py",
"test_monitoring.py",
"test_raw_bson.py",
"test_retryable_writes.py",
"test_session.py",
"test_transactions.py",
]