PYTHON-5668 - Merge backpressure branch into mainline (#2729)

Co-authored-by: Steven Silvester <steve.silvester@mongodb.com>
Co-authored-by: Shane Harvey <shnhrv@gmail.com>
Co-authored-by: Steven Silvester <steven.silvester@ieee.org>
Co-authored-by: Iris <58442094+sleepyStick@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Kevin Albertson <kevin.albertson@mongodb.com>
Co-authored-by: Casey Clements <caseyclements@users.noreply.github.com>
Co-authored-by: Sergey Zelenov <mail@zelenov.su>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Noah Stapp 2026-04-14 12:25:29 -04:00 committed by GitHub
parent ee20ef52ec
commit e1751ff253
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 12043 additions and 1806 deletions

View File

@ -94,6 +94,9 @@ do
change-streams|change_streams) change-streams|change_streams)
cpjson change-streams/tests/ change_streams/ cpjson change-streams/tests/ change_streams/
;; ;;
client-backpressure|client_backpressure)
cpjson client-backpressure/tests client-backpressure
;;
client-side-encryption|csfle|fle) client-side-encryption|csfle|fle)
cpjson client-side-encryption/tests/ client-side-encryption/spec cpjson client-side-encryption/tests/ client-side-encryption/spec
cpjson client-side-encryption/corpus/ client-side-encryption/corpus cpjson client-side-encryption/corpus/ client-side-encryption/corpus

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,9 @@ PyMongo 4.17 brings a number of changes including:
- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods - Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation. that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
See <PLACEHOLDER> for examples and more information. See <PLACEHOLDER> for examples and more information.
- Added support for MongoDB's Intelligent Workload Management (IWM) and ingress connection rate limiting features.
The driver now gracefully handles write-blocking scenarios and optimizes connection establishment during high-load conditions to maintain application availability.
See <DOCSP-55426> and <DOCSP-57078> for more information.
Changes in Version 4.16.0 (2026/01/07) Changes in Version 4.16.0 (2026/01/07)
-------------------------------------- --------------------------------------

View File

@ -59,6 +59,7 @@ from pymongo.errors import (
InvalidOperation, InvalidOperation,
NotPrimaryError, NotPrimaryError,
OperationFailure, OperationFailure,
PyMongoError,
WaitQueueTimeoutError, WaitQueueTimeoutError,
) )
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
@ -563,9 +564,17 @@ class _AsyncClientBulk:
error, ConnectionFailure error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError)) ) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
retryable_label_error = isinstance(
error, PyMongoError
) and error.has_error_label("RetryableError")
# Synthesize the full bulk result without modifying the # Synthesize the full bulk result without modifying the
# current one because this write operation may be retried. # current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error): if retryable and (
retryable_top_level_error
or retryable_network_error
or retryable_label_error
):
full = copy.deepcopy(full_result) full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result) _merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results) _throw_client_bulk_write_exception(full, self.verbose_results)

View File

@ -135,7 +135,9 @@ Classes
from __future__ import annotations from __future__ import annotations
import asyncio
import collections import collections
import random
import time import time
import uuid import uuid
from collections.abc import Mapping as _Mapping from collections.abc import Mapping as _Mapping
@ -162,7 +164,9 @@ from pymongo.asynchronous.cursor_base import _ConnectionManager
from pymongo.errors import ( from pymongo.errors import (
ConfigurationError, ConfigurationError,
ConnectionFailure, ConnectionFailure,
ExecutionTimeout,
InvalidOperation, InvalidOperation,
NetworkTimeout,
OperationFailure, OperationFailure,
PyMongoError, PyMongoError,
WTimeoutError, WTimeoutError,
@ -427,6 +431,7 @@ class _Transaction:
self.recovery_token = None self.recovery_token = None
self.attempt = 0 self.attempt = 0
self.client = client self.client = client
self.has_completed_command = False
def active(self) -> bool: def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@ -434,6 +439,9 @@ class _Transaction:
def starting(self) -> bool: def starting(self) -> bool:
return self.state == _TxnState.STARTING return self.state == _TxnState.STARTING
def set_starting(self) -> None:
self.state = _TxnState.STARTING
@property @property
def pinned_conn(self) -> Optional[AsyncConnection]: def pinned_conn(self) -> Optional[AsyncConnection]:
if self.active() and self.conn_mgr: if self.active() and self.conn_mgr:
@ -459,6 +467,7 @@ class _Transaction:
self.sharded = False self.sharded = False
self.recovery_token = None self.recovery_token = None
self.attempt = 0 self.attempt = 0
self.has_completed_command = False
def __del__(self) -> None: def __del__(self) -> None:
if self.conn_mgr: if self.conn_mgr:
@ -493,11 +502,29 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
# This limit is non-configurable and was chosen to be twice the 60 second # This limit is non-configurable and was chosen to be twice the 60 second
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
_BACKOFF_MAX = 0.500 # 500ms max backoff
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
def _within_time_limit(start_time: float) -> bool: def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?""" """Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error
_T = TypeVar("_T") _T = TypeVar("_T")
@ -744,7 +771,17 @@ class AsyncClientSession:
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback
""" """
start_time = time.monotonic() start_time = time.monotonic()
retry = 0
last_error: Optional[BaseException] = None
while True: while True:
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise _make_timeout_error(last_error) from last_error
await asyncio.sleep(backoff)
retry += 1
await self.start_transaction( await self.start_transaction(
read_concern, write_concern, read_preference, max_commit_time_ms read_concern, write_concern, read_preference, max_commit_time_ms
) )
@ -752,15 +789,16 @@ class AsyncClientSession:
ret = await callback(self) ret = await callback(self)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup. # Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc: except BaseException as exc:
last_error = exc
if self.in_transaction: if self.in_transaction:
await self.abort_transaction() await self.abort_transaction()
if ( if isinstance(exc, PyMongoError) and exc.has_error_label(
isinstance(exc, PyMongoError) "TransientTransactionError"
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
): ):
# Retry the entire transaction. if _within_time_limit(start_time):
continue # Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise raise
if not self.in_transaction: if not self.in_transaction:
@ -771,17 +809,18 @@ class AsyncClientSession:
try: try:
await self.commit_transaction() await self.commit_transaction()
except PyMongoError as exc: except PyMongoError as exc:
if ( last_error = exc
exc.has_error_label("UnknownTransactionCommitResult") if exc.has_error_label(
and _within_time_limit(start_time) "UnknownTransactionCommitResult"
and not _max_time_expired_error(exc) ) and not _max_time_expired_error(exc):
): if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit. # Retry the commit.
continue continue
if exc.has_error_label("TransientTransactionError") and _within_time_limit( if exc.has_error_label("TransientTransactionError"):
start_time if not _within_time_limit(start_time):
): raise _make_timeout_error(last_error) from exc
# Retry the entire transaction. # Retry the entire transaction.
break break
raise raise

View File

@ -20,7 +20,6 @@ from collections import abc
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncContextManager,
Callable, Callable,
Coroutine, Coroutine,
Generic, Generic,
@ -571,11 +570,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
await change_stream._initialize_cursor() await change_stream._initialize_cursor()
return change_stream return change_stream
async def _conn_for_writes(
self, session: Optional[AsyncClientSession], operation: str
) -> AsyncContextManager[AsyncConnection]:
return await self._database.client._conn_for_writes(session, operation)
async def _command( async def _command(
self, self,
conn: AsyncConnection, conn: AsyncConnection,
@ -652,7 +646,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if "size" in options: if "size" in options:
options["size"] = float(options["size"]) options["size"] = float(options["size"])
cmd.update(options) cmd.update(options)
async with await self._conn_for_writes(session, operation=_Op.CREATE) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
if qev2_required and conn.max_wire_version < 21: if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError( raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. " "Driver support of Queryable Encryption is incompatible with server. "
@ -669,6 +666,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session=session, session=session,
) )
await self.database.client._retryable_write(False, inner, session, _Op.CREATE)
async def _create( async def _create(
self, self,
options: MutableMapping[str, Any], options: MutableMapping[str, Any],
@ -2240,7 +2239,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments. command (like maxTimeMS) can be passed as keyword arguments.
""" """
names = [] names = []
async with await self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
supports_quorum = conn.max_wire_version >= 9 supports_quorum = conn.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]: def gen_indexes() -> Iterator[Mapping[str, Any]]:
@ -2269,7 +2271,11 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write_concern=self._write_concern_for(session), write_concern=self._write_concern_for(session),
session=session, session=session,
) )
return names return names
return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_INDEXES
)
async def create_index( async def create_index(
self, self,
@ -2422,7 +2428,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment kwargs["comment"] = comment
await self._drop_index("*", session=session, **kwargs) await self._drop_index("*", session=session, **kwargs)
@_csot.apply
async def drop_index( async def drop_index(
self, self,
index_or_name: _IndexKeyHint, index_or_name: _IndexKeyHint,
@ -2490,7 +2495,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command( await self._command(
conn, conn,
cmd, cmd,
@ -2500,6 +2508,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session=session, session=session,
) )
await self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES)
async def list_indexes( async def list_indexes(
self, self,
session: Optional[AsyncClientSession] = None, session: Optional[AsyncClientSession] = None,
@ -2763,17 +2773,22 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())} cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())}
cmd.update(kwargs) cmd.update(kwargs)
async with await self._conn_for_writes( async def inner(
session, operation=_Op.CREATE_SEARCH_INDEXES session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) as conn: ) -> list[str]:
resp = await self._command( resp = await self._command(
conn, conn,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
) )
return [index["name"] for index in resp["indexesCreated"]] return [index["name"] for index in resp["indexesCreated"]]
return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_SEARCH_INDEXES
)
async def drop_search_index( async def drop_search_index(
self, self,
name: str, name: str,
@ -2799,15 +2814,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command( await self._command(
conn, conn,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
) )
await self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES)
async def update_search_index( async def update_search_index(
self, self,
name: str, name: str,
@ -2835,15 +2856,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command( await self._command(
conn, conn,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
) )
await self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX)
async def options( async def options(
self, self,
session: Optional[AsyncClientSession] = None, session: Optional[AsyncClientSession] = None,
@ -2918,6 +2945,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session, session,
retryable=not cmd._performs_write, retryable=not cmd._performs_write,
operation=_Op.AGGREGATE, operation=_Op.AGGREGATE,
is_aggregate_write=cmd._performs_write,
) )
async def aggregate( async def aggregate(
@ -3123,17 +3151,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session) write_concern = self._write_concern_for_cmd(cmd, session)
client = self._database.client
async with await self._conn_for_writes(session, operation=_Op.RENAME) as conn: async def inner(
async with self._database.client._tmp_session(session) as s: session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
return await conn.command( ) -> MutableMapping[str, Any]:
"admin", return await conn.command(
cmd, "admin",
write_concern=write_concern, cmd,
parse_write_concern_error=True, write_concern=write_concern,
session=s, parse_write_concern_error=True,
client=self._database.client, session=session,
) client=client,
)
return await client._retryable_write(False, inner, session, _Op.RENAME)
async def distinct( async def distinct(
self, self,

View File

@ -931,14 +931,15 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
if read_preference is None: if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
async with await self._client._conn_for_reads(
read_preference, session, operation=command_name async def inner(
) as ( session: Optional[AsyncClientSession],
connection, _server: Server,
read_preference, conn: AsyncConnection,
): read_preference: _ServerMode,
) -> Union[dict[str, Any], _CodecDocumentType]:
return await self._command( return await self._command(
connection, conn,
command, command,
value, value,
check, check,
@ -949,6 +950,10 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
**kwargs, **kwargs,
) )
return await self._client._retryable_read(
inner, read_preference, session, command_name, None, False, is_run_command=True
)
@_csot.apply @_csot.apply
async def cursor_command( async def cursor_command(
self, self,
@ -1016,17 +1021,17 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async with self._client._tmp_session(session) as tmp_session: async with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None: if read_preference is None:
read_preference = ( read_preference = (
tmp_session and tmp_session._txn_read_preference() tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY ) or ReadPreference.PRIMARY
async with await self._client._conn_for_reads(
read_preference, tmp_session, command_name async def inner(
) as ( session: Optional[AsyncClientSession],
conn, _server: Server,
read_preference, conn: AsyncConnection,
): read_preference: _ServerMode,
) -> AsyncCommandCursor[_DocumentType]:
response = await self._command( response = await self._command(
conn, conn,
command, command,
@ -1035,7 +1040,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
None, None,
read_preference, read_preference,
opts, opts,
session=tmp_session, session=session,
**kwargs, **kwargs,
) )
coll = self.get_collection("$cmd", read_preference=read_preference) coll = self.get_collection("$cmd", read_preference=read_preference)
@ -1045,7 +1050,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
response["cursor"], response["cursor"],
conn.address, conn.address,
max_await_time_ms=max_await_time_ms, max_await_time_ms=max_await_time_ms,
session=tmp_session, session=session,
comment=comment, comment=comment,
) )
await cmd_cursor._maybe_pin_connection(conn) await cmd_cursor._maybe_pin_connection(conn)
@ -1053,6 +1058,10 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
else: else:
raise InvalidOperation("Command does not return a cursor.") raise InvalidOperation("Command does not return a cursor.")
return await self.client._retryable_read(
inner, read_preference, tmp_session, command_name, None, False
)
async def _retryable_read_command( async def _retryable_read_command(
self, self,
command: Union[str, MutableMapping[str, Any]], command: Union[str, MutableMapping[str, Any]],
@ -1254,9 +1263,11 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
if comment is not None: if comment is not None:
command["comment"] = comment command["comment"] = comment
async with await self._client._conn_for_writes(session, operation=_Op.DROP) as connection: async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> dict[str, Any]:
return await self._command( return await self._command(
connection, conn,
command, command,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session), write_concern=self._write_concern_for(session),
@ -1264,6 +1275,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
session=session, session=session,
) )
return await self.client._retryable_write(False, inner, session, _Op.DROP)
@_csot.apply @_csot.apply
async def drop_collection( async def drop_collection(
self, self,

View File

@ -17,8 +17,11 @@ from __future__ import annotations
import asyncio import asyncio
import builtins import builtins
import functools
import random
import socket import socket
import sys import sys
import time as time # noqa: PLC0414 # needed in sync version
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -26,6 +29,8 @@ from typing import (
cast, cast,
) )
from pymongo import _csot
from pymongo.common import MAX_ADAPTIVE_RETRIES
from pymongo.errors import ( from pymongo.errors import (
OperationFailure, OperationFailure,
) )
@ -38,6 +43,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F: def _handle_reauth(func: F) -> F:
@functools.wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any: async def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False) no_reauth = kwargs.pop("no_reauth", False)
from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.pool import AsyncConnection
@ -70,6 +76,46 @@ def _handle_reauth(func: F) -> F:
return cast(F, inner) return cast(F, inner)
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10
def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> float:
jitter = random.random() # noqa: S311
return jitter * min(initial_delay * (2**attempt), max_delay)
class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter."""
def __init__(
self,
attempts: int = MAX_ADAPTIVE_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max
def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given attempt."""
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
async def should_retry(self, attempt: int, delay: float) -> bool:
"""Return if we have retry attempts remaining and the next backoff would not exceed a timeout."""
if attempt > self.attempts:
return False
if _csot.get_timeout():
if time.monotonic() + delay > _csot.get_deadline():
return False
return True
async def _getaddrinfo( async def _getaddrinfo(
host: Any, port: Any, **kwargs: Any host: Any, port: Any, **kwargs: Any
) -> list[ ) -> list[

View File

@ -35,6 +35,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import os import os
import time as time # noqa: PLC0414 # needed in sync version
import warnings import warnings
import weakref import weakref
from collections import defaultdict from collections import defaultdict
@ -67,6 +68,9 @@ from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterCh
from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.helpers import (
_RetryPolicy,
)
from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.client_options import ClientOptions from pymongo.client_options import ClientOptions
@ -610,8 +614,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for
details. details.
| **Overload retry options:**
- `max_adaptive_retries`: (int) How many retries to allow for overload errors. Defaults to ``2``.
- `enable_overload_retargeting`: (boolean) Whether overload retargeting is enabled for this client.
If enabled, server overload errors will cause retry attempts to select a server that has not yet returned an overload error, if possible.
Defaults to ``False``.
.. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_. .. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_.
.. versionchanged:: 4.17
Added the ``max_adaptive_retries`` and ``enable_overload_retargeting`` URI and keyword arguments.
.. versionchanged:: 4.5 .. versionchanged:: 4.5
Added the ``serverMonitoringMode`` keyword argument. Added the ``serverMonitoringMode`` keyword argument.
@ -879,11 +893,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern, self._options.read_concern,
) )
self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False self._opened = False
self._closed = False self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv: if not is_srv:
self._init_background() self._init_background()
@ -1991,6 +2008,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref: Optional[_ServerMode] = None, read_pref: Optional[_ServerMode] = None,
retryable: bool = False, retryable: bool = False,
operation_id: Optional[int] = None, operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T: ) -> T:
"""Internal retryable helper for all client transactions. """Internal retryable helper for all client transactions.
@ -2002,6 +2021,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Server Address, defaults to None :param address: Server Address, defaults to None
:param read_pref: Topology of read operation, defaults to None :param read_pref: Topology of read operation, defaults to None
:param retryable: If the operation should be retried once, defaults to None :param retryable: If the operation should be retried once, defaults to None
:param is_run_command: If this is a runCommand operation, defaults to False
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
:return: Output of the calling func() :return: Output of the calling func()
""" """
@ -2016,6 +2037,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address=address, address=address,
retryable=retryable, retryable=retryable,
operation_id=operation_id, operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
).run() ).run()
async def _retryable_read( async def _retryable_read(
@ -2027,6 +2050,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_Address] = None, address: Optional[_Address] = None,
retryable: bool = True, retryable: bool = True,
operation_id: Optional[int] = None, operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T: ) -> T:
"""Execute an operation with consecutive retries if possible """Execute an operation with consecutive retries if possible
@ -2042,6 +2067,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Optional address when sending a message, defaults to None :param address: Optional address when sending a message, defaults to None
:param retryable: if we should attempt retries :param retryable: if we should attempt retries
(may not always be supported even if supplied), defaults to False (may not always be supported even if supplied), defaults to False
:param is_run_command: If this is a runCommand operation, defaults to False.
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
""" """
# Ensure that the client supports retrying on reads and there is no session in # Ensure that the client supports retrying on reads and there is no session in
@ -2060,6 +2087,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref=read_pref, read_pref=read_pref,
retryable=retryable, retryable=retryable,
operation_id=operation_id, operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
) )
async def _retryable_write( async def _retryable_write(
@ -2454,15 +2483,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}" f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}"
) )
async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: await self[name].command(
await self[name]._command( {"dropDatabase": 1, "comment": comment},
conn, read_preference=ReadPreference.PRIMARY,
{"dropDatabase": 1, "comment": comment}, write_concern=self._write_concern_for(session),
read_preference=ReadPreference.PRIMARY, parse_write_concern_error=True,
write_concern=self._write_concern_for(session), session=session,
parse_write_concern_error=True, )
session=session,
)
@_csot.apply @_csot.apply
async def bulk_write( async def bulk_write(
@ -2746,12 +2773,15 @@ class _ClientConnectionRetryable(Generic[T]):
address: Optional[_Address] = None, address: Optional[_Address] = None,
retryable: bool = False, retryable: bool = False,
operation_id: Optional[int] = None, operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
): ):
self._last_error: Optional[Exception] = None self._last_error: Optional[Exception] = None
self._retrying = False self._retrying = False
self._always_retryable = False
self._multiple_retries = _csot.get_timeout() is not None self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client self._client = mongo_client
self._retry_policy = mongo_client._retry_policy
self._func = func self._func = func
self._bulk = bulk self._bulk = bulk
self._session = session self._session = session
@ -2767,6 +2797,8 @@ class _ClientConnectionRetryable(Generic[T]):
self._operation = operation self._operation = operation
self._operation_id = operation_id self._operation_id = operation_id
self._attempt_number = 0 self._attempt_number = 0
self._is_run_command = is_run_command
self._is_aggregate_write = is_aggregate_write
async def run(self) -> T: async def run(self) -> T:
"""Runs the supplied func() and attempts a retry """Runs the supplied func() and attempts a retry
@ -2786,7 +2818,13 @@ class _ClientConnectionRetryable(Generic[T]):
while True: while True:
self._check_last_error(check_csot=True) self._check_last_error(check_csot=True)
try: try:
return await self._read() if self._is_read else await self._write() res = await self._read() if self._is_read else await self._write()
# Track whether the transaction has completed a command.
# If we need to apply backpressure to the first command,
# we will need to revert back to starting state.
if self._session is not None and self._session.in_transaction:
self._session._transaction.has_completed_command = True
return res
except ServerSelectionTimeoutError: except ServerSelectionTimeoutError:
# The application may think the write was never attempted # The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry # if we raise ServerSelectionTimeoutError on the retry
@ -2797,37 +2835,76 @@ class _ClientConnectionRetryable(Generic[T]):
# most likely be a waste of time. # most likely be a waste of time.
raise raise
except PyMongoError as exc: except PyMongoError as exc:
always_retryable = False
overloaded = False
exc_to_check = exc
if self._is_run_command and not (
self._client.options.retry_reads and self._client.options.retry_writes
):
raise
if self._is_aggregate_write and not self._client.options.retry_writes:
raise
# Execute specialized catch on read # Execute specialized catch on read
if self._is_read: if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)): if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property # ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None) exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or ( overloaded = exc.has_error_label("SystemOverloadedError")
isinstance(exc, OperationFailure) always_retryable = exc.has_error_label("RetryableError") and overloaded
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES if not self._client.options.retry_reads or (
not always_retryable
and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
)
)
): ):
raise raise
self._retrying = True self._retrying = True
self._last_error = exc self._last_error = exc
self._attempt_number += 1 self._attempt_number += 1
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if (
overloaded
and self._session is not None
and self._session.in_transaction
):
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
else: else:
raise raise
# Specialized catch on write operation # Specialized catch on write operation
if not self._is_read: if not self._is_read:
if not self._retryable: if isinstance(exc, ClientBulkWriteException) and isinstance(
exc.error, PyMongoError
):
exc_to_check = exc.error
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded
# Always retry abortTransaction and commitTransaction up to once
if self._operation not in ["abortTransaction", "commitTransaction"] and (
not self._client.options.retry_writes
or not (self._retryable or always_retryable)
):
raise raise
if isinstance(exc, ClientBulkWriteException) and exc.error: if retryable_write_label or always_retryable:
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
assert self._session assert self._session
await self._session._unpin() await self._session._unpin()
if not retryable_write_error_exc or self._is_not_eligible_for_retry(): if not always_retryable and (
if exc.has_error_label("NoWritesPerformed") and self._last_error: not retryable_write_label or self._is_not_eligible_for_retry()
):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc raise self._last_error from exc
else: else:
raise raise
@ -2836,18 +2913,34 @@ class _ClientConnectionRetryable(Generic[T]):
self._bulk.retrying = True self._bulk.retrying = True
else: else:
self._retrying = True self._retrying = True
if not exc.has_error_label("NoWritesPerformed"): if not exc_to_check.has_error_label("NoWritesPerformed"):
self._last_error = exc self._last_error = exc
if self._last_error is None: if self._last_error is None:
self._last_error = exc self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
if ( if self._server is not None and (
self._server is not None self._client.topology_description.topology_type_name == "Sharded"
and self._client.topology_description.topology_type_name == "Sharded" or (overloaded and self._client.options.enable_overload_retargeting)
or exc.has_error_label("SystemOverloadedError")
): ):
self._deprioritized_servers.append(self._server) self._deprioritized_servers.append(self._server)
self._always_retryable = always_retryable
if overloaded:
delay = self._retry_policy.backoff(self._attempt_number)
if not await self._retry_policy.should_retry(self._attempt_number, delay):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
await asyncio.sleep(delay)
def _is_not_eligible_for_retry(self) -> bool: def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry""" """Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries) return not self._retryable or (self._is_retrying() and not self._multiple_retries)
@ -2909,7 +3002,7 @@ class _ClientConnectionRetryable(Generic[T]):
and conn.supports_sessions and conn.supports_sessions
) )
is_mongos = conn.is_mongos is_mongos = conn.is_mongos
if not sessions_supported: if not self._always_retryable and not sessions_supported:
# A retry is not possible because this server does # A retry is not possible because this server does
# not support sessions raise the last error. # not support sessions raise the last error.
self._check_last_error() self._check_last_error()
@ -2941,7 +3034,7 @@ class _ClientConnectionRetryable(Generic[T]):
conn, conn,
read_pref, read_pref,
): ):
if self._retrying and not self._retryable: if self._retrying and not self._retryable and not self._always_retryable:
self._check_last_error() self._check_last_error()
if self._retrying: if self._retrying:
_debug_log( _debug_log(

View File

@ -19,6 +19,8 @@ import collections
import contextlib import contextlib
import logging import logging
import os import os
import socket
import ssl
import sys import sys
import time import time
import weakref import weakref
@ -52,10 +54,12 @@ from pymongo.errors import ( # type:ignore[attr-defined]
DocumentTooLarge, DocumentTooLarge,
ExecutionTimeout, ExecutionTimeout,
InvalidOperation, InvalidOperation,
NetworkTimeout,
NotPrimaryError, NotPrimaryError,
OperationFailure, OperationFailure,
PyMongoError, PyMongoError,
WaitQueueTimeoutError, WaitQueueTimeoutError,
_CertificateError,
) )
from pymongo.hello import Hello, HelloCompat from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
@ -250,6 +254,7 @@ class AsyncConnection:
cmd = self.hello_cmd() cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake performing_handshake = not self.performed_handshake
awaitable = False awaitable = False
cmd["backpressure"] = True
if performing_handshake: if performing_handshake:
self.performed_handshake = True self.performed_handshake = True
cmd["client"] = self.opts.metadata cmd["client"] = self.opts.metadata
@ -752,8 +757,8 @@ class Pool:
# Enforces: maxConnecting # Enforces: maxConnecting
# Also used for: clearing the wait queue # Also used for: clearing the wait queue
self._max_connecting_cond = _async_create_condition(self.lock) self._max_connecting_cond = _async_create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0 self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id self._client_id = client_id
if self.enabled_for_cmap: if self.enabled_for_cmap:
assert self.opts._event_listeners is not None assert self.opts._event_listeners is not None
@ -986,6 +991,21 @@ class Pool:
self.requests -= 1 self.requests -= 1
self.size_cond.notify() self.size_cond.notify()
def _handle_connection_error(self, error: BaseException) -> None:
# Handle system overload condition for non-sdam pools.
# Look for errors of type AutoReconnect and add error labels if appropriate.
if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout):
return
assert isinstance(error, AutoReconnect) # Appease type checker.
# If the original error was a DNS, certificate, or SSL error, ignore it.
if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)):
# End of file errors are excluded, because the server may have disconnected
# during the handshake.
if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)):
return
error._add_error_label("SystemOverloadedError")
error._add_error_label("RetryableError")
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection: async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection:
"""Connect to Mongo and return a new AsyncConnection. """Connect to Mongo and return a new AsyncConnection.
@ -1037,10 +1057,10 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR), reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR, error=ConnectionClosedReason.ERROR,
) )
self._handle_connection_error(error)
if isinstance(error, (IOError, OSError, *SSLErrors)): if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts) details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details) _raise_connection_failure(self.address, error, timeout_details=details)
raise raise
conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type] conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
@ -1049,18 +1069,22 @@ class Pool:
self.active_contexts.discard(tmp_context) self.active_contexts.discard(tmp_context)
if tmp_context.cancelled: if tmp_context.cancelled:
conn.cancel_context.cancel() conn.cancel_context.cancel()
completed_hello = False
try: try:
if not self.is_sdam: if not self.is_sdam:
await conn.hello() await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable self.is_writable = conn.is_writable
if handler: if handler:
handler.contribute_socket(conn, completed_handshake=False) handler.contribute_socket(conn, completed_handshake=False)
await conn.authenticate() await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup. # Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException: except BaseException as e:
async with self.lock: async with self.lock:
self.active_contexts.discard(conn.cancel_context) self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR) await conn.close_conn(ConnectionClosedReason.ERROR)
raise raise
@ -1389,8 +1413,8 @@ class Pool:
:class:`~pymongo.errors.AutoReconnect` exceptions on server :class:`~pymongo.errors.AutoReconnect` exceptions on server
hiccups, etc. We only check if the socket was closed by an external hiccups, etc. We only check if the socket was closed by an external
error if it has been > 1 second since the socket was checked into the error if it has been > 1 second since the socket was checked into the
pool, to keep performance reasonable - we can't avoid AutoReconnects pool to keep performance reasonable -
completely anyway. we can't avoid AutoReconnects completely anyway.
""" """
idle_time_seconds = conn.idle_time_seconds() idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one. # If socket is idle, open a new one.
@ -1401,8 +1425,9 @@ class Pool:
await conn.close_conn(ConnectionClosedReason.IDLE) await conn.close_conn(ConnectionClosedReason.IDLE)
return True return True
if self._check_interval_seconds is not None and ( check_interval_seconds = self._check_interval_seconds
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds if check_interval_seconds is not None and (
check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds
): ):
if conn.conn_closed(): if conn.conn_closed():
await conn.close_conn(ConnectionClosedReason.ERROR) await conn.close_conn(ConnectionClosedReason.ERROR)

View File

@ -913,7 +913,9 @@ class Topology:
# Clear the pool. # Clear the pool.
await server.reset(service_id) await server.reset(service_id)
elif isinstance(error, ConnectionFailure): elif isinstance(error, ConnectionFailure):
if isinstance(error, WaitQueueTimeoutError): if isinstance(error, WaitQueueTimeoutError) or (
error.has_error_label("SystemOverloadedError")
):
return return
# "Client MUST replace the server's description with type Unknown # "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server." # ... MUST NOT request an immediate check of the server."

View File

@ -235,6 +235,16 @@ class ClientOptions:
self.__server_monitoring_mode = options.get( self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE "servermonitoringmode", common.SERVER_MONITORING_MODE
) )
self.__max_adaptive_retries = (
options.get("max_adaptive_retries", common.MAX_ADAPTIVE_RETRIES)
if "max_adaptive_retries" in options
else options.get("maxadaptiveretries", common.MAX_ADAPTIVE_RETRIES)
)
self.__enable_overload_retargeting = (
options.get("enable_overload_retargeting", common.ENABLE_OVERLOAD_RETARGETING)
if "enable_overload_retargeting" in options
else options.get("enableoverloadretargeting", common.ENABLE_OVERLOAD_RETARGETING)
)
@property @property
def _options(self) -> Mapping[str, Any]: def _options(self) -> Mapping[str, Any]:
@ -346,3 +356,19 @@ class ClientOptions:
.. versionadded:: 4.5 .. versionadded:: 4.5
""" """
return self.__server_monitoring_mode return self.__server_monitoring_mode
@property
def max_adaptive_retries(self) -> int:
"""The configured maxAdaptiveRetries option.
.. versionadded:: 4.17
"""
return self.__max_adaptive_retries
@property
def enable_overload_retargeting(self) -> bool:
"""The configured enableOverloadRetargeting option.
.. versionadded:: 4.17
"""
return self.__enable_overload_retargeting

View File

@ -140,6 +140,12 @@ SRV_SERVICE_NAME = "mongodb"
# Default value for serverMonitoringMode # Default value for serverMonitoringMode
SERVER_MONITORING_MODE = "auto" # poll/stream/auto SERVER_MONITORING_MODE = "auto" # poll/stream/auto
# Default value for max adaptive retries
MAX_ADAPTIVE_RETRIES = 2
# Default value for enableOverloadRetargeting
ENABLE_OVERLOAD_RETARGETING = False
# Auth mechanism properties that must raise an error instead of warning if they invalidate. # Auth mechanism properties that must raise an error instead of warning if they invalidate.
_MECH_PROP_MUST_RAISE = ["CANONICALIZE_HOST_NAME"] _MECH_PROP_MUST_RAISE = ["CANONICALIZE_HOST_NAME"]
@ -717,6 +723,8 @@ URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = {
"srvmaxhosts": validate_non_negative_integer, "srvmaxhosts": validate_non_negative_integer,
"timeoutms": validate_timeoutms, "timeoutms": validate_timeoutms,
"servermonitoringmode": validate_server_monitoring_mode, "servermonitoringmode": validate_server_monitoring_mode,
"maxadaptiveretries": validate_non_negative_integer,
"enableoverloadretargeting": validate_boolean_or_string,
} }
# Dictionary where keys are the names of URI options specific to pymongo, # Dictionary where keys are the names of URI options specific to pymongo,
@ -750,6 +758,8 @@ KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = {
"server_selector": validate_is_callable_or_none, "server_selector": validate_is_callable_or_none,
"auto_encryption_opts": validate_auto_encryption_opts_or_none, "auto_encryption_opts": validate_auto_encryption_opts_or_none,
"authoidcallowedhosts": validate_list, "authoidcallowedhosts": validate_list,
"max_adaptive_retries": validate_non_negative_integer,
"enable_overload_retargeting": validate_boolean_or_string,
} }
# Dictionary where keys are any URI option name, and values are the # Dictionary where keys are any URI option name, and values are the

View File

@ -59,6 +59,7 @@ from pymongo.errors import (
InvalidOperation, InvalidOperation,
NotPrimaryError, NotPrimaryError,
OperationFailure, OperationFailure,
PyMongoError,
WaitQueueTimeoutError, WaitQueueTimeoutError,
) )
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
@ -561,9 +562,17 @@ class _ClientBulk:
error, ConnectionFailure error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError)) ) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
retryable_label_error = isinstance(
error, PyMongoError
) and error.has_error_label("RetryableError")
# Synthesize the full bulk result without modifying the # Synthesize the full bulk result without modifying the
# current one because this write operation may be retried. # current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error): if retryable and (
retryable_top_level_error
or retryable_network_error
or retryable_label_error
):
full = copy.deepcopy(full_result) full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result) _merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results) _throw_client_bulk_write_exception(full, self.verbose_results)

View File

@ -136,6 +136,7 @@ Classes
from __future__ import annotations from __future__ import annotations
import collections import collections
import random
import time import time
import uuid import uuid
from collections.abc import Mapping as _Mapping from collections.abc import Mapping as _Mapping
@ -160,7 +161,9 @@ from pymongo import _csot
from pymongo.errors import ( from pymongo.errors import (
ConfigurationError, ConfigurationError,
ConnectionFailure, ConnectionFailure,
ExecutionTimeout,
InvalidOperation, InvalidOperation,
NetworkTimeout,
OperationFailure, OperationFailure,
PyMongoError, PyMongoError,
WTimeoutError, WTimeoutError,
@ -426,6 +429,7 @@ class _Transaction:
self.recovery_token = None self.recovery_token = None
self.attempt = 0 self.attempt = 0
self.client = client self.client = client
self.has_completed_command = False
def active(self) -> bool: def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@ -433,6 +437,9 @@ class _Transaction:
def starting(self) -> bool: def starting(self) -> bool:
return self.state == _TxnState.STARTING return self.state == _TxnState.STARTING
def set_starting(self) -> None:
self.state = _TxnState.STARTING
@property @property
def pinned_conn(self) -> Optional[Connection]: def pinned_conn(self) -> Optional[Connection]:
if self.active() and self.conn_mgr: if self.active() and self.conn_mgr:
@ -458,6 +465,7 @@ class _Transaction:
self.sharded = False self.sharded = False
self.recovery_token = None self.recovery_token = None
self.attempt = 0 self.attempt = 0
self.has_completed_command = False
def __del__(self) -> None: def __del__(self) -> None:
if self.conn_mgr: if self.conn_mgr:
@ -492,11 +500,29 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
# This limit is non-configurable and was chosen to be twice the 60 second # This limit is non-configurable and was chosen to be twice the 60 second
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
_BACKOFF_MAX = 0.500 # 500ms max backoff
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
def _within_time_limit(start_time: float) -> bool: def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?""" """Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error
_T = TypeVar("_T") _T = TypeVar("_T")
@ -743,21 +769,32 @@ class ClientSession:
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback
""" """
start_time = time.monotonic() start_time = time.monotonic()
retry = 0
last_error: Optional[BaseException] = None
while True: while True:
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise _make_timeout_error(last_error) from last_error
time.sleep(backoff)
retry += 1
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
try: try:
ret = callback(self) ret = callback(self)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup. # Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc: except BaseException as exc:
last_error = exc
if self.in_transaction: if self.in_transaction:
self.abort_transaction() self.abort_transaction()
if ( if isinstance(exc, PyMongoError) and exc.has_error_label(
isinstance(exc, PyMongoError) "TransientTransactionError"
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
): ):
# Retry the entire transaction. if _within_time_limit(start_time):
continue # Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise raise
if not self.in_transaction: if not self.in_transaction:
@ -768,17 +805,18 @@ class ClientSession:
try: try:
self.commit_transaction() self.commit_transaction()
except PyMongoError as exc: except PyMongoError as exc:
if ( last_error = exc
exc.has_error_label("UnknownTransactionCommitResult") if exc.has_error_label(
and _within_time_limit(start_time) "UnknownTransactionCommitResult"
and not _max_time_expired_error(exc) ) and not _max_time_expired_error(exc):
): if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit. # Retry the commit.
continue continue
if exc.has_error_label("TransientTransactionError") and _within_time_limit( if exc.has_error_label("TransientTransactionError"):
start_time if not _within_time_limit(start_time):
): raise _make_timeout_error(last_error) from exc
# Retry the entire transaction. # Retry the entire transaction.
break break
raise raise

View File

@ -21,7 +21,6 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
ContextManager,
Generic, Generic,
Iterable, Iterable,
Iterator, Iterator,
@ -572,11 +571,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
change_stream._initialize_cursor() change_stream._initialize_cursor()
return change_stream return change_stream
def _conn_for_writes(
self, session: Optional[ClientSession], operation: str
) -> ContextManager[Connection]:
return self._database.client._conn_for_writes(session, operation)
def _command( def _command(
self, self,
conn: Connection, conn: Connection,
@ -653,7 +647,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if "size" in options: if "size" in options:
options["size"] = float(options["size"]) options["size"] = float(options["size"])
cmd.update(options) cmd.update(options)
with self._conn_for_writes(session, operation=_Op.CREATE) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
if qev2_required and conn.max_wire_version < 21: if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError( raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. " "Driver support of Queryable Encryption is incompatible with server. "
@ -670,6 +667,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session=session, session=session,
) )
self.database.client._retryable_write(False, inner, session, _Op.CREATE)
def _create( def _create(
self, self,
options: MutableMapping[str, Any], options: MutableMapping[str, Any],
@ -2237,7 +2236,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments. command (like maxTimeMS) can be passed as keyword arguments.
""" """
names = [] names = []
with self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> list[str]:
supports_quorum = conn.max_wire_version >= 9 supports_quorum = conn.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]: def gen_indexes() -> Iterator[Mapping[str, Any]]:
@ -2266,7 +2268,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern=self._write_concern_for(session), write_concern=self._write_concern_for(session),
session=session, session=session,
) )
return names return names
return self.database.client._retryable_write(False, inner, session, _Op.CREATE_INDEXES)
def create_index( def create_index(
self, self,
@ -2419,7 +2423,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment kwargs["comment"] = comment
self._drop_index("*", session=session, **kwargs) self._drop_index("*", session=session, **kwargs)
@_csot.apply
def drop_index( def drop_index(
self, self,
index_or_name: _IndexKeyHint, index_or_name: _IndexKeyHint,
@ -2487,7 +2490,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command( self._command(
conn, conn,
cmd, cmd,
@ -2497,6 +2503,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session=session, session=session,
) )
self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES)
def list_indexes( def list_indexes(
self, self,
session: Optional[ClientSession] = None, session: Optional[ClientSession] = None,
@ -2760,15 +2768,22 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())} cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())}
cmd.update(kwargs) cmd.update(kwargs)
with self._conn_for_writes(session, operation=_Op.CREATE_SEARCH_INDEXES) as conn: def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> list[str]:
resp = self._command( resp = self._command(
conn, conn,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
) )
return [index["name"] for index in resp["indexesCreated"]] return [index["name"] for index in resp["indexesCreated"]]
return self.database.client._retryable_write(
False, inner, session, _Op.CREATE_SEARCH_INDEXES
)
def drop_search_index( def drop_search_index(
self, self,
name: str, name: str,
@ -2794,15 +2809,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command( self._command(
conn, conn,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
) )
self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES)
def update_search_index( def update_search_index(
self, self,
name: str, name: str,
@ -2830,15 +2851,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs) cmd.update(kwargs)
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command( self._command(
conn, conn,
cmd, cmd,
read_preference=ReadPreference.PRIMARY, read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
) )
self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX)
def options( def options(
self, self,
session: Optional[ClientSession] = None, session: Optional[ClientSession] = None,
@ -2911,6 +2938,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session, session,
retryable=not cmd._performs_write, retryable=not cmd._performs_write,
operation=_Op.AGGREGATE, operation=_Op.AGGREGATE,
is_aggregate_write=cmd._performs_write,
) )
def aggregate( def aggregate(
@ -3116,17 +3144,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if comment is not None: if comment is not None:
cmd["comment"] = comment cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session) write_concern = self._write_concern_for_cmd(cmd, session)
client = self._database.client
with self._conn_for_writes(session, operation=_Op.RENAME) as conn: def inner(
with self._database.client._tmp_session(session) as s: session: Optional[ClientSession], conn: Connection, _retryable_write: bool
return conn.command( ) -> MutableMapping[str, Any]:
"admin", return conn.command(
cmd, "admin",
write_concern=write_concern, cmd,
parse_write_concern_error=True, write_concern=write_concern,
session=s, parse_write_concern_error=True,
client=self._database.client, session=session,
) client=client,
)
return client._retryable_write(False, inner, session, _Op.RENAME)
def distinct( def distinct(
self, self,

View File

@ -931,12 +931,15 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if read_preference is None: if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
with self._client._conn_for_reads(read_preference, session, operation=command_name) as (
connection, def inner(
read_preference, session: Optional[ClientSession],
): _server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> Union[dict[str, Any], _CodecDocumentType]:
return self._command( return self._command(
connection, conn,
command, command,
value, value,
check, check,
@ -947,6 +950,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
**kwargs, **kwargs,
) )
return self._client._retryable_read(
inner, read_preference, session, command_name, None, False, is_run_command=True
)
@_csot.apply @_csot.apply
def cursor_command( def cursor_command(
self, self,
@ -1014,15 +1021,17 @@ class Database(common.BaseObject, Generic[_DocumentType]):
with self._client._tmp_session(session) as tmp_session: with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None: if read_preference is None:
read_preference = ( read_preference = (
tmp_session and tmp_session._txn_read_preference() tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY ) or ReadPreference.PRIMARY
with self._client._conn_for_reads(read_preference, tmp_session, command_name) as (
conn, def inner(
read_preference, session: Optional[ClientSession],
): _server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]:
response = self._command( response = self._command(
conn, conn,
command, command,
@ -1031,7 +1040,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
None, None,
read_preference, read_preference,
opts, opts,
session=tmp_session, session=session,
**kwargs, **kwargs,
) )
coll = self.get_collection("$cmd", read_preference=read_preference) coll = self.get_collection("$cmd", read_preference=read_preference)
@ -1041,7 +1050,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
response["cursor"], response["cursor"],
conn.address, conn.address,
max_await_time_ms=max_await_time_ms, max_await_time_ms=max_await_time_ms,
session=tmp_session, session=session,
comment=comment, comment=comment,
) )
cmd_cursor._maybe_pin_connection(conn) cmd_cursor._maybe_pin_connection(conn)
@ -1049,6 +1058,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
else: else:
raise InvalidOperation("Command does not return a cursor.") raise InvalidOperation("Command does not return a cursor.")
return self.client._retryable_read(
inner, read_preference, tmp_session, command_name, None, False
)
def _retryable_read_command( def _retryable_read_command(
self, self,
command: Union[str, MutableMapping[str, Any]], command: Union[str, MutableMapping[str, Any]],
@ -1247,9 +1260,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if comment is not None: if comment is not None:
command["comment"] = comment command["comment"] = comment
with self._client._conn_for_writes(session, operation=_Op.DROP) as connection: def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> dict[str, Any]:
return self._command( return self._command(
connection, conn,
command, command,
allowable_errors=["ns not found", 26], allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session), write_concern=self._write_concern_for(session),
@ -1257,6 +1272,8 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session=session, session=session,
) )
return self.client._retryable_write(False, inner, session, _Op.DROP)
@_csot.apply @_csot.apply
def drop_collection( def drop_collection(
self, self,

View File

@ -17,8 +17,11 @@ from __future__ import annotations
import asyncio import asyncio
import builtins import builtins
import functools
import random
import socket import socket
import sys import sys
import time as time # noqa: PLC0414 # needed in sync version
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -26,6 +29,8 @@ from typing import (
cast, cast,
) )
from pymongo import _csot
from pymongo.common import MAX_ADAPTIVE_RETRIES
from pymongo.errors import ( from pymongo.errors import (
OperationFailure, OperationFailure,
) )
@ -38,6 +43,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F: def _handle_reauth(func: F) -> F:
@functools.wraps(func)
def inner(*args: Any, **kwargs: Any) -> Any: def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False) no_reauth = kwargs.pop("no_reauth", False)
from pymongo.message import _BulkWriteContext from pymongo.message import _BulkWriteContext
@ -70,6 +76,46 @@ def _handle_reauth(func: F) -> F:
return cast(F, inner) return cast(F, inner)
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10
def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> float:
jitter = random.random() # noqa: S311
return jitter * min(initial_delay * (2**attempt), max_delay)
class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter."""
def __init__(
self,
attempts: int = MAX_ADAPTIVE_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max
def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given attempt."""
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
def should_retry(self, attempt: int, delay: float) -> bool:
"""Return if we have retry attempts remaining and the next backoff would not exceed a timeout."""
if attempt > self.attempts:
return False
if _csot.get_timeout():
if time.monotonic() + delay > _csot.get_deadline():
return False
return True
def _getaddrinfo( def _getaddrinfo(
host: Any, port: Any, **kwargs: Any host: Any, port: Any, **kwargs: Any
) -> list[ ) -> list[

View File

@ -35,6 +35,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import os import os
import time as time # noqa: PLC0414 # needed in sync version
import warnings import warnings
import weakref import weakref
from collections import defaultdict from collections import defaultdict
@ -110,6 +111,9 @@ from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.helpers import (
_RetryPolicy,
)
from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.synchronous.topology import Topology, _ErrorContext
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
@ -610,8 +614,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for
details. details.
| **Overload retry options:**
- `max_adaptive_retries`: (int) How many retries to allow for overload errors. Defaults to ``2``.
- `enable_overload_retargeting`: (boolean) Whether overload retargeting is enabled for this client.
If enabled, server overload errors will cause retry attempts to select a server that has not yet returned an overload error, if possible.
Defaults to ``False``.
.. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_. .. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_.
.. versionchanged:: 4.17
Added the ``max_adaptive_retries`` and ``enable_overload_retargeting`` URI and keyword arguments.
.. versionchanged:: 4.5 .. versionchanged:: 4.5
Added the ``serverMonitoringMode`` keyword argument. Added the ``serverMonitoringMode`` keyword argument.
@ -879,11 +893,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern, self._options.read_concern,
) )
self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False self._opened = False
self._closed = False self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv: if not is_srv:
self._init_background() self._init_background()
@ -1987,6 +2004,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref: Optional[_ServerMode] = None, read_pref: Optional[_ServerMode] = None,
retryable: bool = False, retryable: bool = False,
operation_id: Optional[int] = None, operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T: ) -> T:
"""Internal retryable helper for all client transactions. """Internal retryable helper for all client transactions.
@ -1998,6 +2017,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Server Address, defaults to None :param address: Server Address, defaults to None
:param read_pref: Topology of read operation, defaults to None :param read_pref: Topology of read operation, defaults to None
:param retryable: If the operation should be retried once, defaults to None :param retryable: If the operation should be retried once, defaults to None
:param is_run_command: If this is a runCommand operation, defaults to False
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
:return: Output of the calling func() :return: Output of the calling func()
""" """
@ -2012,6 +2033,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address=address, address=address,
retryable=retryable, retryable=retryable,
operation_id=operation_id, operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
).run() ).run()
def _retryable_read( def _retryable_read(
@ -2023,6 +2046,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_Address] = None, address: Optional[_Address] = None,
retryable: bool = True, retryable: bool = True,
operation_id: Optional[int] = None, operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T: ) -> T:
"""Execute an operation with consecutive retries if possible """Execute an operation with consecutive retries if possible
@ -2038,6 +2063,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Optional address when sending a message, defaults to None :param address: Optional address when sending a message, defaults to None
:param retryable: if we should attempt retries :param retryable: if we should attempt retries
(may not always be supported even if supplied), defaults to False (may not always be supported even if supplied), defaults to False
:param is_run_command: If this is a runCommand operation, defaults to False.
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
""" """
# Ensure that the client supports retrying on reads and there is no session in # Ensure that the client supports retrying on reads and there is no session in
@ -2056,6 +2083,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref=read_pref, read_pref=read_pref,
retryable=retryable, retryable=retryable,
operation_id=operation_id, operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
) )
def _retryable_write( def _retryable_write(
@ -2444,15 +2473,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
f"name_or_database must be an instance of str or a Database, not {type(name)}" f"name_or_database must be an instance of str or a Database, not {type(name)}"
) )
with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: self[name].command(
self[name]._command( {"dropDatabase": 1, "comment": comment},
conn, read_preference=ReadPreference.PRIMARY,
{"dropDatabase": 1, "comment": comment}, write_concern=self._write_concern_for(session),
read_preference=ReadPreference.PRIMARY, parse_write_concern_error=True,
write_concern=self._write_concern_for(session), session=session,
parse_write_concern_error=True, )
session=session,
)
@_csot.apply @_csot.apply
def bulk_write( def bulk_write(
@ -2736,12 +2763,15 @@ class _ClientConnectionRetryable(Generic[T]):
address: Optional[_Address] = None, address: Optional[_Address] = None,
retryable: bool = False, retryable: bool = False,
operation_id: Optional[int] = None, operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
): ):
self._last_error: Optional[Exception] = None self._last_error: Optional[Exception] = None
self._retrying = False self._retrying = False
self._always_retryable = False
self._multiple_retries = _csot.get_timeout() is not None self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client self._client = mongo_client
self._retry_policy = mongo_client._retry_policy
self._func = func self._func = func
self._bulk = bulk self._bulk = bulk
self._session = session self._session = session
@ -2757,6 +2787,8 @@ class _ClientConnectionRetryable(Generic[T]):
self._operation = operation self._operation = operation
self._operation_id = operation_id self._operation_id = operation_id
self._attempt_number = 0 self._attempt_number = 0
self._is_run_command = is_run_command
self._is_aggregate_write = is_aggregate_write
def run(self) -> T: def run(self) -> T:
"""Runs the supplied func() and attempts a retry """Runs the supplied func() and attempts a retry
@ -2776,7 +2808,13 @@ class _ClientConnectionRetryable(Generic[T]):
while True: while True:
self._check_last_error(check_csot=True) self._check_last_error(check_csot=True)
try: try:
return self._read() if self._is_read else self._write() res = self._read() if self._is_read else self._write()
# Track whether the transaction has completed a command.
# If we need to apply backpressure to the first command,
# we will need to revert back to starting state.
if self._session is not None and self._session.in_transaction:
self._session._transaction.has_completed_command = True
return res
except ServerSelectionTimeoutError: except ServerSelectionTimeoutError:
# The application may think the write was never attempted # The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry # if we raise ServerSelectionTimeoutError on the retry
@ -2787,37 +2825,76 @@ class _ClientConnectionRetryable(Generic[T]):
# most likely be a waste of time. # most likely be a waste of time.
raise raise
except PyMongoError as exc: except PyMongoError as exc:
always_retryable = False
overloaded = False
exc_to_check = exc
if self._is_run_command and not (
self._client.options.retry_reads and self._client.options.retry_writes
):
raise
if self._is_aggregate_write and not self._client.options.retry_writes:
raise
# Execute specialized catch on read # Execute specialized catch on read
if self._is_read: if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)): if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property # ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None) exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or ( overloaded = exc.has_error_label("SystemOverloadedError")
isinstance(exc, OperationFailure) always_retryable = exc.has_error_label("RetryableError") and overloaded
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES if not self._client.options.retry_reads or (
not always_retryable
and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
)
)
): ):
raise raise
self._retrying = True self._retrying = True
self._last_error = exc self._last_error = exc
self._attempt_number += 1 self._attempt_number += 1
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if (
overloaded
and self._session is not None
and self._session.in_transaction
):
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
else: else:
raise raise
# Specialized catch on write operation # Specialized catch on write operation
if not self._is_read: if not self._is_read:
if not self._retryable: if isinstance(exc, ClientBulkWriteException) and isinstance(
exc.error, PyMongoError
):
exc_to_check = exc.error
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded
# Always retry abortTransaction and commitTransaction up to once
if self._operation not in ["abortTransaction", "commitTransaction"] and (
not self._client.options.retry_writes
or not (self._retryable or always_retryable)
):
raise raise
if isinstance(exc, ClientBulkWriteException) and exc.error: if retryable_write_label or always_retryable:
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
assert self._session assert self._session
self._session._unpin() self._session._unpin()
if not retryable_write_error_exc or self._is_not_eligible_for_retry(): if not always_retryable and (
if exc.has_error_label("NoWritesPerformed") and self._last_error: not retryable_write_label or self._is_not_eligible_for_retry()
):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc raise self._last_error from exc
else: else:
raise raise
@ -2826,18 +2903,34 @@ class _ClientConnectionRetryable(Generic[T]):
self._bulk.retrying = True self._bulk.retrying = True
else: else:
self._retrying = True self._retrying = True
if not exc.has_error_label("NoWritesPerformed"): if not exc_to_check.has_error_label("NoWritesPerformed"):
self._last_error = exc self._last_error = exc
if self._last_error is None: if self._last_error is None:
self._last_error = exc self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
if ( if self._server is not None and (
self._server is not None self._client.topology_description.topology_type_name == "Sharded"
and self._client.topology_description.topology_type_name == "Sharded" or (overloaded and self._client.options.enable_overload_retargeting)
or exc.has_error_label("SystemOverloadedError")
): ):
self._deprioritized_servers.append(self._server) self._deprioritized_servers.append(self._server)
self._always_retryable = always_retryable
if overloaded:
delay = self._retry_policy.backoff(self._attempt_number)
if not self._retry_policy.should_retry(self._attempt_number, delay):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
time.sleep(delay)
def _is_not_eligible_for_retry(self) -> bool: def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry""" """Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries) return not self._retryable or (self._is_retrying() and not self._multiple_retries)
@ -2899,7 +2992,7 @@ class _ClientConnectionRetryable(Generic[T]):
and conn.supports_sessions and conn.supports_sessions
) )
is_mongos = conn.is_mongos is_mongos = conn.is_mongos
if not sessions_supported: if not self._always_retryable and not sessions_supported:
# A retry is not possible because this server does # A retry is not possible because this server does
# not support sessions raise the last error. # not support sessions raise the last error.
self._check_last_error() self._check_last_error()
@ -2931,7 +3024,7 @@ class _ClientConnectionRetryable(Generic[T]):
conn, conn,
read_pref, read_pref,
): ):
if self._retrying and not self._retryable: if self._retrying and not self._retryable and not self._always_retryable:
self._check_last_error() self._check_last_error()
if self._retrying: if self._retrying:
_debug_log( _debug_log(

View File

@ -19,6 +19,8 @@ import collections
import contextlib import contextlib
import logging import logging
import os import os
import socket
import ssl
import sys import sys
import time import time
import weakref import weakref
@ -49,10 +51,12 @@ from pymongo.errors import ( # type:ignore[attr-defined]
DocumentTooLarge, DocumentTooLarge,
ExecutionTimeout, ExecutionTimeout,
InvalidOperation, InvalidOperation,
NetworkTimeout,
NotPrimaryError, NotPrimaryError,
OperationFailure, OperationFailure,
PyMongoError, PyMongoError,
WaitQueueTimeoutError, WaitQueueTimeoutError,
_CertificateError,
) )
from pymongo.hello import Hello, HelloCompat from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
@ -250,6 +254,7 @@ class Connection:
cmd = self.hello_cmd() cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake performing_handshake = not self.performed_handshake
awaitable = False awaitable = False
cmd["backpressure"] = True
if performing_handshake: if performing_handshake:
self.performed_handshake = True self.performed_handshake = True
cmd["client"] = self.opts.metadata cmd["client"] = self.opts.metadata
@ -750,8 +755,8 @@ class Pool:
# Enforces: maxConnecting # Enforces: maxConnecting
# Also used for: clearing the wait queue # Also used for: clearing the wait queue
self._max_connecting_cond = _create_condition(self.lock) self._max_connecting_cond = _create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0 self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id self._client_id = client_id
if self.enabled_for_cmap: if self.enabled_for_cmap:
assert self.opts._event_listeners is not None assert self.opts._event_listeners is not None
@ -982,6 +987,21 @@ class Pool:
self.requests -= 1 self.requests -= 1
self.size_cond.notify() self.size_cond.notify()
def _handle_connection_error(self, error: BaseException) -> None:
# Handle system overload condition for non-sdam pools.
# Look for errors of type AutoReconnect and add error labels if appropriate.
if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout):
return
assert isinstance(error, AutoReconnect) # Appease type checker.
# If the original error was a DNS, certificate, or SSL error, ignore it.
if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)):
# End of file errors are excluded, because the server may have disconnected
# during the handshake.
if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)):
return
error._add_error_label("SystemOverloadedError")
error._add_error_label("RetryableError")
def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
"""Connect to Mongo and return a new Connection. """Connect to Mongo and return a new Connection.
@ -1033,10 +1053,10 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR), reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR, error=ConnectionClosedReason.ERROR,
) )
self._handle_connection_error(error)
if isinstance(error, (IOError, OSError, *SSLErrors)): if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts) details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details) _raise_connection_failure(self.address, error, timeout_details=details)
raise raise
conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type] conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
@ -1045,18 +1065,22 @@ class Pool:
self.active_contexts.discard(tmp_context) self.active_contexts.discard(tmp_context)
if tmp_context.cancelled: if tmp_context.cancelled:
conn.cancel_context.cancel() conn.cancel_context.cancel()
completed_hello = False
try: try:
if not self.is_sdam: if not self.is_sdam:
conn.hello() conn.hello()
completed_hello = True
self.is_writable = conn.is_writable self.is_writable = conn.is_writable
if handler: if handler:
handler.contribute_socket(conn, completed_handshake=False) handler.contribute_socket(conn, completed_handshake=False)
conn.authenticate() conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup. # Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException: except BaseException as e:
with self.lock: with self.lock:
self.active_contexts.discard(conn.cancel_context) self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR) conn.close_conn(ConnectionClosedReason.ERROR)
raise raise
@ -1385,8 +1409,8 @@ class Pool:
:class:`~pymongo.errors.AutoReconnect` exceptions on server :class:`~pymongo.errors.AutoReconnect` exceptions on server
hiccups, etc. We only check if the socket was closed by an external hiccups, etc. We only check if the socket was closed by an external
error if it has been > 1 second since the socket was checked into the error if it has been > 1 second since the socket was checked into the
pool, to keep performance reasonable - we can't avoid AutoReconnects pool to keep performance reasonable -
completely anyway. we can't avoid AutoReconnects completely anyway.
""" """
idle_time_seconds = conn.idle_time_seconds() idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one. # If socket is idle, open a new one.
@ -1397,8 +1421,9 @@ class Pool:
conn.close_conn(ConnectionClosedReason.IDLE) conn.close_conn(ConnectionClosedReason.IDLE)
return True return True
if self._check_interval_seconds is not None and ( check_interval_seconds = self._check_interval_seconds
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds if check_interval_seconds is not None and (
check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds
): ):
if conn.conn_closed(): if conn.conn_closed():
conn.close_conn(ConnectionClosedReason.ERROR) conn.close_conn(ConnectionClosedReason.ERROR)

View File

@ -911,7 +911,9 @@ class Topology:
# Clear the pool. # Clear the pool.
server.reset(service_id) server.reset(service_id)
elif isinstance(error, ConnectionFailure): elif isinstance(error, ConnectionFailure):
if isinstance(error, WaitQueueTimeoutError): if isinstance(error, WaitQueueTimeoutError) or (
error.has_error_label("SystemOverloadedError")
):
return return
# "Client MUST replace the server's description with type Unknown # "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server." # ... MUST NOT request an immediate check of the server."

View File

@ -652,6 +652,38 @@ class AsyncClientUnitTest(AsyncUnitTest):
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
self.simple_client(multi_host) self.simple_client(multi_host)
async def test_max_adaptive_retries(self):
# Assert that max adaptive retries defaults to 2.
c = self.simple_client(connect=False)
self.assertEqual(c.options.max_adaptive_retries, 2)
# Assert that max adaptive retries can be configured through connection or client options.
c = self.simple_client(connect=False, max_adaptive_retries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(connect=False, maxAdaptiveRetries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(host="mongodb://localhost/?maxAdaptiveRetries=10", connect=False)
self.assertEqual(c.options.max_adaptive_retries, 10)
async def test_enable_overload_retargeting(self):
# Assert that overload retargeting defaults to false.
c = self.simple_client(connect=False)
self.assertFalse(c.options.enable_overload_retargeting)
# Assert that overload retargeting can be enabled through connection or client options.
c = self.simple_client(connect=False, enable_overload_retargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(connect=False, enableOverloadRetargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(
host="mongodb://localhost/?enableOverloadRetargeting=true", connect=False
)
self.assertTrue(c.options.enable_overload_retargeting)
class TestClient(AsyncIntegrationTest): class TestClient(AsyncIntegrationTest):
def test_multiple_uris(self): def test_multiple_uris(self):

View File

@ -0,0 +1,312 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Client Backpressure spec."""
from __future__ import annotations
import os
import pathlib
import sys
from time import perf_counter
from unittest.mock import patch
from pymongo.common import MAX_ADAPTIVE_RETRIES
sys.path[0:0] = [""]
from test.asynchronous import (
AsyncIntegrationTest,
async_client_context,
unittest,
)
from test.asynchronous.unified_format import generate_test_classes
from test.utils_shared import EventListener, OvertCommandListener
from pymongo.errors import OperationFailure, PyMongoError
_IS_SYNC = False
# Mock a system overload error.
mock_overload_error = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find", "insert", "update"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def get_mock_overload_error(times: int):
error = mock_overload_error.copy()
error["mode"] = {"times": times}
return error
class TestBackpressure(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_command(self):
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.command("find", "t")
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.command("find", "t")
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_find(self):
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.find_one()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.find_one()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_insert_one(self):
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.insert_one({"x": 1})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.insert_one({"x": 1})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_update_many(self):
# Even though update_many is not a retryable write operation, it will
# still be retried via the "RetryableError" error label.
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.update_many({}, {"$set": {"x": 2}})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.update_many({}, {"$set": {"x": 2}})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_getMore(self):
coll = self.db.t
await coll.insert_many([{"x": 1} for _ in range(10)])
# Ensure command is retried on overload error.
fail_many = {
"configureFailPoint": "failCommand",
"mode": {"times": MAX_ADAPTIVE_RETRIES},
"data": {
"failCommands": ["getMore"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
cursor = coll.find(batch_size=2)
await cursor.next()
async with self.fail_point(fail_many):
await cursor.to_list()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = fail_many.copy()
fail_too_many["mode"] = {"times": MAX_ADAPTIVE_RETRIES + 1}
cursor = coll.find(batch_size=2)
await cursor.next()
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await cursor.to_list()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# Prose tests.
class AsyncTestClientBackpressure(AsyncIntegrationTest):
listener: EventListener
@classmethod
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
@async_client_context.require_connection
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.listener.reset()
self.app_name = self.__class__.__name__.lower()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener], appName=self.app_name
)
@patch("random.random")
@async_client_context.require_failCommand_appName
async def test_01_operation_retry_uses_exponential_backoff(self, random_func):
# Drivers should test that retries do not occur immediately when a SystemOverloadedError is encountered.
# 1. let `client` be a `MongoClient`
client = self.client
# 2. let `collection` be a collection
collection = client.test.test
# 3. Now, run transactions without backoff:
# a. Configure the random number generator used for jitter to always return `0` -- this effectively disables backoff.
random_func.return_value = 0
# b. Configure the following failPoint:
fail_point = dict(
mode="alwaysOn",
data=dict(
failCommands=["insert"],
errorCode=2,
errorLabels=["SystemOverloadedError", "RetryableError"],
appName=self.app_name,
),
)
async with self.fail_point(fail_point):
# c. Execute the following command. Expect that the command errors. Measure the duration of the command execution.
start0 = perf_counter()
with self.assertRaises(OperationFailure):
await collection.insert_one({"a": 1})
end0 = perf_counter()
# d. Configure the random number generator used for jitter to always return `1`.
random_func.return_value = 1
# e. Execute step c again.
start1 = perf_counter()
with self.assertRaises(OperationFailure):
await collection.insert_one({"a": 1})
end1 = perf_counter()
# f. Compare the times between the two runs.
# The sum of 2 backoffs is 0.3 seconds. There is a 0.3-second window to account for potential variance between the two
# runs.
self.assertTrue(abs((end1 - start1) - (end0 - start0 + 0.3)) < 0.3)
@async_client_context.require_failCommand_appName
async def test_03_overload_retries_limited(self):
# Drivers should test that overload errors are retried a maximum of two times.
# 1. Let `client` be a `MongoClient`.
client = self.client
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
async with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
await coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is MAX_ADAPTIVE_RETRIES + 1.
self.assertEqual(len(self.listener.started_events), MAX_ADAPTIVE_RETRIES + 1)
@async_client_context.require_failCommand_appName
async def test_04_overload_retries_limited_configured(self):
# Drivers should test that overload errors are retried a maximum of maxAdaptiveRetries times.
max_retries = 1
# 1. Let `client` be a `MongoClient` with `maxAdaptiveRetries=1` and command event monitoring enabled.
client = await self.async_single_client(
maxAdaptiveRetries=max_retries, event_listeners=[self.listener]
)
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
async with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
await coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is max_retries + 1.
self.assertEqual(len(self.listener.started_events), max_retries + 1)
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "client-backpressure")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-backpressure")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -219,6 +219,19 @@ class TestClientMetadataProse(AsyncIntegrationTest):
# add same metadata again # add same metadata again
await self.check_metadata_added(client, "Framework", None, None) await self.check_metadata_added(client, "Framework", None, None)
async def test_handshake_documents_include_backpressure(self):
# Create a `MongoClient` that is configured to record all handshake documents sent to the server as a part of
# connection establishment.
client = await self.async_rs_or_single_client("mongodb://" + self.server.address_string)
# Send a `ping` command to the server and verify that the command succeeds. This ensure that a connection is
# established on all topologies. Note: MockupDB only supports standalone servers.
await client.admin.command("ping")
# Assert that for every handshake document intercepted:
# the document has a field `backpressure` whose value is `true`.
self.assertEqual(self.handshake_req["backpressure"], True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -25,8 +25,10 @@ from asyncio import StreamReader, StreamWriter
from pathlib import Path from pathlib import Path
from test.asynchronous.helpers import ConcurrentRunner from test.asynchronous.helpers import ConcurrentRunner
from test.asynchronous.utils import flaky from test.asynchronous.utils import flaky
from test.utils_shared import delay
from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.pool import AsyncConnection
from pymongo.errors import ConnectionFailure
from pymongo.operations import _Op from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector from pymongo.server_selectors import writable_server_selector
@ -70,7 +72,12 @@ from pymongo.errors import (
) )
from pymongo.hello import Hello, HelloCompat from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _check_command_response, _check_write_command_response from pymongo.helpers_shared import _check_command_response, _check_write_command_response
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent from pymongo.monitoring import (
ConnectionCheckOutFailedEvent,
PoolClearedEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatStartedEvent,
)
from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.topology_description import TOPOLOGY_TYPE
@ -131,6 +138,9 @@ async def got_app_error(topology, app_error):
raise AssertionError raise AssertionError
except (AutoReconnect, NotPrimaryError, OperationFailure) as e: except (AutoReconnect, NotPrimaryError, OperationFailure) as e:
if when == "beforeHandshakeCompletes": if when == "beforeHandshakeCompletes":
# The pool would have added the SystemOverloadedError in this case.
if isinstance(e, AutoReconnect):
e._add_error_label("SystemOverloadedError")
completed_handshake = False completed_handshake = False
elif when == "afterHandshakeCompletes": elif when == "afterHandshakeCompletes":
completed_handshake = True completed_handshake = True
@ -439,6 +449,59 @@ class TestPoolManagement(AsyncIntegrationTest):
AsyncConnection.close_conn = original_close AsyncConnection.close_conn = original_close
class TestPoolBackpressure(AsyncIntegrationTest):
@async_client_context.require_version_min(7, 0, 0)
async def test_connection_pool_is_not_cleared(self):
listener = CMAPListener()
# Create a client that listens to CMAP events, with maxConnecting=100.
client = await self.async_rs_or_single_client(maxConnecting=100, event_listeners=[listener])
# Enable the ingress rate limiter.
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=True
)
await client.admin.command("setParameter", 1, ingressConnectionEstablishmentRatePerSec=20)
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentBurstCapacitySecs=1
)
await client.admin.command("setParameter", 1, ingressConnectionEstablishmentMaxQueueDepth=1)
# Disable the ingress rate limiter on teardown.
# Sleep for 1 second before disabling to avoid the rate limiter.
async def teardown():
await asyncio.sleep(1)
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=False
)
self.addAsyncCleanup(teardown)
# Make sure the collection has at least one document.
await client.test.test.delete_many({})
await client.test.test.insert_one({})
# Run a slow operation to tie up the connection.
async def target():
try:
await client.test.test.find_one({"$where": delay(0.1)})
except ConnectionFailure:
pass
# Run 100 parallel operations that contend for connections.
tasks = []
for _ in range(100):
tasks.append(ConcurrentRunner(target=target))
for t in tasks:
await t.start()
for t in tasks:
await t.join()
# Verify there were at least 10 connection checkout failed event but no pool cleared events.
self.assertGreater(len(listener.events_by_type(ConnectionCheckOutFailedEvent)), 10)
self.assertEqual(len(listener.events_by_type(PoolClearedEvent)), 0)
class TestServerMonitoringMode(AsyncIntegrationTest): class TestServerMonitoringMode(AsyncIntegrationTest):
@async_client_context.require_no_load_balancer @async_client_context.require_no_load_balancer
async def asyncSetUp(self): async def asyncSetUp(self):

View File

@ -513,6 +513,39 @@ class TestPooling(_TestPoolingBase):
str(error.exception), str(error.exception),
) )
@async_client_context.require_failCommand_appName
async def test_pool_backpressure_preserves_existing_connections(self):
client = await self.async_rs_or_single_client()
coll = client.pymongo_test.t
pool = await async_get_pool(client)
await coll.insert_many([{"x": 1} for _ in range(10)])
t = SocketGetter(self.c, pool)
await t.start()
while t.state != "connection":
await asyncio.sleep(0.1)
assert not t.sock.conn_closed()
# Mock a session establishment overload.
mock_connection_fail = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"closeConnection": True,
},
}
async with self.fail_point(mock_connection_fail):
await coll.find_one({})
# Make sure the existing socket was not affected.
assert not t.sock.conn_closed()
# Cleanup
await t.release_conn()
await t.join()
await pool.close()
class TestPoolMaxSize(_TestPoolingBase): class TestPoolMaxSize(_TestPoolingBase):
async def test_max_pool_size(self): async def test_max_pool_size(self):

View File

@ -265,14 +265,17 @@ class TestRetryableReads(AsyncIntegrationTest):
@async_client_context.require_secondaries_count(1) @async_client_context.require_secondaries_count(1)
@async_client_context.require_failCommand_fail_point @async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, 0) @async_client_context.require_version_min(4, 4, 0)
async def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available( async def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available_and_overload_retargeting_is_enabled(
self self
): ):
listener = OvertCommandListener() listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled. # 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, `enableOverloadRetargeting=True`, and command event monitoring enabled.
client = await self.async_rs_or_single_client( client = await self.async_rs_or_single_client(
event_listeners=[listener], retryReads=True, readPreference="primaryPreferred" event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
enableOverloadRetargeting=True,
) )
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
@ -339,6 +342,47 @@ class TestRetryableReads(AsyncIntegrationTest):
# 6. Assert that both events occurred the same server. # 6. Assert that both events occurred the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
@async_client_context.require_replica_set
@async_client_context.require_secondaries_count(1)
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, 0)
async def test_03_03_retryable_reads_caused_by_overload_errors_are_retried_on_the_same_replicaset_server_when_one_is_available_and_overload_retargeting_is_disabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
client = await self.async_rs_or_single_client(
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 6,
},
}
await async_set_fail_point(client, command_args)
# 3. Reset the command event monitor to clear the fail point command from its stored events.
listener.reset()
# 4. Execute a `find` command with `client`.
await client.t.t.find_one({})
# 5. Assert that one failed command event and one successful command event occurred.
self.assertEqual(len(listener.failed_events), 1)
self.assertEqual(len(listener.succeeded_events), 1)
# 6. Assert that both events occurred on the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -43,14 +43,17 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.int64 import Int64 from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument from bson.raw_bson import RawBSONDocument
from bson.son import SON from bson.son import SON
from pymongo import MongoClient
from pymongo.errors import ( from pymongo.errors import (
AutoReconnect, AutoReconnect,
ConnectionFailure, ConnectionFailure,
OperationFailure, NotPrimaryError,
PyMongoError,
ServerSelectionTimeoutError, ServerSelectionTimeoutError,
WriteConcernError, WriteConcernError,
) )
from pymongo.monitoring import ( from pymongo.monitoring import (
CommandFailedEvent,
CommandSucceededEvent, CommandSucceededEvent,
ConnectionCheckedOutEvent, ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent, ConnectionCheckOutFailedEvent,
@ -601,5 +604,186 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
self.assertEqual(sent_txn_id, final_txn_id, msg) self.assertEqual(sent_txn_id, final_txn_id, msg)
class TestErrorPropagationAfterEncounteringMultipleErrors(AsyncIntegrationTest):
# Only run against replica sets as mongos does not propagate the NoWritesPerformed label to the drivers.
@async_client_context.require_replica_set
# Run against server versions 6.0 and above.
@async_client_context.require_version_min(6, 0) # type: ignore[untyped-decorator]
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.setup_client = MongoClient(**async_client_context.default_client_options)
self.addCleanup(self.setup_client.close)
# TODO: After PYTHON-4595 we can use async event handlers and remove this workaround.
def configure_fail_point_sync(self, command_args, off=False) -> None:
cmd = {"configureFailPoint": "failCommand"}
cmd.update(command_args)
if off:
cmd["mode"] = "off"
cmd.pop("data", None)
self.setup_client.admin.command(cmd)
async def test_01_drivers_return_the_correct_error_when_receiving_only_errors_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code 10107 (NotWritablePrimary).
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 10107.
assert exc.exception.errors["code"] == 10107 # type:ignore[call-overload]
async def test_02_drivers_return_the_correct_error_when_receiving_only_errors_with_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code `10107` (NotWritablePrimary)
# and a NoWritesPerformed label.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
},
}
def failed(event: CommandFailedEvent) -> None:
if listener.failed_events:
return
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91 # type:ignore[call-overload]
async def test_03_drivers_return_the_correct_error_when_receiving_some_errors_with_NoWritesPerformed_and_some_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (NotWritablePrimary) and the `NoWritesPerformed`, `RetryableError` and `SystemOverloadedError` labels.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and
# `SystemOverloadedError` error labels but without the `NoWritesPerformed` error label.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(PyMongoError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91
# Assert that the error does not contain the error label `NoWritesPerformed`.
assert "NoWritesPerformed" not in exc.exception.errors["errorLabels"]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -16,9 +16,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import random
import sys import sys
import time
from io import BytesIO from io import BytesIO
from unittest.mock import patch
import pymongo
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
from pymongo.asynchronous.pool import PoolState from pymongo.asynchronous.pool import PoolState
from pymongo.server_selectors import writable_server_selector from pymongo.server_selectors import writable_server_selector
@ -45,7 +49,9 @@ from pymongo.errors import (
CollectionInvalid, CollectionInvalid,
ConfigurationError, ConfigurationError,
ConnectionFailure, ConnectionFailure,
ExecutionTimeout,
InvalidOperation, InvalidOperation,
NetworkTimeout,
OperationFailure, OperationFailure,
) )
from pymongo.operations import IndexModel, InsertOne from pymongo.operations import IndexModel, InsertOne
@ -434,7 +440,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await self.configure_fail_point(client, command_args) await self.configure_fail_point(client, command_args)
@async_client_context.require_transactions @async_client_context.require_transactions
async def test_callback_raises_custom_error(self): async def test_1_callback_raises_custom_error(self):
class _MyException(Exception): class _MyException(Exception):
pass pass
@ -446,7 +452,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await s.with_transaction(raise_error) await s.with_transaction(raise_error)
@async_client_context.require_transactions @async_client_context.require_transactions
async def test_callback_returns_value(self): async def test_2_callback_returns_value(self):
async def callback(_): async def callback(_):
return "Foo" return "Foo"
@ -474,7 +480,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
self.assertEqual(await s.with_transaction(callback), "Foo") self.assertEqual(await s.with_transaction(callback), "Foo")
@async_client_context.require_transactions @async_client_context.require_transactions
async def test_callback_not_retried_after_timeout(self): async def test_3_1_callback_not_retried_after_timeout(self):
listener = OvertCommandListener() listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener]) client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test coll = client[self.db.name].test
@ -495,14 +501,16 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
listener.reset() listener.reset()
async with client.start_session() as s: async with client.start_session() as s:
with PatchSessionTimeout(0): with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure): with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback) await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@async_client_context.require_test_commands @async_client_context.require_test_commands
@async_client_context.require_transactions @async_client_context.require_transactions
async def test_callback_not_retried_after_commit_timeout(self): async def test_3_2_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener() listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener]) client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test coll = client[self.db.name].test
@ -529,14 +537,16 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async with client.start_session() as s: async with client.start_session() as s:
with PatchSessionTimeout(0): with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure): with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback) await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@async_client_context.require_test_commands @async_client_context.require_test_commands
@async_client_context.require_transactions @async_client_context.require_transactions
async def test_commit_not_retried_after_timeout(self): async def test_3_3_commit_not_retried_after_timeout(self):
listener = OvertCommandListener() listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener]) client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test coll = client[self.db.name].test
@ -560,7 +570,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async with client.start_session() as s: async with client.start_session() as s:
with PatchSessionTimeout(0): with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure): with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback) await s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic # One insert for the callback and two commits (includes the automatic
@ -568,6 +578,40 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
self.assertEqual( self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
) )
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
@async_client_context.require_transactions
async def test_callback_not_retried_after_csot_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
async def callback(session):
await coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)
# Create the collection.
await coll.insert_one({})
listener.reset()
async with client.start_session() as s:
with pymongo.timeout(1.0):
with self.assertRaises(ExecutionTimeout):
await s.with_transaction(callback)
# At least two attempts: the original and one or more retries.
inserts = len([x for x in listener.started_command_names() if x == "insert"])
aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"])
self.assertGreaterEqual(inserts, 2)
self.assertGreaterEqual(aborts, 2)
# Tested here because this supports Motor's convenient transactions API. # Tested here because this supports Motor's convenient transactions API.
@async_client_context.require_transactions @async_client_context.require_transactions
@ -606,6 +650,63 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await s.with_transaction(callback) await s.with_transaction(callback)
self.assertFalse(s.in_transaction) self.assertFalse(s.in_transaction)
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_4_retry_backoff_is_enforced(self):
client = async_client_context.client
coll = client[self.db.name].test
end = start = no_backoff_time = 0
# Make random.random always return 0 (no backoff)
with patch.object(random, "random", return_value=0):
# set fail point to trigger transaction failure and trigger backoff
await self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {"times": 13},
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
async def callback(session):
await coll.insert_one({}, session=session)
start = time.monotonic()
async with self.client.start_session() as s:
await s.with_transaction(callback)
end = time.monotonic()
no_backoff_time = end - start
# Make random.random always return 1 (max backoff)
with patch.object(random, "random", return_value=1):
# set fail point to trigger transaction failure and trigger backoff
await self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {
"times": 13
}, # sufficiently high enough such that the time effect of backoff is noticeable
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
start = time.monotonic()
async with self.client.start_session() as s:
await s.with_transaction(callback)
end = time.monotonic()
self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2
class TestOptionsInsideTransactionProse(AsyncTransactionsBase): class TestOptionsInsideTransactionProse(AsyncTransactionsBase):
@async_client_context.require_transactions @async_client_context.require_transactions

View File

@ -0,0 +1,111 @@
{
"description": "tests that connections are returned to the pool on retry attempts for overload errors",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"connectionCheckedOutEvent",
"connectionCheckedInEvent"
]
}
},
{
"client": {
"id": "fail_point_client",
"useMultipleMongoses": false
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "backpressure-connection-checkin"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "coll"
}
}
],
"tests": [
{
"description": "overload error retry attempts return connections to the pool",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "fail_point_client",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"find"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"object": "collection",
"arguments": {
"filter": {}
},
"expectError": {
"isError": true,
"isClientError": false
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": [
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
},
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
},
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
}
]
}
]
}
]
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,253 @@
{
"description": "getMore-retried-backpressure",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4"
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent",
"commandFailedEvent",
"commandSucceededEvent"
]
}
},
{
"client": {
"id": "failPointClient",
"useMultipleMongoses": false
}
},
{
"database": {
"id": "db",
"client": "client0",
"databaseName": "default"
}
},
{
"collection": {
"id": "coll",
"database": "db",
"collectionName": "default"
}
}
],
"initialData": [
{
"databaseName": "default",
"collectionName": "default",
"documents": [
{
"a": 1
},
{
"a": 2
},
{
"a": 3
}
]
}
],
"tests": [
{
"description": "getMores are retried",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "failPointClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"getMore"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"object": "coll",
"arguments": {
"batchSize": 2,
"filter": {},
"sort": {
"a": 1
}
},
"expectResult": [
{
"a": 1
},
{
"a": 2
},
{
"a": 3
}
]
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandSucceededEvent": {
"commandName": "getMore"
}
}
]
}
]
},
{
"description": "getMores are retried maxAttempts=2 times",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "failPointClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"getMore"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"arguments": {
"batchSize": 2,
"filter": {}
},
"object": "coll",
"expectError": {
"isError": true,
"isClientError": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "killCursors"
}
},
{
"commandSucceededEvent": {
"commandName": "killCursors"
}
}
]
}
]
}
]
}

View File

@ -9,9 +9,7 @@
], ],
"failPoint": { "failPoint": {
"configureFailPoint": "failCommand", "configureFailPoint": "failCommand",
"mode": { "mode": "alwaysOn",
"times": 50
},
"data": { "data": {
"failCommands": [ "failCommands": [
"isMaster", "isMaster",

View File

@ -97,14 +97,22 @@
"outcome": { "outcome": {
"servers": { "servers": {
"a:27017": { "a:27017": {
"type": "Unknown", "type": "RSPrimary",
"topologyVersion": null, "setName": "rs",
"topologyVersion": {
"processId": {
"$oid": "000000000000000000000001"
},
"counter": {
"$numberLong": "1"
}
},
"pool": { "pool": {
"generation": 1 "generation": 0
} }
} }
}, },
"topologyType": "ReplicaSetNoPrimary", "topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null, "logicalSessionTimeoutMinutes": null,
"setName": "rs" "setName": "rs"
} }

View File

@ -0,0 +1,142 @@
{
"description": "backpressure-network-error-fail-replicaset",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"replicaset"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"initialData": [
{
"collectionName": "backpressure-network-error-fail",
"databaseName": "sdam-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
],
"tests": [
{
"description": "apply backpressure on network connection errors during connection establishment",
"operations": [
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"serverDescriptionChangedEvent",
"poolClearedEvent"
],
"uriOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 1000000,
"serverMonitoringMode": "poll",
"appname": "backpressureNetworkErrorFailTest"
}
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "sdam-tests"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "backpressure-network-error-fail"
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {
"newDescription": {
"type": "RSPrimary"
}
}
},
"count": 1
}
},
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",
"hello"
],
"appName": "backpressureNetworkErrorFailTest",
"closeConnection": true
}
}
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"expectError": {
"isError": true,
"errorLabelsContain": [
"SystemOverloadedError",
"RetryableError"
]
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,142 @@
{
"description": "backpressure-network-error-fail-single",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"single"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"initialData": [
{
"collectionName": "backpressure-network-error-fail",
"databaseName": "sdam-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
],
"tests": [
{
"description": "apply backpressure on network connection errors during connection establishment",
"operations": [
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"serverDescriptionChangedEvent",
"poolClearedEvent"
],
"uriOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 1000000,
"serverMonitoringMode": "poll",
"appname": "backpressureNetworkErrorFailTest"
}
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "sdam-tests"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "backpressure-network-error-fail"
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {
"newDescription": {
"type": "Standalone"
}
}
},
"count": 1
}
},
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",
"hello"
],
"appName": "backpressureNetworkErrorFailTest",
"closeConnection": true
}
}
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"expectError": {
"isError": true,
"errorLabelsContain": [
"SystemOverloadedError",
"RetryableError"
]
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,145 @@
{
"description": "backpressure-network-timeout-error-replicaset",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"replicaset"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"initialData": [
{
"collectionName": "backpressure-network-timeout-error",
"databaseName": "sdam-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
],
"tests": [
{
"description": "apply backpressure on network timeout error during connection establishment",
"operations": [
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"serverDescriptionChangedEvent",
"poolClearedEvent"
],
"uriOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 1000000,
"appname": "backpressureNetworkTimeoutErrorTest",
"serverMonitoringMode": "poll",
"connectTimeoutMS": 250,
"socketTimeoutMS": 250
}
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "sdam-tests"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "backpressure-network-timeout-error"
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {
"newDescription": {
"type": "RSPrimary"
}
}
},
"count": 1
}
},
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",
"hello"
],
"blockConnection": true,
"blockTimeMS": 500,
"appName": "backpressureNetworkTimeoutErrorTest"
}
}
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"expectError": {
"isError": true,
"errorLabelsContain": [
"SystemOverloadedError",
"RetryableError"
]
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,145 @@
{
"description": "backpressure-network-timeout-error-single",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"single"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"initialData": [
{
"collectionName": "backpressure-network-timeout-error",
"databaseName": "sdam-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
],
"tests": [
{
"description": "apply backpressure on network timeout error during connection establishment",
"operations": [
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"serverDescriptionChangedEvent",
"poolClearedEvent"
],
"uriOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 1000000,
"appname": "backpressureNetworkTimeoutErrorTest",
"serverMonitoringMode": "poll",
"connectTimeoutMS": 250,
"socketTimeoutMS": 250
}
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "sdam-tests"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "backpressure-network-timeout-error"
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {
"newDescription": {
"type": "Standalone"
}
}
},
"count": 1
}
},
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",
"hello"
],
"blockConnection": true,
"blockTimeMS": 500,
"appName": "backpressureNetworkTimeoutErrorTest"
}
}
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"expectError": {
"isError": true,
"errorLabelsContain": [
"SystemOverloadedError",
"RetryableError"
]
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,106 @@
{
"description": "backpressure-server-description-unchanged-on-min-pool-size-population-error",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"single"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"tests": [
{
"description": "the server description is not changed on handshake error during minPoolSize population",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"skip": 1
},
"data": {
"failCommands": [
"hello",
"isMaster"
],
"appName": "authErrorTest",
"closeConnection": true
}
}
}
},
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"observeEvents": [
"serverDescriptionChangedEvent",
"connectionClosedEvent"
],
"uriOptions": {
"appname": "authErrorTest",
"minPoolSize": 5,
"maxConnecting": 1,
"serverMonitoringMode": "poll",
"heartbeatFrequencyMS": 1000000
}
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {}
},
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"connectionClosedEvent": {}
},
"count": 1
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "sdam",
"events": [
{
"serverDescriptionChangedEvent": {}
}
]
}
]
}
]
}

View File

@ -645,6 +645,38 @@ class ClientUnitTest(UnitTest):
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
self.simple_client(multi_host) self.simple_client(multi_host)
def test_max_adaptive_retries(self):
# Assert that max adaptive retries defaults to 2.
c = self.simple_client(connect=False)
self.assertEqual(c.options.max_adaptive_retries, 2)
# Assert that max adaptive retries can be configured through connection or client options.
c = self.simple_client(connect=False, max_adaptive_retries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(connect=False, maxAdaptiveRetries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(host="mongodb://localhost/?maxAdaptiveRetries=10", connect=False)
self.assertEqual(c.options.max_adaptive_retries, 10)
def test_enable_overload_retargeting(self):
# Assert that overload retargeting defaults to false.
c = self.simple_client(connect=False)
self.assertFalse(c.options.enable_overload_retargeting)
# Assert that overload retargeting can be enabled through connection or client options.
c = self.simple_client(connect=False, enable_overload_retargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(connect=False, enableOverloadRetargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(
host="mongodb://localhost/?enableOverloadRetargeting=true", connect=False
)
self.assertTrue(c.options.enable_overload_retargeting)
class TestClient(IntegrationTest): class TestClient(IntegrationTest):
def test_multiple_uris(self): def test_multiple_uris(self):

View File

@ -0,0 +1,310 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Client Backpressure spec."""
from __future__ import annotations
import os
import pathlib
import sys
from time import perf_counter
from unittest.mock import patch
from pymongo.common import MAX_ADAPTIVE_RETRIES
sys.path[0:0] = [""]
from test import (
IntegrationTest,
client_context,
unittest,
)
from test.unified_format import generate_test_classes
from test.utils_shared import EventListener, OvertCommandListener
from pymongo.errors import OperationFailure, PyMongoError
_IS_SYNC = True
# Mock a system overload error.
mock_overload_error = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find", "insert", "update"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def get_mock_overload_error(times: int):
error = mock_overload_error.copy()
error["mode"] = {"times": times}
return error
class TestBackpressure(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
@client_context.require_failCommand_appName
def test_retry_overload_error_command(self):
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.command("find", "t")
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.command("find", "t")
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_find(self):
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.find_one()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.find_one()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_insert_one(self):
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.insert_one({"x": 1})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.insert_one({"x": 1})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_update_many(self):
# Even though update_many is not a retryable write operation, it will
# still be retried via the "RetryableError" error label.
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.update_many({}, {"$set": {"x": 2}})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.update_many({}, {"$set": {"x": 2}})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_getMore(self):
coll = self.db.t
coll.insert_many([{"x": 1} for _ in range(10)])
# Ensure command is retried on overload error.
fail_many = {
"configureFailPoint": "failCommand",
"mode": {"times": MAX_ADAPTIVE_RETRIES},
"data": {
"failCommands": ["getMore"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
cursor = coll.find(batch_size=2)
cursor.next()
with self.fail_point(fail_many):
cursor.to_list()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = fail_many.copy()
fail_too_many["mode"] = {"times": MAX_ADAPTIVE_RETRIES + 1}
cursor = coll.find(batch_size=2)
cursor.next()
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
cursor.to_list()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# Prose tests.
class TestClientBackpressure(IntegrationTest):
listener: EventListener
@classmethod
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
@client_context.require_connection
def setUp(self) -> None:
super().setUp()
self.listener.reset()
self.app_name = self.__class__.__name__.lower()
self.client = self.rs_or_single_client(
event_listeners=[self.listener], appName=self.app_name
)
@patch("random.random")
@client_context.require_failCommand_appName
def test_01_operation_retry_uses_exponential_backoff(self, random_func):
# Drivers should test that retries do not occur immediately when a SystemOverloadedError is encountered.
# 1. let `client` be a `MongoClient`
client = self.client
# 2. let `collection` be a collection
collection = client.test.test
# 3. Now, run transactions without backoff:
# a. Configure the random number generator used for jitter to always return `0` -- this effectively disables backoff.
random_func.return_value = 0
# b. Configure the following failPoint:
fail_point = dict(
mode="alwaysOn",
data=dict(
failCommands=["insert"],
errorCode=2,
errorLabels=["SystemOverloadedError", "RetryableError"],
appName=self.app_name,
),
)
with self.fail_point(fail_point):
# c. Execute the following command. Expect that the command errors. Measure the duration of the command execution.
start0 = perf_counter()
with self.assertRaises(OperationFailure):
collection.insert_one({"a": 1})
end0 = perf_counter()
# d. Configure the random number generator used for jitter to always return `1`.
random_func.return_value = 1
# e. Execute step c again.
start1 = perf_counter()
with self.assertRaises(OperationFailure):
collection.insert_one({"a": 1})
end1 = perf_counter()
# f. Compare the times between the two runs.
# The sum of 2 backoffs is 0.3 seconds. There is a 0.3-second window to account for potential variance between the two
# runs.
self.assertTrue(abs((end1 - start1) - (end0 - start0 + 0.3)) < 0.3)
@client_context.require_failCommand_appName
def test_03_overload_retries_limited(self):
# Drivers should test that overload errors are retried a maximum of two times.
# 1. Let `client` be a `MongoClient`.
client = self.client
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is MAX_ADAPTIVE_RETRIES + 1.
self.assertEqual(len(self.listener.started_events), MAX_ADAPTIVE_RETRIES + 1)
@client_context.require_failCommand_appName
def test_04_overload_retries_limited_configured(self):
# Drivers should test that overload errors are retried a maximum of maxAdaptiveRetries times.
max_retries = 1
# 1. Let `client` be a `MongoClient` with `maxAdaptiveRetries=1` and command event monitoring enabled.
client = self.single_client(maxAdaptiveRetries=max_retries, event_listeners=[self.listener])
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is max_retries + 1.
self.assertEqual(len(self.listener.started_events), max_retries + 1)
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "client-backpressure")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-backpressure")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -219,6 +219,19 @@ class TestClientMetadataProse(IntegrationTest):
# add same metadata again # add same metadata again
self.check_metadata_added(client, "Framework", None, None) self.check_metadata_added(client, "Framework", None, None)
def test_handshake_documents_include_backpressure(self):
# Create a `MongoClient` that is configured to record all handshake documents sent to the server as a part of
# connection establishment.
client = self.rs_or_single_client("mongodb://" + self.server.address_string)
# Send a `ping` command to the server and verify that the command succeeds. This ensure that a connection is
# established on all topologies. Note: MockupDB only supports standalone servers.
client.admin.command("ping")
# Assert that for every handshake document intercepted:
# the document has a field `backpressure` whose value is `true`.
self.assertEqual(self.handshake_req["backpressure"], True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -25,7 +25,9 @@ from asyncio import StreamReader, StreamWriter
from pathlib import Path from pathlib import Path
from test.helpers import ConcurrentRunner from test.helpers import ConcurrentRunner
from test.utils import flaky from test.utils import flaky
from test.utils_shared import delay
from pymongo.errors import ConnectionFailure
from pymongo.operations import _Op from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import Connection from pymongo.synchronous.pool import Connection
@ -67,7 +69,12 @@ from pymongo.errors import (
) )
from pymongo.hello import Hello, HelloCompat from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _check_command_response, _check_write_command_response from pymongo.helpers_shared import _check_command_response, _check_write_command_response
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent from pymongo.monitoring import (
ConnectionCheckOutFailedEvent,
PoolClearedEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatStartedEvent,
)
from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.synchronous.topology import Topology, _ErrorContext
@ -131,6 +138,9 @@ def got_app_error(topology, app_error):
raise AssertionError raise AssertionError
except (AutoReconnect, NotPrimaryError, OperationFailure) as e: except (AutoReconnect, NotPrimaryError, OperationFailure) as e:
if when == "beforeHandshakeCompletes": if when == "beforeHandshakeCompletes":
# The pool would have added the SystemOverloadedError in this case.
if isinstance(e, AutoReconnect):
e._add_error_label("SystemOverloadedError")
completed_handshake = False completed_handshake = False
elif when == "afterHandshakeCompletes": elif when == "afterHandshakeCompletes":
completed_handshake = True completed_handshake = True
@ -437,6 +447,57 @@ class TestPoolManagement(IntegrationTest):
Connection.close_conn = original_close Connection.close_conn = original_close
class TestPoolBackpressure(IntegrationTest):
@client_context.require_version_min(7, 0, 0)
def test_connection_pool_is_not_cleared(self):
listener = CMAPListener()
# Create a client that listens to CMAP events, with maxConnecting=100.
client = self.rs_or_single_client(maxConnecting=100, event_listeners=[listener])
# Enable the ingress rate limiter.
client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=True
)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentRatePerSec=20)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentBurstCapacitySecs=1)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentMaxQueueDepth=1)
# Disable the ingress rate limiter on teardown.
# Sleep for 1 second before disabling to avoid the rate limiter.
def teardown():
time.sleep(1)
client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=False
)
self.addCleanup(teardown)
# Make sure the collection has at least one document.
client.test.test.delete_many({})
client.test.test.insert_one({})
# Run a slow operation to tie up the connection.
def target():
try:
client.test.test.find_one({"$where": delay(0.1)})
except ConnectionFailure:
pass
# Run 100 parallel operations that contend for connections.
tasks = []
for _ in range(100):
tasks.append(ConcurrentRunner(target=target))
for t in tasks:
t.start()
for t in tasks:
t.join()
# Verify there were at least 10 connection checkout failed event but no pool cleared events.
self.assertGreater(len(listener.events_by_type(ConnectionCheckOutFailedEvent)), 10)
self.assertEqual(len(listener.events_by_type(PoolClearedEvent)), 0)
class TestServerMonitoringMode(IntegrationTest): class TestServerMonitoringMode(IntegrationTest):
@client_context.require_no_load_balancer @client_context.require_no_load_balancer
def setUp(self): def setUp(self):

View File

@ -511,6 +511,39 @@ class TestPooling(_TestPoolingBase):
str(error.exception), str(error.exception),
) )
@client_context.require_failCommand_appName
def test_pool_backpressure_preserves_existing_connections(self):
client = self.rs_or_single_client()
coll = client.pymongo_test.t
pool = get_pool(client)
coll.insert_many([{"x": 1} for _ in range(10)])
t = SocketGetter(self.c, pool)
t.start()
while t.state != "connection":
time.sleep(0.1)
assert not t.sock.conn_closed()
# Mock a session establishment overload.
mock_connection_fail = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"closeConnection": True,
},
}
with self.fail_point(mock_connection_fail):
coll.find_one({})
# Make sure the existing socket was not affected.
assert not t.sock.conn_closed()
# Cleanup
t.release_conn()
t.join()
pool.close()
class TestPoolMaxSize(_TestPoolingBase): class TestPoolMaxSize(_TestPoolingBase):
def test_max_pool_size(self): def test_max_pool_size(self):

View File

@ -263,14 +263,17 @@ class TestRetryableReads(IntegrationTest):
@client_context.require_secondaries_count(1) @client_context.require_secondaries_count(1)
@client_context.require_failCommand_fail_point @client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, 0) @client_context.require_version_min(4, 4, 0)
def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available( def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available_and_overload_retargeting_is_enabled(
self self
): ):
listener = OvertCommandListener() listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled. # 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, `enableOverloadRetargeting=True`, and command event monitoring enabled.
client = self.rs_or_single_client( client = self.rs_or_single_client(
event_listeners=[listener], retryReads=True, readPreference="primaryPreferred" event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
enableOverloadRetargeting=True,
) )
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
@ -337,6 +340,47 @@ class TestRetryableReads(IntegrationTest):
# 6. Assert that both events occurred the same server. # 6. Assert that both events occurred the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
@client_context.require_replica_set
@client_context.require_secondaries_count(1)
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, 0)
def test_03_03_retryable_reads_caused_by_overload_errors_are_retried_on_the_same_replicaset_server_when_one_is_available_and_overload_retargeting_is_disabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
client = self.rs_or_single_client(
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 6,
},
}
set_fail_point(client, command_args)
# 3. Reset the command event monitor to clear the fail point command from its stored events.
listener.reset()
# 4. Execute a `find` command with `client`.
client.t.t.find_one({})
# 5. Assert that one failed command event and one successful command event occurred.
self.assertEqual(len(listener.failed_events), 1)
self.assertEqual(len(listener.succeeded_events), 1)
# 6. Assert that both events occurred on the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -43,14 +43,17 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.int64 import Int64 from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument from bson.raw_bson import RawBSONDocument
from bson.son import SON from bson.son import SON
from pymongo import MongoClient
from pymongo.errors import ( from pymongo.errors import (
AutoReconnect, AutoReconnect,
ConnectionFailure, ConnectionFailure,
OperationFailure, NotPrimaryError,
PyMongoError,
ServerSelectionTimeoutError, ServerSelectionTimeoutError,
WriteConcernError, WriteConcernError,
) )
from pymongo.monitoring import ( from pymongo.monitoring import (
CommandFailedEvent,
CommandSucceededEvent, CommandSucceededEvent,
ConnectionCheckedOutEvent, ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent, ConnectionCheckOutFailedEvent,
@ -597,5 +600,186 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
self.assertEqual(sent_txn_id, final_txn_id, msg) self.assertEqual(sent_txn_id, final_txn_id, msg)
class TestErrorPropagationAfterEncounteringMultipleErrors(IntegrationTest):
# Only run against replica sets as mongos does not propagate the NoWritesPerformed label to the drivers.
@client_context.require_replica_set
# Run against server versions 6.0 and above.
@client_context.require_version_min(6, 0) # type: ignore[untyped-decorator]
def setUp(self) -> None:
super().setUp()
self.setup_client = MongoClient(**client_context.default_client_options)
self.addCleanup(self.setup_client.close)
# TODO: After PYTHON-4595 we can use async event handlers and remove this workaround.
def configure_fail_point_sync(self, command_args, off=False) -> None:
cmd = {"configureFailPoint": "failCommand"}
cmd.update(command_args)
if off:
cmd["mode"] = "off"
cmd.pop("data", None)
self.setup_client.admin.command(cmd)
def test_01_drivers_return_the_correct_error_when_receiving_only_errors_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code 10107 (NotWritablePrimary).
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 10107.
assert exc.exception.errors["code"] == 10107 # type:ignore[call-overload]
def test_02_drivers_return_the_correct_error_when_receiving_only_errors_with_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code `10107` (NotWritablePrimary)
# and a NoWritesPerformed label.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
},
}
def failed(event: CommandFailedEvent) -> None:
if listener.failed_events:
return
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91 # type:ignore[call-overload]
def test_03_drivers_return_the_correct_error_when_receiving_some_errors_with_NoWritesPerformed_and_some_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (NotWritablePrimary) and the `NoWritesPerformed`, `RetryableError` and `SystemOverloadedError` labels.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and
# `SystemOverloadedError` error labels but without the `NoWritesPerformed` error label.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(PyMongoError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91
# Assert that the error does not contain the error label `NoWritesPerformed`.
assert "NoWritesPerformed" not in exc.exception.errors["errorLabels"]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -16,9 +16,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import random
import sys import sys
import time
from io import BytesIO from io import BytesIO
from unittest.mock import patch
import pymongo
from gridfs.synchronous.grid_file import GridFS, GridFSBucket from gridfs.synchronous.grid_file import GridFS, GridFSBucket
from pymongo.server_selectors import writable_server_selector from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import PoolState from pymongo.synchronous.pool import PoolState
@ -40,7 +44,9 @@ from pymongo.errors import (
CollectionInvalid, CollectionInvalid,
ConfigurationError, ConfigurationError,
ConnectionFailure, ConnectionFailure,
ExecutionTimeout,
InvalidOperation, InvalidOperation,
NetworkTimeout,
OperationFailure, OperationFailure,
) )
from pymongo.operations import IndexModel, InsertOne from pymongo.operations import IndexModel, InsertOne
@ -426,7 +432,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.configure_fail_point(client, command_args) self.configure_fail_point(client, command_args)
@client_context.require_transactions @client_context.require_transactions
def test_callback_raises_custom_error(self): def test_1_callback_raises_custom_error(self):
class _MyException(Exception): class _MyException(Exception):
pass pass
@ -438,7 +444,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
s.with_transaction(raise_error) s.with_transaction(raise_error)
@client_context.require_transactions @client_context.require_transactions
def test_callback_returns_value(self): def test_2_callback_returns_value(self):
def callback(_): def callback(_):
return "Foo" return "Foo"
@ -466,7 +472,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.assertEqual(s.with_transaction(callback), "Foo") self.assertEqual(s.with_transaction(callback), "Foo")
@client_context.require_transactions @client_context.require_transactions
def test_callback_not_retried_after_timeout(self): def test_3_1_callback_not_retried_after_timeout(self):
listener = OvertCommandListener() listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener]) client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test coll = client[self.db.name].test
@ -487,14 +493,16 @@ class TestTransactionsConvenientAPI(TransactionsBase):
listener.reset() listener.reset()
with client.start_session() as s: with client.start_session() as s:
with PatchSessionTimeout(0): with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure): with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback) s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@client_context.require_test_commands @client_context.require_test_commands
@client_context.require_transactions @client_context.require_transactions
def test_callback_not_retried_after_commit_timeout(self): def test_3_2_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener() listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener]) client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test coll = client[self.db.name].test
@ -519,14 +527,16 @@ class TestTransactionsConvenientAPI(TransactionsBase):
with client.start_session() as s: with client.start_session() as s:
with PatchSessionTimeout(0): with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure): with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback) s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@client_context.require_test_commands @client_context.require_test_commands
@client_context.require_transactions @client_context.require_transactions
def test_commit_not_retried_after_timeout(self): def test_3_3_commit_not_retried_after_timeout(self):
listener = OvertCommandListener() listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener]) client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test coll = client[self.db.name].test
@ -548,7 +558,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
with client.start_session() as s: with client.start_session() as s:
with PatchSessionTimeout(0): with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure): with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback) s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic # One insert for the callback and two commits (includes the automatic
@ -556,6 +566,40 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.assertEqual( self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
) )
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
@client_context.require_transactions
def test_callback_not_retried_after_csot_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
def callback(session):
coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)
# Create the collection.
coll.insert_one({})
listener.reset()
with client.start_session() as s:
with pymongo.timeout(1.0):
with self.assertRaises(ExecutionTimeout):
s.with_transaction(callback)
# At least two attempts: the original and one or more retries.
inserts = len([x for x in listener.started_command_names() if x == "insert"])
aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"])
self.assertGreaterEqual(inserts, 2)
self.assertGreaterEqual(aborts, 2)
# Tested here because this supports Motor's convenient transactions API. # Tested here because this supports Motor's convenient transactions API.
@client_context.require_transactions @client_context.require_transactions
@ -594,6 +638,63 @@ class TestTransactionsConvenientAPI(TransactionsBase):
s.with_transaction(callback) s.with_transaction(callback)
self.assertFalse(s.in_transaction) self.assertFalse(s.in_transaction)
@client_context.require_test_commands
@client_context.require_transactions
def test_4_retry_backoff_is_enforced(self):
client = client_context.client
coll = client[self.db.name].test
end = start = no_backoff_time = 0
# Make random.random always return 0 (no backoff)
with patch.object(random, "random", return_value=0):
# set fail point to trigger transaction failure and trigger backoff
self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {"times": 13},
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
def callback(session):
coll.insert_one({}, session=session)
start = time.monotonic()
with self.client.start_session() as s:
s.with_transaction(callback)
end = time.monotonic()
no_backoff_time = end - start
# Make random.random always return 1 (max backoff)
with patch.object(random, "random", return_value=1):
# set fail point to trigger transaction failure and trigger backoff
self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {
"times": 13
}, # sufficiently high enough such that the time effect of backoff is noticeable
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
start = time.monotonic()
with self.client.start_session() as s:
s.with_transaction(callback)
end = time.monotonic()
self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2
class TestOptionsInsideTransactionProse(TransactionsBase): class TestOptionsInsideTransactionProse(TransactionsBase):
@client_context.require_transactions @client_context.require_transactions

View File

@ -0,0 +1,342 @@
{
"description": "backpressure-retryable-abort",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "transaction-tests"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "test"
}
},
{
"session": {
"id": "session0",
"client": "client0"
}
}
],
"initialData": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
],
"tests": [
{
"description": "abortTransaction retries if backpressure labels are added",
"operations": [
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"abortTransaction"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
},
{
"description": "abortTransaction is retried maxAttempts=2 times if backpressure labels are added",
"operations": [
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"abortTransaction"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
}
]
}

View File

@ -0,0 +1,359 @@
{
"description": "backpressure-retryable-commit",
"schemaVersion": "1.4",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"sharded",
"replicaset",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "transaction-tests"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "test"
}
},
{
"session": {
"id": "session0",
"client": "client0"
}
}
],
"initialData": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
],
"tests": [
{
"description": "commitTransaction retries if backpressure labels are added",
"runOnRequirements": [
{
"serverless": "forbid"
}
],
"operations": [
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"commitTransaction"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "session0",
"name": "commitTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"commitTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"command": {
"commitTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"command": {
"commitTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": [
{
"_id": 1
}
]
}
]
},
{
"description": "commitTransaction is retried maxAttempts=2 times if backpressure labels are added",
"runOnRequirements": [
{
"serverless": "forbid"
}
],
"operations": [
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"commitTransaction"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "session0",
"name": "commitTransaction",
"expectError": {
"isError": true
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"commitTransaction": 1,
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
},
{
"commandStartedEvent": {
"commandName": "commitTransaction"
}
},
{
"commandStartedEvent": {
"commandName": "commitTransaction"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
}
]
}

View File

@ -0,0 +1,313 @@
{
"description": "backpressure-retryable-reads",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "transaction-tests"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "test"
}
},
{
"session": {
"id": "session0",
"client": "client0"
}
}
],
"initialData": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
],
"tests": [
{
"description": "reads are retried if backpressure labels are added",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "find",
"arguments": {
"filter": {},
"session": "session0"
}
},
{
"object": "session0",
"name": "commitTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"find": "test",
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "find",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"find": "test",
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "find",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
}
]
}
]
},
{
"description": "reads are retried maxAttempts=2 times if backpressure labels are added",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"find"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "find",
"arguments": {
"filter": {},
"session": "session0"
},
"expectError": {
"isError": true
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
}
]
}
]
}
]
}

View File

@ -0,0 +1,439 @@
{
"description": "backpressure-retryable-writes",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "transaction-tests"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "test"
}
},
{
"session": {
"id": "session0",
"client": "client0"
}
}
],
"initialData": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
],
"tests": [
{
"description": "writes are retried if backpressure labels are added",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 2
}
}
},
{
"object": "session0",
"name": "commitTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 1
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": true,
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 2
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "test",
"documents": [
{
"_id": 2
}
],
"ordered": true,
"readConcern": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"abortTransaction": {
"$$exists": false
},
"lsid": {
"$$sessionLsid": "session0"
},
"txnNumber": {
"$numberLong": "1"
},
"startTransaction": {
"$$exists": false
},
"autocommit": false,
"writeConcern": {
"$$exists": false
}
},
"commandName": "commitTransaction",
"databaseName": "admin"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
]
},
{
"description": "writes are retried maxAttempts=2 times if backpressure labels are added",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 1
}
},
"expectResult": {
"$$unsetOrMatches": {
"insertedId": {
"$$unsetOrMatches": 1
}
}
}
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"insert"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 2
}
},
"expectError": {
"isError": true
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"commandName": "abortTransaction"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
},
{
"description": "retry succeeds if backpressure labels are added to the first operation in a transaction",
"operations": [
{
"object": "session0",
"name": "startTransaction"
},
{
"object": "testRunner",
"name": "failPoint",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 112
}
}
}
},
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"session": "session0",
"document": {
"_id": 2
}
}
},
{
"object": "session0",
"name": "abortTransaction"
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"startTransaction": true
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"startTransaction": true
},
"commandName": "insert",
"databaseName": "transaction-tests"
}
},
{
"commandStartedEvent": {
"command": {
"startTransaction": {
"$$exists": false
}
},
"commandName": "abortTransaction",
"databaseName": "admin"
}
}
]
}
],
"outcome": [
{
"collectionName": "test",
"databaseName": "transaction-tests",
"documents": []
}
]
}
]
}

View File

@ -0,0 +1,66 @@
{
"tests": [
{
"description": "maxAdaptiveRetries is parsed correctly",
"uri": "mongodb://example.com/?maxAdaptiveRetries=3",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"maxAdaptiveRetries": 3
}
},
{
"description": "maxAdaptiveRetries=0 is parsed correctly",
"uri": "mongodb://example.com/?maxAdaptiveRetries=0",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"maxAdaptiveRetries": 0
}
},
{
"description": "maxAdaptiveRetries with invalid value causes a warning",
"uri": "mongodb://example.com/?maxAdaptiveRetries=-5",
"valid": true,
"warning": true,
"hosts": null,
"auth": null,
"options": null
},
{
"description": "enableOverloadRetargeting is parsed correctly",
"uri": "mongodb://example.com/?enableOverloadRetargeting=true",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"enableOverloadRetargeting": true
}
},
{
"description": "enableOverloadRetargeting=false is parsed correctly",
"uri": "mongodb://example.com/?enableOverloadRetargeting=false",
"valid": true,
"warning": false,
"hosts": null,
"auth": null,
"options": {
"enableOverloadRetargeting": false
}
},
{
"description": "enableOverloadRetargeting with invalid value causes a warning",
"uri": "mongodb://example.com/?enableOverloadRetargeting=invalid",
"valid": true,
"warning": true,
"hosts": null,
"auth": null,
"options": null
}
]
}

View File

@ -213,6 +213,7 @@ converted_tests = [
"test_bulk.py", "test_bulk.py",
"test_change_stream.py", "test_change_stream.py",
"test_client.py", "test_client.py",
"test_client_backpressure.py",
"test_client_bulk_write.py", "test_client_bulk_write.py",
"test_client_context.py", "test_client_context.py",
"test_client_metadata.py", "test_client_metadata.py",
@ -350,7 +351,7 @@ def translate_async_sleeps(lines: list[str]) -> list[str]:
sleeps = [line for line in lines if "asyncio.sleep" in line] sleeps = [line for line in lines if "asyncio.sleep" in line]
for line in sleeps: for line in sleeps:
res = re.search(r"asyncio.sleep\(([^()]*)\)", line) res = re.search(r"asyncio\.sleep\(\s*(.*?)\)", line)
if res: if res:
old = res[0] old = res[0]
index = lines.index(line) index = lines.index(line)