Merge branch 'master' into spec-resync-04-13-2026

This commit is contained in:
Noah Stapp 2026-04-15 12:05:13 -04:00 committed by GitHub
commit d14dfba36d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 10798 additions and 1809 deletions

View File

@ -94,6 +94,9 @@ do
change-streams|change_streams)
cpjson change-streams/tests/ change_streams/
;;
client-backpressure|client_backpressure)
cpjson client-backpressure/tests client-backpressure
;;
client-side-encryption|csfle|fle)
cpjson client-side-encryption/tests/ client-side-encryption/spec
cpjson client-side-encryption/corpus/ client-side-encryption/corpus

File diff suppressed because it is too large Load Diff

View File

@ -216,4 +216,4 @@ pip install -e ".[test]"
pytest
```
For more advanced testing scenarios, see the [contributing guide](./CONTRIBUTING.md#running-tests-locally).
For more advanced testing scenarios, see the [contributing guide](https://github.com/mongodb/mongo-python-driver/blob/master/CONTRIBUTING.md#running-tests-locally).

View File

@ -14,6 +14,9 @@ PyMongo 4.17 brings a number of changes including:
- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
See <PLACEHOLDER> for examples and more information.
- Added support for MongoDB's Intelligent Workload Management (IWM) and ingress connection rate limiting features.
The driver now gracefully handles write-blocking scenarios and optimizes connection establishment during high-load conditions to maintain application availability.
See <DOCSP-55426> and <DOCSP-57078> for more information.
Changes in Version 4.16.0 (2026/01/07)
--------------------------------------

View File

@ -59,6 +59,7 @@ from pymongo.errors import (
InvalidOperation,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
@ -563,9 +564,17 @@ class _AsyncClientBulk:
error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
retryable_label_error = isinstance(
error, PyMongoError
) and error.has_error_label("RetryableError")
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error):
if retryable and (
retryable_top_level_error
or retryable_network_error
or retryable_label_error
):
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)

View File

@ -135,7 +135,9 @@ Classes
from __future__ import annotations
import asyncio
import collections
import random
import time
import uuid
from collections.abc import Mapping as _Mapping
@ -162,7 +164,9 @@ from pymongo.asynchronous.cursor_base import _ConnectionManager
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
PyMongoError,
WTimeoutError,
@ -427,6 +431,7 @@ class _Transaction:
self.recovery_token = None
self.attempt = 0
self.client = client
self.has_completed_command = False
def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@ -434,6 +439,9 @@ class _Transaction:
def starting(self) -> bool:
return self.state == _TxnState.STARTING
def set_starting(self) -> None:
self.state = _TxnState.STARTING
@property
def pinned_conn(self) -> Optional[AsyncConnection]:
if self.active() and self.conn_mgr:
@ -459,6 +467,7 @@ class _Transaction:
self.sharded = False
self.recovery_token = None
self.attempt = 0
self.has_completed_command = False
def __del__(self) -> None:
if self.conn_mgr:
@ -493,11 +502,29 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
# This limit is non-configurable and was chosen to be twice the 60 second
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
_BACKOFF_MAX = 0.500 # 500ms max backoff
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
def _within_time_limit(start_time: float) -> bool:
def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error
_T = TypeVar("_T")
@ -744,7 +771,17 @@ class AsyncClientSession:
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback
"""
start_time = time.monotonic()
retry = 0
last_error: Optional[BaseException] = None
while True:
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise _make_timeout_error(last_error) from last_error
await asyncio.sleep(backoff)
retry += 1
await self.start_transaction(
read_concern, write_concern, read_preference, max_commit_time_ms
)
@ -752,15 +789,16 @@ class AsyncClientSession:
ret = await callback(self)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
last_error = exc
if self.in_transaction:
await self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
if isinstance(exc, PyMongoError) and exc.has_error_label(
"TransientTransactionError"
):
# Retry the entire transaction.
continue
if _within_time_limit(start_time):
# Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise
if not self.in_transaction:
@ -771,17 +809,18 @@ class AsyncClientSession:
try:
await self.commit_transaction()
except PyMongoError as exc:
if (
exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)
):
last_error = exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit.
continue
if exc.has_error_label("TransientTransactionError") and _within_time_limit(
start_time
):
if exc.has_error_label("TransientTransactionError"):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the entire transaction.
break
raise

View File

@ -20,7 +20,6 @@ from collections import abc
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Callable,
Coroutine,
Generic,
@ -571,11 +570,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
await change_stream._initialize_cursor()
return change_stream
async def _conn_for_writes(
self, session: Optional[AsyncClientSession], operation: str
) -> AsyncContextManager[AsyncConnection]:
return await self._database.client._conn_for_writes(session, operation)
async def _command(
self,
conn: AsyncConnection,
@ -652,7 +646,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if "size" in options:
options["size"] = float(options["size"])
cmd.update(options)
async with await self._conn_for_writes(session, operation=_Op.CREATE) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. "
@ -669,6 +666,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.CREATE)
async def _create(
self,
options: MutableMapping[str, Any],
@ -2240,7 +2239,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments.
"""
names = []
async with await self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
supports_quorum = conn.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]:
@ -2269,7 +2271,11 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write_concern=self._write_concern_for(session),
session=session,
)
return names
return names
return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_INDEXES
)
async def create_index(
self,
@ -2422,7 +2428,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment
await self._drop_index("*", session=session, **kwargs)
@_csot.apply
async def drop_index(
self,
index_or_name: _IndexKeyHint,
@ -2490,7 +2495,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
@ -2500,6 +2508,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES)
async def list_indexes(
self,
session: Optional[AsyncClientSession] = None,
@ -2763,17 +2773,22 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())}
cmd.update(kwargs)
async with await self._conn_for_writes(
session, operation=_Op.CREATE_SEARCH_INDEXES
) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> list[str]:
resp = await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
return [index["name"] for index in resp["indexesCreated"]]
return await self.database.client._retryable_write(
False, inner, session, _Op.CREATE_SEARCH_INDEXES
)
async def drop_search_index(
self,
name: str,
@ -2799,15 +2814,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES)
async def update_search_index(
self,
name: str,
@ -2835,15 +2856,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
async with await self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> None:
await self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
await self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX)
async def options(
self,
session: Optional[AsyncClientSession] = None,
@ -2918,6 +2945,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session,
retryable=not cmd._performs_write,
operation=_Op.AGGREGATE,
is_aggregate_write=cmd._performs_write,
)
async def aggregate(
@ -3123,17 +3151,21 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session)
client = self._database.client
async with await self._conn_for_writes(session, operation=_Op.RENAME) as conn:
async with self._database.client._tmp_session(session) as s:
return await conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=s,
client=self._database.client,
)
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> MutableMapping[str, Any]:
return await conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=session,
client=client,
)
return await client._retryable_write(False, inner, session, _Op.RENAME)
async def distinct(
self,

View File

@ -931,14 +931,15 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
async with await self._client._conn_for_reads(
read_preference, session, operation=command_name
) as (
connection,
read_preference,
):
async def inner(
session: Optional[AsyncClientSession],
_server: Server,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> Union[dict[str, Any], _CodecDocumentType]:
return await self._command(
connection,
conn,
command,
value,
check,
@ -949,6 +950,10 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
**kwargs,
)
return await self._client._retryable_read(
inner, read_preference, session, command_name, None, False, is_run_command=True
)
@_csot.apply
async def cursor_command(
self,
@ -1016,17 +1021,17 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
async with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None:
read_preference = (
tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY
async with await self._client._conn_for_reads(
read_preference, tmp_session, command_name
) as (
conn,
read_preference,
):
async def inner(
session: Optional[AsyncClientSession],
_server: Server,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> AsyncCommandCursor[_DocumentType]:
response = await self._command(
conn,
command,
@ -1035,7 +1040,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
None,
read_preference,
opts,
session=tmp_session,
session=session,
**kwargs,
)
coll = self.get_collection("$cmd", read_preference=read_preference)
@ -1045,7 +1050,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
response["cursor"],
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
session=session,
comment=comment,
)
await cmd_cursor._maybe_pin_connection(conn)
@ -1053,6 +1058,10 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
else:
raise InvalidOperation("Command does not return a cursor.")
return await self.client._retryable_read(
inner, read_preference, tmp_session, command_name, None, False
)
async def _retryable_read_command(
self,
command: Union[str, MutableMapping[str, Any]],
@ -1254,9 +1263,11 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
command["comment"] = comment
async with await self._client._conn_for_writes(session, operation=_Op.DROP) as connection:
async def inner(
session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool
) -> dict[str, Any]:
return await self._command(
connection,
conn,
command,
allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session),
@ -1264,6 +1275,8 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
session=session,
)
return await self.client._retryable_write(False, inner, session, _Op.DROP)
@_csot.apply
async def drop_collection(
self,

View File

@ -17,8 +17,11 @@ from __future__ import annotations
import asyncio
import builtins
import functools
import random
import socket
import sys
import time as time # noqa: PLC0414 # needed in sync version
from typing import (
Any,
Callable,
@ -26,6 +29,8 @@ from typing import (
cast,
)
from pymongo import _csot
from pymongo.common import MAX_ADAPTIVE_RETRIES
from pymongo.errors import (
OperationFailure,
)
@ -38,6 +43,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F:
@functools.wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.asynchronous.pool import AsyncConnection
@ -70,6 +76,46 @@ def _handle_reauth(func: F) -> F:
return cast(F, inner)
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10
def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> float:
jitter = random.random() # noqa: S311
return jitter * min(initial_delay * (2**attempt), max_delay)
class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter."""
def __init__(
self,
attempts: int = MAX_ADAPTIVE_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max
def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given attempt."""
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
async def should_retry(self, attempt: int, delay: float) -> bool:
"""Return if we have retry attempts remaining and the next backoff would not exceed a timeout."""
if attempt > self.attempts:
return False
if _csot.get_timeout():
if time.monotonic() + delay > _csot.get_deadline():
return False
return True
async def _getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[

View File

@ -35,6 +35,7 @@ from __future__ import annotations
import asyncio
import contextlib
import os
import time as time # noqa: PLC0414 # needed in sync version
import warnings
import weakref
from collections import defaultdict
@ -67,6 +68,9 @@ from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterCh
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.helpers import (
_RetryPolicy,
)
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.client_options import ClientOptions
@ -610,8 +614,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for
details.
| **Overload retry options:**
- `max_adaptive_retries`: (int) How many retries to allow for overload errors. Defaults to ``2``.
- `enable_overload_retargeting`: (boolean) Whether overload retargeting is enabled for this client.
If enabled, server overload errors will cause retry attempts to select a server that has not yet returned an overload error, if possible.
Defaults to ``False``.
.. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_.
.. versionchanged:: 4.17
Added the ``max_adaptive_retries`` and ``enable_overload_retargeting`` URI and keyword arguments.
.. versionchanged:: 4.5
Added the ``serverMonitoringMode`` keyword argument.
@ -879,11 +893,14 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv:
self._init_background()
@ -1991,6 +2008,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref: Optional[_ServerMode] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Internal retryable helper for all client transactions.
@ -2002,6 +2021,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Server Address, defaults to None
:param read_pref: Topology of read operation, defaults to None
:param retryable: If the operation should be retried once, defaults to None
:param is_run_command: If this is a runCommand operation, defaults to False
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
:return: Output of the calling func()
"""
@ -2016,6 +2037,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
).run()
async def _retryable_read(
@ -2027,6 +2050,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_Address] = None,
retryable: bool = True,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Execute an operation with consecutive retries if possible
@ -2042,6 +2067,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Optional address when sending a message, defaults to None
:param retryable: if we should attempt retries
(may not always be supported even if supplied), defaults to False
:param is_run_command: If this is a runCommand operation, defaults to False.
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
"""
# Ensure that the client supports retrying on reads and there is no session in
@ -2060,6 +2087,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
)
async def _retryable_write(
@ -2454,15 +2483,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}"
)
async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
await self[name]._command(
conn,
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
await self[name].command(
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
@_csot.apply
async def bulk_write(
@ -2746,12 +2773,15 @@ class _ClientConnectionRetryable(Generic[T]):
address: Optional[_Address] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
):
self._last_error: Optional[Exception] = None
self._retrying = False
self._always_retryable = False
self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client
self._retry_policy = mongo_client._retry_policy
self._func = func
self._bulk = bulk
self._session = session
@ -2767,6 +2797,8 @@ class _ClientConnectionRetryable(Generic[T]):
self._operation = operation
self._operation_id = operation_id
self._attempt_number = 0
self._is_run_command = is_run_command
self._is_aggregate_write = is_aggregate_write
async def run(self) -> T:
"""Runs the supplied func() and attempts a retry
@ -2786,7 +2818,13 @@ class _ClientConnectionRetryable(Generic[T]):
while True:
self._check_last_error(check_csot=True)
try:
return await self._read() if self._is_read else await self._write()
res = await self._read() if self._is_read else await self._write()
# Track whether the transaction has completed a command.
# If we need to apply backpressure to the first command,
# we will need to revert back to starting state.
if self._session is not None and self._session.in_transaction:
self._session._transaction.has_completed_command = True
return res
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
@ -2797,37 +2835,76 @@ class _ClientConnectionRetryable(Generic[T]):
# most likely be a waste of time.
raise
except PyMongoError as exc:
always_retryable = False
overloaded = False
exc_to_check = exc
if self._is_run_command and not (
self._client.options.retry_reads and self._client.options.retry_writes
):
raise
if self._is_aggregate_write and not self._client.options.retry_writes:
raise
# Execute specialized catch on read
if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
overloaded = exc.has_error_label("SystemOverloadedError")
always_retryable = exc.has_error_label("RetryableError") and overloaded
if not self._client.options.retry_reads or (
not always_retryable
and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
)
)
):
raise
self._retrying = True
self._last_error = exc
self._attempt_number += 1
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if (
overloaded
and self._session is not None
and self._session.in_transaction
):
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
else:
raise
# Specialized catch on write operation
if not self._is_read:
if not self._retryable:
if isinstance(exc, ClientBulkWriteException) and isinstance(
exc.error, PyMongoError
):
exc_to_check = exc.error
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded
# Always retry abortTransaction and commitTransaction up to once
if self._operation not in ["abortTransaction", "commitTransaction"] and (
not self._client.options.retry_writes
or not (self._retryable or always_retryable)
):
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
if retryable_write_label or always_retryable:
assert self._session
await self._session._unpin()
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
if exc.has_error_label("NoWritesPerformed") and self._last_error:
if not always_retryable and (
not retryable_write_label or self._is_not_eligible_for_retry()
):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
@ -2836,18 +2913,34 @@ class _ClientConnectionRetryable(Generic[T]):
self._bulk.retrying = True
else:
self._retrying = True
if not exc.has_error_label("NoWritesPerformed"):
if not exc_to_check.has_error_label("NoWritesPerformed"):
self._last_error = exc
if self._last_error is None:
self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
if (
self._server is not None
and self._client.topology_description.topology_type_name == "Sharded"
or exc.has_error_label("SystemOverloadedError")
if self._server is not None and (
self._client.topology_description.topology_type_name == "Sharded"
or (overloaded and self._client.options.enable_overload_retargeting)
):
self._deprioritized_servers.append(self._server)
self._always_retryable = always_retryable
if overloaded:
delay = self._retry_policy.backoff(self._attempt_number)
if not await self._retry_policy.should_retry(self._attempt_number, delay):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
await asyncio.sleep(delay)
def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
@ -2909,7 +3002,7 @@ class _ClientConnectionRetryable(Generic[T]):
and conn.supports_sessions
)
is_mongos = conn.is_mongos
if not sessions_supported:
if not self._always_retryable and not sessions_supported:
# A retry is not possible because this server does
# not support sessions raise the last error.
self._check_last_error()
@ -2941,7 +3034,7 @@ class _ClientConnectionRetryable(Generic[T]):
conn,
read_pref,
):
if self._retrying and not self._retryable:
if self._retrying and not self._retryable and not self._always_retryable:
self._check_last_error()
if self._retrying:
_debug_log(

View File

@ -19,6 +19,8 @@ import collections
import contextlib
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -52,10 +54,12 @@ from pymongo.errors import ( # type:ignore[attr-defined]
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
@ -250,6 +254,7 @@ class AsyncConnection:
cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake
awaitable = False
cmd["backpressure"] = True
if performing_handshake:
self.performed_handshake = True
cmd["client"] = self.opts.metadata
@ -752,8 +757,8 @@ class Pool:
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _async_create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
@ -986,6 +991,21 @@ class Pool:
self.requests -= 1
self.size_cond.notify()
def _handle_connection_error(self, error: BaseException) -> None:
# Handle system overload condition for non-sdam pools.
# Look for errors of type AutoReconnect and add error labels if appropriate.
if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout):
return
assert isinstance(error, AutoReconnect) # Appease type checker.
# If the original error was a DNS, certificate, or SSL error, ignore it.
if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)):
# End of file errors are excluded, because the server may have disconnected
# during the handshake.
if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)):
return
error._add_error_label("SystemOverloadedError")
error._add_error_label("RetryableError")
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection:
"""Connect to Mongo and return a new AsyncConnection.
@ -1037,10 +1057,10 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR,
)
self._handle_connection_error(error)
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)
raise
conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
@ -1049,18 +1069,22 @@ class Pool:
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
@ -1389,8 +1413,8 @@ class Pool:
:class:`~pymongo.errors.AutoReconnect` exceptions on server
hiccups, etc. We only check if the socket was closed by an external
error if it has been > 1 second since the socket was checked into the
pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway.
pool to keep performance reasonable -
we can't avoid AutoReconnects completely anyway.
"""
idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one.
@ -1401,8 +1425,9 @@ class Pool:
await conn.close_conn(ConnectionClosedReason.IDLE)
return True
if self._check_interval_seconds is not None and (
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds
check_interval_seconds = self._check_interval_seconds
if check_interval_seconds is not None and (
check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds
):
if conn.conn_closed():
await conn.close_conn(ConnectionClosedReason.ERROR)

View File

@ -913,7 +913,9 @@ class Topology:
# Clear the pool.
await server.reset(service_id)
elif isinstance(error, ConnectionFailure):
if isinstance(error, WaitQueueTimeoutError):
if isinstance(error, WaitQueueTimeoutError) or (
error.has_error_label("SystemOverloadedError")
):
return
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."

View File

@ -235,6 +235,16 @@ class ClientOptions:
self.__server_monitoring_mode = options.get(
"servermonitoringmode", common.SERVER_MONITORING_MODE
)
self.__max_adaptive_retries = (
options.get("max_adaptive_retries", common.MAX_ADAPTIVE_RETRIES)
if "max_adaptive_retries" in options
else options.get("maxadaptiveretries", common.MAX_ADAPTIVE_RETRIES)
)
self.__enable_overload_retargeting = (
options.get("enable_overload_retargeting", common.ENABLE_OVERLOAD_RETARGETING)
if "enable_overload_retargeting" in options
else options.get("enableoverloadretargeting", common.ENABLE_OVERLOAD_RETARGETING)
)
@property
def _options(self) -> Mapping[str, Any]:
@ -346,3 +356,19 @@ class ClientOptions:
.. versionadded:: 4.5
"""
return self.__server_monitoring_mode
@property
def max_adaptive_retries(self) -> int:
"""The configured maxAdaptiveRetries option.
.. versionadded:: 4.17
"""
return self.__max_adaptive_retries
@property
def enable_overload_retargeting(self) -> bool:
"""The configured enableOverloadRetargeting option.
.. versionadded:: 4.17
"""
return self.__enable_overload_retargeting

View File

@ -140,6 +140,12 @@ SRV_SERVICE_NAME = "mongodb"
# Default value for serverMonitoringMode
SERVER_MONITORING_MODE = "auto" # poll/stream/auto
# Default value for max adaptive retries
MAX_ADAPTIVE_RETRIES = 2
# Default value for enableOverloadRetargeting
ENABLE_OVERLOAD_RETARGETING = False
# Auth mechanism properties that must raise an error instead of warning if they invalidate.
_MECH_PROP_MUST_RAISE = ["CANONICALIZE_HOST_NAME"]
@ -717,6 +723,8 @@ URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = {
"srvmaxhosts": validate_non_negative_integer,
"timeoutms": validate_timeoutms,
"servermonitoringmode": validate_server_monitoring_mode,
"maxadaptiveretries": validate_non_negative_integer,
"enableoverloadretargeting": validate_boolean_or_string,
}
# Dictionary where keys are the names of URI options specific to pymongo,
@ -750,6 +758,8 @@ KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = {
"server_selector": validate_is_callable_or_none,
"auto_encryption_opts": validate_auto_encryption_opts_or_none,
"authoidcallowedhosts": validate_list,
"max_adaptive_retries": validate_non_negative_integer,
"enable_overload_retargeting": validate_boolean_or_string,
}
# Dictionary where keys are any URI option name, and values are the

View File

@ -59,6 +59,7 @@ from pymongo.errors import (
InvalidOperation,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
@ -561,9 +562,17 @@ class _ClientBulk:
error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
retryable_label_error = isinstance(
error, PyMongoError
) and error.has_error_label("RetryableError")
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error):
if retryable and (
retryable_top_level_error
or retryable_network_error
or retryable_label_error
):
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)

View File

@ -136,6 +136,7 @@ Classes
from __future__ import annotations
import collections
import random
import time
import uuid
from collections.abc import Mapping as _Mapping
@ -160,7 +161,9 @@ from pymongo import _csot
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
PyMongoError,
WTimeoutError,
@ -426,6 +429,7 @@ class _Transaction:
self.recovery_token = None
self.attempt = 0
self.client = client
self.has_completed_command = False
def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@ -433,6 +437,9 @@ class _Transaction:
def starting(self) -> bool:
return self.state == _TxnState.STARTING
def set_starting(self) -> None:
self.state = _TxnState.STARTING
@property
def pinned_conn(self) -> Optional[Connection]:
if self.active() and self.conn_mgr:
@ -458,6 +465,7 @@ class _Transaction:
self.sharded = False
self.recovery_token = None
self.attempt = 0
self.has_completed_command = False
def __del__(self) -> None:
if self.conn_mgr:
@ -492,11 +500,29 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
# This limit is non-configurable and was chosen to be twice the 60 second
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
_BACKOFF_MAX = 0.500 # 500ms max backoff
_BACKOFF_INITIAL = 0.005 # 5ms initial backoff
def _within_time_limit(start_time: float) -> bool:
def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
"""Are we within the with_transaction retry limit?"""
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
return False
return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error
_T = TypeVar("_T")
@ -743,21 +769,32 @@ class ClientSession:
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback
"""
start_time = time.monotonic()
retry = 0
last_error: Optional[BaseException] = None
while True:
if retry: # Implement exponential backoff on retry.
jitter = random.random() # noqa: S311
backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX)
if not _within_time_limit(start_time, backoff):
assert last_error is not None
raise _make_timeout_error(last_error) from last_error
time.sleep(backoff)
retry += 1
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
try:
ret = callback(self)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
last_error = exc
if self.in_transaction:
self.abort_transaction()
if (
isinstance(exc, PyMongoError)
and exc.has_error_label("TransientTransactionError")
and _within_time_limit(start_time)
if isinstance(exc, PyMongoError) and exc.has_error_label(
"TransientTransactionError"
):
# Retry the entire transaction.
continue
if _within_time_limit(start_time):
# Retry the entire transaction.
continue
raise _make_timeout_error(last_error) from exc
raise
if not self.in_transaction:
@ -768,17 +805,18 @@ class ClientSession:
try:
self.commit_transaction()
except PyMongoError as exc:
if (
exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)
):
last_error = exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit.
continue
if exc.has_error_label("TransientTransactionError") and _within_time_limit(
start_time
):
if exc.has_error_label("TransientTransactionError"):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the entire transaction.
break
raise

View File

@ -21,7 +21,6 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Generic,
Iterable,
Iterator,
@ -572,11 +571,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
change_stream._initialize_cursor()
return change_stream
def _conn_for_writes(
self, session: Optional[ClientSession], operation: str
) -> ContextManager[Connection]:
return self._database.client._conn_for_writes(session, operation)
def _command(
self,
conn: Connection,
@ -653,7 +647,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if "size" in options:
options["size"] = float(options["size"])
cmd.update(options)
with self._conn_for_writes(session, operation=_Op.CREATE) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
if qev2_required and conn.max_wire_version < 21:
raise ConfigurationError(
"Driver support of Queryable Encryption is incompatible with server. "
@ -670,6 +667,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
self.database.client._retryable_write(False, inner, session, _Op.CREATE)
def _create(
self,
options: MutableMapping[str, Any],
@ -2237,7 +2236,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
command (like maxTimeMS) can be passed as keyword arguments.
"""
names = []
with self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> list[str]:
supports_quorum = conn.max_wire_version >= 9
def gen_indexes() -> Iterator[Mapping[str, Any]]:
@ -2266,7 +2268,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern=self._write_concern_for(session),
session=session,
)
return names
return names
return self.database.client._retryable_write(False, inner, session, _Op.CREATE_INDEXES)
def create_index(
self,
@ -2419,7 +2423,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment
self._drop_index("*", session=session, **kwargs)
@_csot.apply
def drop_index(
self,
index_or_name: _IndexKeyHint,
@ -2487,7 +2490,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command(
conn,
cmd,
@ -2497,6 +2503,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session=session,
)
self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES)
def list_indexes(
self,
session: Optional[ClientSession] = None,
@ -2760,15 +2768,22 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())}
cmd.update(kwargs)
with self._conn_for_writes(session, operation=_Op.CREATE_SEARCH_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> list[str]:
resp = self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
return [index["name"] for index in resp["indexesCreated"]]
return self.database.client._retryable_write(
False, inner, session, _Op.CREATE_SEARCH_INDEXES
)
def drop_search_index(
self,
name: str,
@ -2794,15 +2809,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES)
def update_search_index(
self,
name: str,
@ -2830,15 +2851,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
cmd.update(kwargs)
if comment is not None:
cmd["comment"] = comment
with self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> None:
self._command(
conn,
cmd,
read_preference=ReadPreference.PRIMARY,
allowable_errors=["ns not found", 26],
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=session,
)
self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX)
def options(
self,
session: Optional[ClientSession] = None,
@ -2911,6 +2938,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session,
retryable=not cmd._performs_write,
operation=_Op.AGGREGATE,
is_aggregate_write=cmd._performs_write,
)
def aggregate(
@ -3116,17 +3144,21 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
cmd["comment"] = comment
write_concern = self._write_concern_for_cmd(cmd, session)
client = self._database.client
with self._conn_for_writes(session, operation=_Op.RENAME) as conn:
with self._database.client._tmp_session(session) as s:
return conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=s,
client=self._database.client,
)
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> MutableMapping[str, Any]:
return conn.command(
"admin",
cmd,
write_concern=write_concern,
parse_write_concern_error=True,
session=session,
client=client,
)
return client._retryable_write(False, inner, session, _Op.RENAME)
def distinct(
self,

View File

@ -931,12 +931,15 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if read_preference is None:
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
with self._client._conn_for_reads(read_preference, session, operation=command_name) as (
connection,
read_preference,
):
def inner(
session: Optional[ClientSession],
_server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> Union[dict[str, Any], _CodecDocumentType]:
return self._command(
connection,
conn,
command,
value,
check,
@ -947,6 +950,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
**kwargs,
)
return self._client._retryable_read(
inner, read_preference, session, command_name, None, False, is_run_command=True
)
@_csot.apply
def cursor_command(
self,
@ -1014,15 +1021,17 @@ class Database(common.BaseObject, Generic[_DocumentType]):
with self._client._tmp_session(session) as tmp_session:
opts = codec_options or DEFAULT_CODEC_OPTIONS
if read_preference is None:
read_preference = (
tmp_session and tmp_session._txn_read_preference()
) or ReadPreference.PRIMARY
with self._client._conn_for_reads(read_preference, tmp_session, command_name) as (
conn,
read_preference,
):
def inner(
session: Optional[ClientSession],
_server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]:
response = self._command(
conn,
command,
@ -1031,7 +1040,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
None,
read_preference,
opts,
session=tmp_session,
session=session,
**kwargs,
)
coll = self.get_collection("$cmd", read_preference=read_preference)
@ -1041,7 +1050,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
response["cursor"],
conn.address,
max_await_time_ms=max_await_time_ms,
session=tmp_session,
session=session,
comment=comment,
)
cmd_cursor._maybe_pin_connection(conn)
@ -1049,6 +1058,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
else:
raise InvalidOperation("Command does not return a cursor.")
return self.client._retryable_read(
inner, read_preference, tmp_session, command_name, None, False
)
def _retryable_read_command(
self,
command: Union[str, MutableMapping[str, Any]],
@ -1247,9 +1260,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
command["comment"] = comment
with self._client._conn_for_writes(session, operation=_Op.DROP) as connection:
def inner(
session: Optional[ClientSession], conn: Connection, _retryable_write: bool
) -> dict[str, Any]:
return self._command(
connection,
conn,
command,
allowable_errors=["ns not found", 26],
write_concern=self._write_concern_for(session),
@ -1257,6 +1272,8 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session=session,
)
return self.client._retryable_write(False, inner, session, _Op.DROP)
@_csot.apply
def drop_collection(
self,

View File

@ -17,8 +17,11 @@ from __future__ import annotations
import asyncio
import builtins
import functools
import random
import socket
import sys
import time as time # noqa: PLC0414 # needed in sync version
from typing import (
Any,
Callable,
@ -26,6 +29,8 @@ from typing import (
cast,
)
from pymongo import _csot
from pymongo.common import MAX_ADAPTIVE_RETRIES
from pymongo.errors import (
OperationFailure,
)
@ -38,6 +43,7 @@ F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F:
@functools.wraps(func)
def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.message import _BulkWriteContext
@ -70,6 +76,46 @@ def _handle_reauth(func: F) -> F:
return cast(F, inner)
_BACKOFF_INITIAL = 0.1
_BACKOFF_MAX = 10
def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> float:
jitter = random.random() # noqa: S311
return jitter * min(initial_delay * (2**attempt), max_delay)
class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter."""
def __init__(
self,
attempts: int = MAX_ADAPTIVE_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max
def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given attempt."""
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
def should_retry(self, attempt: int, delay: float) -> bool:
"""Return if we have retry attempts remaining and the next backoff would not exceed a timeout."""
if attempt > self.attempts:
return False
if _csot.get_timeout():
if time.monotonic() + delay > _csot.get_deadline():
return False
return True
def _getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[

View File

@ -35,6 +35,7 @@ from __future__ import annotations
import asyncio
import contextlib
import os
import time as time # noqa: PLC0414 # needed in sync version
import warnings
import weakref
from collections import defaultdict
@ -110,6 +111,9 @@ from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.helpers import (
_RetryPolicy,
)
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
@ -610,8 +614,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
client to use Stable API. See `versioned API <https://www.mongodb.com/docs/manual/reference/stable-api/#what-is-the-stable-api--and-should-you-use-it->`_ for
details.
| **Overload retry options:**
- `max_adaptive_retries`: (int) How many retries to allow for overload errors. Defaults to ``2``.
- `enable_overload_retargeting`: (boolean) Whether overload retargeting is enabled for this client.
If enabled, server overload errors will cause retry attempts to select a server that has not yet returned an overload error, if possible.
Defaults to ``False``.
.. seealso:: The MongoDB documentation on `connections <https://dochub.mongodb.org/core/connections>`_.
.. versionchanged:: 4.17
Added the ``max_adaptive_retries`` and ``enable_overload_retargeting`` URI and keyword arguments.
.. versionchanged:: 4.5
Added the ``serverMonitoringMode`` keyword argument.
@ -879,11 +893,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
if not is_srv:
self._init_background()
@ -1987,6 +2004,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref: Optional[_ServerMode] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Internal retryable helper for all client transactions.
@ -1998,6 +2017,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Server Address, defaults to None
:param read_pref: Topology of read operation, defaults to None
:param retryable: If the operation should be retried once, defaults to None
:param is_run_command: If this is a runCommand operation, defaults to False
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
:return: Output of the calling func()
"""
@ -2012,6 +2033,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
).run()
def _retryable_read(
@ -2023,6 +2046,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address: Optional[_Address] = None,
retryable: bool = True,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
) -> T:
"""Execute an operation with consecutive retries if possible
@ -2038,6 +2063,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
:param address: Optional address when sending a message, defaults to None
:param retryable: if we should attempt retries
(may not always be supported even if supplied), defaults to False
:param is_run_command: If this is a runCommand operation, defaults to False.
:param is_aggregate_write: If this is a aggregate operation with a write, defaults to False.
"""
# Ensure that the client supports retrying on reads and there is no session in
@ -2056,6 +2083,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
read_pref=read_pref,
retryable=retryable,
operation_id=operation_id,
is_run_command=is_run_command,
is_aggregate_write=is_aggregate_write,
)
def _retryable_write(
@ -2444,15 +2473,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
f"name_or_database must be an instance of str or a Database, not {type(name)}"
)
with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
self[name]._command(
conn,
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
self[name].command(
{"dropDatabase": 1, "comment": comment},
read_preference=ReadPreference.PRIMARY,
write_concern=self._write_concern_for(session),
parse_write_concern_error=True,
session=session,
)
@_csot.apply
def bulk_write(
@ -2736,12 +2763,15 @@ class _ClientConnectionRetryable(Generic[T]):
address: Optional[_Address] = None,
retryable: bool = False,
operation_id: Optional[int] = None,
is_run_command: bool = False,
is_aggregate_write: bool = False,
):
self._last_error: Optional[Exception] = None
self._retrying = False
self._always_retryable = False
self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client
self._retry_policy = mongo_client._retry_policy
self._func = func
self._bulk = bulk
self._session = session
@ -2757,6 +2787,8 @@ class _ClientConnectionRetryable(Generic[T]):
self._operation = operation
self._operation_id = operation_id
self._attempt_number = 0
self._is_run_command = is_run_command
self._is_aggregate_write = is_aggregate_write
def run(self) -> T:
"""Runs the supplied func() and attempts a retry
@ -2776,7 +2808,13 @@ class _ClientConnectionRetryable(Generic[T]):
while True:
self._check_last_error(check_csot=True)
try:
return self._read() if self._is_read else self._write()
res = self._read() if self._is_read else self._write()
# Track whether the transaction has completed a command.
# If we need to apply backpressure to the first command,
# we will need to revert back to starting state.
if self._session is not None and self._session.in_transaction:
self._session._transaction.has_completed_command = True
return res
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
@ -2787,37 +2825,76 @@ class _ClientConnectionRetryable(Generic[T]):
# most likely be a waste of time.
raise
except PyMongoError as exc:
always_retryable = False
overloaded = False
exc_to_check = exc
if self._is_run_command and not (
self._client.options.retry_reads and self._client.options.retry_writes
):
raise
if self._is_aggregate_write and not self._client.options.retry_writes:
raise
# Execute specialized catch on read
if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
overloaded = exc.has_error_label("SystemOverloadedError")
always_retryable = exc.has_error_label("RetryableError") and overloaded
if not self._client.options.retry_reads or (
not always_retryable
and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
)
)
):
raise
self._retrying = True
self._last_error = exc
self._attempt_number += 1
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if (
overloaded
and self._session is not None
and self._session.in_transaction
):
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
else:
raise
# Specialized catch on write operation
if not self._is_read:
if not self._retryable:
if isinstance(exc, ClientBulkWriteException) and isinstance(
exc.error, PyMongoError
):
exc_to_check = exc.error
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded
# Always retry abortTransaction and commitTransaction up to once
if self._operation not in ["abortTransaction", "commitTransaction"] and (
not self._client.options.retry_writes
or not (self._retryable or always_retryable)
):
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
if retryable_write_label or always_retryable:
assert self._session
self._session._unpin()
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
if exc.has_error_label("NoWritesPerformed") and self._last_error:
if not always_retryable and (
not retryable_write_label or self._is_not_eligible_for_retry()
):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
@ -2826,18 +2903,34 @@ class _ClientConnectionRetryable(Generic[T]):
self._bulk.retrying = True
else:
self._retrying = True
if not exc.has_error_label("NoWritesPerformed"):
if not exc_to_check.has_error_label("NoWritesPerformed"):
self._last_error = exc
if self._last_error is None:
self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
transaction.set_starting()
transaction.attempt = 0
if (
self._server is not None
and self._client.topology_description.topology_type_name == "Sharded"
or exc.has_error_label("SystemOverloadedError")
if self._server is not None and (
self._client.topology_description.topology_type_name == "Sharded"
or (overloaded and self._client.options.enable_overload_retargeting)
):
self._deprioritized_servers.append(self._server)
self._always_retryable = always_retryable
if overloaded:
delay = self._retry_policy.backoff(self._attempt_number)
if not self._retry_policy.should_retry(self._attempt_number, delay):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
time.sleep(delay)
def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
@ -2899,7 +2992,7 @@ class _ClientConnectionRetryable(Generic[T]):
and conn.supports_sessions
)
is_mongos = conn.is_mongos
if not sessions_supported:
if not self._always_retryable and not sessions_supported:
# A retry is not possible because this server does
# not support sessions raise the last error.
self._check_last_error()
@ -2931,7 +3024,7 @@ class _ClientConnectionRetryable(Generic[T]):
conn,
read_pref,
):
if self._retrying and not self._retryable:
if self._retrying and not self._retryable and not self._always_retryable:
self._check_last_error()
if self._retrying:
_debug_log(

View File

@ -19,6 +19,8 @@ import collections
import contextlib
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -49,10 +51,12 @@ from pymongo.errors import ( # type:ignore[attr-defined]
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
@ -250,6 +254,7 @@ class Connection:
cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake
awaitable = False
cmd["backpressure"] = True
if performing_handshake:
self.performed_handshake = True
cmd["client"] = self.opts.metadata
@ -750,8 +755,8 @@ class Pool:
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._max_connecting = self.opts.max_connecting
self._client_id = client_id
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
@ -982,6 +987,21 @@ class Pool:
self.requests -= 1
self.size_cond.notify()
def _handle_connection_error(self, error: BaseException) -> None:
# Handle system overload condition for non-sdam pools.
# Look for errors of type AutoReconnect and add error labels if appropriate.
if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout):
return
assert isinstance(error, AutoReconnect) # Appease type checker.
# If the original error was a DNS, certificate, or SSL error, ignore it.
if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)):
# End of file errors are excluded, because the server may have disconnected
# during the handshake.
if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)):
return
error._add_error_label("SystemOverloadedError")
error._add_error_label("RetryableError")
def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
"""Connect to Mongo and return a new Connection.
@ -1033,10 +1053,10 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR,
)
self._handle_connection_error(error)
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)
raise
conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
@ -1045,18 +1065,22 @@ class Pool:
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)
conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
except BaseException as e:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR)
raise
@ -1385,8 +1409,8 @@ class Pool:
:class:`~pymongo.errors.AutoReconnect` exceptions on server
hiccups, etc. We only check if the socket was closed by an external
error if it has been > 1 second since the socket was checked into the
pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway.
pool to keep performance reasonable -
we can't avoid AutoReconnects completely anyway.
"""
idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one.
@ -1397,8 +1421,9 @@ class Pool:
conn.close_conn(ConnectionClosedReason.IDLE)
return True
if self._check_interval_seconds is not None and (
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds
check_interval_seconds = self._check_interval_seconds
if check_interval_seconds is not None and (
check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds
):
if conn.conn_closed():
conn.close_conn(ConnectionClosedReason.ERROR)

View File

@ -911,7 +911,9 @@ class Topology:
# Clear the pool.
server.reset(service_id)
elif isinstance(error, ConnectionFailure):
if isinstance(error, WaitQueueTimeoutError):
if isinstance(error, WaitQueueTimeoutError) or (
error.has_error_label("SystemOverloadedError")
):
return
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."

View File

@ -652,6 +652,38 @@ class AsyncClientUnitTest(AsyncUnitTest):
with self.assertWarns(UserWarning):
self.simple_client(multi_host)
async def test_max_adaptive_retries(self):
# Assert that max adaptive retries defaults to 2.
c = self.simple_client(connect=False)
self.assertEqual(c.options.max_adaptive_retries, 2)
# Assert that max adaptive retries can be configured through connection or client options.
c = self.simple_client(connect=False, max_adaptive_retries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(connect=False, maxAdaptiveRetries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(host="mongodb://localhost/?maxAdaptiveRetries=10", connect=False)
self.assertEqual(c.options.max_adaptive_retries, 10)
async def test_enable_overload_retargeting(self):
# Assert that overload retargeting defaults to false.
c = self.simple_client(connect=False)
self.assertFalse(c.options.enable_overload_retargeting)
# Assert that overload retargeting can be enabled through connection or client options.
c = self.simple_client(connect=False, enable_overload_retargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(connect=False, enableOverloadRetargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(
host="mongodb://localhost/?enableOverloadRetargeting=true", connect=False
)
self.assertTrue(c.options.enable_overload_retargeting)
class TestClient(AsyncIntegrationTest):
def test_multiple_uris(self):
@ -1034,7 +1066,7 @@ class TestClient(AsyncIntegrationTest):
db_names = await self.client.list_database_names()
self.assertIn("pymongo_test", db_names)
self.assertIn("pymongo_test_mike", db_names)
self.assertEqual(db_names, cmd_names)
self.assertCountEqual(db_names, cmd_names)
async def test_drop_database(self):
with self.assertRaises(TypeError):

View File

@ -0,0 +1,312 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Client Backpressure spec."""
from __future__ import annotations
import os
import pathlib
import sys
from time import perf_counter
from unittest.mock import patch
from pymongo.common import MAX_ADAPTIVE_RETRIES
sys.path[0:0] = [""]
from test.asynchronous import (
AsyncIntegrationTest,
async_client_context,
unittest,
)
from test.asynchronous.unified_format import generate_test_classes
from test.utils_shared import EventListener, OvertCommandListener
from pymongo.errors import OperationFailure, PyMongoError
_IS_SYNC = False
# Mock a system overload error.
mock_overload_error = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find", "insert", "update"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def get_mock_overload_error(times: int):
error = mock_overload_error.copy()
error["mode"] = {"times": times}
return error
class TestBackpressure(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_command(self):
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.command("find", "t")
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.command("find", "t")
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_find(self):
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.find_one()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.find_one()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_insert_one(self):
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.insert_one({"x": 1})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.insert_one({"x": 1})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_update_many(self):
# Even though update_many is not a retryable write operation, it will
# still be retried via the "RetryableError" error label.
await self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
async with self.fail_point(fail_many):
await self.db.t.update_many({}, {"$set": {"x": 2}})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await self.db.t.update_many({}, {"$set": {"x": 2}})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@async_client_context.require_failCommand_appName
async def test_retry_overload_error_getMore(self):
coll = self.db.t
await coll.insert_many([{"x": 1} for _ in range(10)])
# Ensure command is retried on overload error.
fail_many = {
"configureFailPoint": "failCommand",
"mode": {"times": MAX_ADAPTIVE_RETRIES},
"data": {
"failCommands": ["getMore"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
cursor = coll.find(batch_size=2)
await cursor.next()
async with self.fail_point(fail_many):
await cursor.to_list()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = fail_many.copy()
fail_too_many["mode"] = {"times": MAX_ADAPTIVE_RETRIES + 1}
cursor = coll.find(batch_size=2)
await cursor.next()
async with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
await cursor.to_list()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# Prose tests.
class AsyncTestClientBackpressure(AsyncIntegrationTest):
listener: EventListener
@classmethod
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
@async_client_context.require_connection
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.listener.reset()
self.app_name = self.__class__.__name__.lower()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener], appName=self.app_name
)
@patch("random.random")
@async_client_context.require_failCommand_appName
async def test_01_operation_retry_uses_exponential_backoff(self, random_func):
# Drivers should test that retries do not occur immediately when a SystemOverloadedError is encountered.
# 1. let `client` be a `MongoClient`
client = self.client
# 2. let `collection` be a collection
collection = client.test.test
# 3. Now, run transactions without backoff:
# a. Configure the random number generator used for jitter to always return `0` -- this effectively disables backoff.
random_func.return_value = 0
# b. Configure the following failPoint:
fail_point = dict(
mode="alwaysOn",
data=dict(
failCommands=["insert"],
errorCode=2,
errorLabels=["SystemOverloadedError", "RetryableError"],
appName=self.app_name,
),
)
async with self.fail_point(fail_point):
# c. Execute the following command. Expect that the command errors. Measure the duration of the command execution.
start0 = perf_counter()
with self.assertRaises(OperationFailure):
await collection.insert_one({"a": 1})
end0 = perf_counter()
# d. Configure the random number generator used for jitter to always return `1`.
random_func.return_value = 1
# e. Execute step c again.
start1 = perf_counter()
with self.assertRaises(OperationFailure):
await collection.insert_one({"a": 1})
end1 = perf_counter()
# f. Compare the times between the two runs.
# The sum of 2 backoffs is 0.3 seconds. There is a 0.3-second window to account for potential variance between the two
# runs.
self.assertTrue(abs((end1 - start1) - (end0 - start0 + 0.3)) < 0.3)
@async_client_context.require_failCommand_appName
async def test_03_overload_retries_limited(self):
# Drivers should test that overload errors are retried a maximum of two times.
# 1. Let `client` be a `MongoClient`.
client = self.client
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
async with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
await coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is MAX_ADAPTIVE_RETRIES + 1.
self.assertEqual(len(self.listener.started_events), MAX_ADAPTIVE_RETRIES + 1)
@async_client_context.require_failCommand_appName
async def test_04_overload_retries_limited_configured(self):
# Drivers should test that overload errors are retried a maximum of maxAdaptiveRetries times.
max_retries = 1
# 1. Let `client` be a `MongoClient` with `maxAdaptiveRetries=1` and command event monitoring enabled.
client = await self.async_single_client(
maxAdaptiveRetries=max_retries, event_listeners=[self.listener]
)
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
async with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
await coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is max_retries + 1.
self.assertEqual(len(self.listener.started_events), max_retries + 1)
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "client-backpressure")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-backpressure")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -219,6 +219,19 @@ class TestClientMetadataProse(AsyncIntegrationTest):
# add same metadata again
await self.check_metadata_added(client, "Framework", None, None)
async def test_handshake_documents_include_backpressure(self):
# Create a `MongoClient` that is configured to record all handshake documents sent to the server as a part of
# connection establishment.
client = await self.async_rs_or_single_client("mongodb://" + self.server.address_string)
# Send a `ping` command to the server and verify that the command succeeds. This ensure that a connection is
# established on all topologies. Note: MockupDB only supports standalone servers.
await client.admin.command("ping")
# Assert that for every handshake document intercepted:
# the document has a field `backpressure` whose value is `true`.
self.assertEqual(self.handshake_req["backpressure"], True)
if __name__ == "__main__":
unittest.main()

View File

@ -25,8 +25,10 @@ from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.asynchronous.helpers import ConcurrentRunner
from test.asynchronous.utils import flaky
from test.utils_shared import delay
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.errors import ConnectionFailure
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
@ -70,7 +72,12 @@ from pymongo.errors import (
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _check_command_response, _check_write_command_response
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
from pymongo.monitoring import (
ConnectionCheckOutFailedEvent,
PoolClearedEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatStartedEvent,
)
from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE
@ -131,6 +138,9 @@ async def got_app_error(topology, app_error):
raise AssertionError
except (AutoReconnect, NotPrimaryError, OperationFailure) as e:
if when == "beforeHandshakeCompletes":
# The pool would have added the SystemOverloadedError in this case.
if isinstance(e, AutoReconnect):
e._add_error_label("SystemOverloadedError")
completed_handshake = False
elif when == "afterHandshakeCompletes":
completed_handshake = True
@ -439,6 +449,59 @@ class TestPoolManagement(AsyncIntegrationTest):
AsyncConnection.close_conn = original_close
class TestPoolBackpressure(AsyncIntegrationTest):
@async_client_context.require_version_min(7, 0, 0)
async def test_connection_pool_is_not_cleared(self):
listener = CMAPListener()
# Create a client that listens to CMAP events, with maxConnecting=100.
client = await self.async_rs_or_single_client(maxConnecting=100, event_listeners=[listener])
# Enable the ingress rate limiter.
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=True
)
await client.admin.command("setParameter", 1, ingressConnectionEstablishmentRatePerSec=20)
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentBurstCapacitySecs=1
)
await client.admin.command("setParameter", 1, ingressConnectionEstablishmentMaxQueueDepth=1)
# Disable the ingress rate limiter on teardown.
# Sleep for 1 second before disabling to avoid the rate limiter.
async def teardown():
await asyncio.sleep(1)
await client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=False
)
self.addAsyncCleanup(teardown)
# Make sure the collection has at least one document.
await client.test.test.delete_many({})
await client.test.test.insert_one({})
# Run a slow operation to tie up the connection.
async def target():
try:
await client.test.test.find_one({"$where": delay(0.1)})
except ConnectionFailure:
pass
# Run 100 parallel operations that contend for connections.
tasks = []
for _ in range(100):
tasks.append(ConcurrentRunner(target=target))
for t in tasks:
await t.start()
for t in tasks:
await t.join()
# Verify there were at least 10 connection checkout failed event but no pool cleared events.
self.assertGreater(len(listener.events_by_type(ConnectionCheckOutFailedEvent)), 10)
self.assertEqual(len(listener.events_by_type(PoolClearedEvent)), 0)
class TestServerMonitoringMode(AsyncIntegrationTest):
@async_client_context.require_no_load_balancer
async def asyncSetUp(self):

View File

@ -513,6 +513,39 @@ class TestPooling(_TestPoolingBase):
str(error.exception),
)
@async_client_context.require_failCommand_appName
async def test_pool_backpressure_preserves_existing_connections(self):
client = await self.async_rs_or_single_client()
coll = client.pymongo_test.t
pool = await async_get_pool(client)
await coll.insert_many([{"x": 1} for _ in range(10)])
t = SocketGetter(self.c, pool)
await t.start()
while t.state != "connection":
await asyncio.sleep(0.1)
assert not t.sock.conn_closed()
# Mock a session establishment overload.
mock_connection_fail = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"closeConnection": True,
},
}
async with self.fail_point(mock_connection_fail):
await coll.find_one({})
# Make sure the existing socket was not affected.
assert not t.sock.conn_closed()
# Cleanup
await t.release_conn()
await t.join()
await pool.close()
class TestPoolMaxSize(_TestPoolingBase):
async def test_max_pool_size(self):

View File

@ -265,14 +265,17 @@ class TestRetryableReads(AsyncIntegrationTest):
@async_client_context.require_secondaries_count(1)
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, 0)
async def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available(
async def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available_and_overload_retargeting_is_enabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, `enableOverloadRetargeting=True`, and command event monitoring enabled.
client = await self.async_rs_or_single_client(
event_listeners=[listener], retryReads=True, readPreference="primaryPreferred"
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
enableOverloadRetargeting=True,
)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
@ -339,6 +342,47 @@ class TestRetryableReads(AsyncIntegrationTest):
# 6. Assert that both events occurred the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
@async_client_context.require_replica_set
@async_client_context.require_secondaries_count(1)
@async_client_context.require_failCommand_fail_point
@async_client_context.require_version_min(4, 4, 0)
async def test_03_03_retryable_reads_caused_by_overload_errors_are_retried_on_the_same_replicaset_server_when_one_is_available_and_overload_retargeting_is_disabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
client = await self.async_rs_or_single_client(
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 6,
},
}
await async_set_fail_point(client, command_args)
# 3. Reset the command event monitor to clear the fail point command from its stored events.
listener.reset()
# 4. Execute a `find` command with `client`.
await client.t.t.find_one({})
# 5. Assert that one failed command event and one successful command event occurred.
self.assertEqual(len(listener.failed_events), 1)
self.assertEqual(len(listener.succeeded_events), 1)
# 6. Assert that both events occurred on the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
if __name__ == "__main__":
unittest.main()

View File

@ -43,14 +43,17 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo import MongoClient
from pymongo.errors import (
AutoReconnect,
ConnectionFailure,
OperationFailure,
NotPrimaryError,
PyMongoError,
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.monitoring import (
CommandFailedEvent,
CommandSucceededEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
@ -601,5 +604,186 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
self.assertEqual(sent_txn_id, final_txn_id, msg)
class TestErrorPropagationAfterEncounteringMultipleErrors(AsyncIntegrationTest):
# Only run against replica sets as mongos does not propagate the NoWritesPerformed label to the drivers.
@async_client_context.require_replica_set
# Run against server versions 6.0 and above.
@async_client_context.require_version_min(6, 0) # type: ignore[untyped-decorator]
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.setup_client = MongoClient(**async_client_context.default_client_options)
self.addCleanup(self.setup_client.close)
# TODO: After PYTHON-4595 we can use async event handlers and remove this workaround.
def configure_fail_point_sync(self, command_args, off=False) -> None:
cmd = {"configureFailPoint": "failCommand"}
cmd.update(command_args)
if off:
cmd["mode"] = "off"
cmd.pop("data", None)
self.setup_client.admin.command(cmd)
async def test_01_drivers_return_the_correct_error_when_receiving_only_errors_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code 10107 (NotWritablePrimary).
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 10107.
assert exc.exception.errors["code"] == 10107 # type:ignore[call-overload]
async def test_02_drivers_return_the_correct_error_when_receiving_only_errors_with_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code `10107` (NotWritablePrimary)
# and a NoWritesPerformed label.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
},
}
def failed(event: CommandFailedEvent) -> None:
if listener.failed_events:
return
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91 # type:ignore[call-overload]
async def test_03_drivers_return_the_correct_error_when_receiving_some_errors_with_NoWritesPerformed_and_some_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (NotWritablePrimary) and the `NoWritesPerformed`, `RetryableError` and `SystemOverloadedError` labels.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and
# `SystemOverloadedError` error labels but without the `NoWritesPerformed` error label.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = await self.async_rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(PyMongoError) as exc:
await client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91
# Assert that the error does not contain the error label `NoWritesPerformed`.
assert "NoWritesPerformed" not in exc.exception.errors["errorLabels"]
if __name__ == "__main__":
unittest.main()

View File

@ -16,9 +16,13 @@
from __future__ import annotations
import asyncio
import random
import sys
import time
from io import BytesIO
from unittest.mock import patch
import pymongo
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
from pymongo.asynchronous.pool import PoolState
from pymongo.server_selectors import writable_server_selector
@ -45,7 +49,9 @@ from pymongo.errors import (
CollectionInvalid,
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
@ -434,7 +440,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await self.configure_fail_point(client, command_args)
@async_client_context.require_transactions
async def test_callback_raises_custom_error(self):
async def test_1_callback_raises_custom_error(self):
class _MyException(Exception):
pass
@ -446,7 +452,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await s.with_transaction(raise_error)
@async_client_context.require_transactions
async def test_callback_returns_value(self):
async def test_2_callback_returns_value(self):
async def callback(_):
return "Foo"
@ -474,7 +480,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
self.assertEqual(await s.with_transaction(callback), "Foo")
@async_client_context.require_transactions
async def test_callback_not_retried_after_timeout(self):
async def test_3_1_callback_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -495,14 +501,16 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
listener.reset()
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_callback_not_retried_after_commit_timeout(self):
async def test_3_2_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -529,14 +537,16 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_commit_not_retried_after_timeout(self):
async def test_3_3_commit_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -560,7 +570,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic
@ -568,6 +578,40 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
@async_client_context.require_transactions
async def test_callback_not_retried_after_csot_timeout(self):
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
async def callback(session):
await coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)
# Create the collection.
await coll.insert_one({})
listener.reset()
async with client.start_session() as s:
with pymongo.timeout(1.0):
with self.assertRaises(ExecutionTimeout):
await s.with_transaction(callback)
# At least two attempts: the original and one or more retries.
inserts = len([x for x in listener.started_command_names() if x == "insert"])
aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"])
self.assertGreaterEqual(inserts, 2)
self.assertGreaterEqual(aborts, 2)
# Tested here because this supports Motor's convenient transactions API.
@async_client_context.require_transactions
@ -606,6 +650,63 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
await s.with_transaction(callback)
self.assertFalse(s.in_transaction)
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_4_retry_backoff_is_enforced(self):
client = async_client_context.client
coll = client[self.db.name].test
end = start = no_backoff_time = 0
# Make random.random always return 0 (no backoff)
with patch.object(random, "random", return_value=0):
# set fail point to trigger transaction failure and trigger backoff
await self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {"times": 13},
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
async def callback(session):
await coll.insert_one({}, session=session)
start = time.monotonic()
async with self.client.start_session() as s:
await s.with_transaction(callback)
end = time.monotonic()
no_backoff_time = end - start
# Make random.random always return 1 (max backoff)
with patch.object(random, "random", return_value=1):
# set fail point to trigger transaction failure and trigger backoff
await self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {
"times": 13
}, # sufficiently high enough such that the time effect of backoff is noticeable
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
start = time.monotonic()
async with self.client.start_session() as s:
await s.with_transaction(callback)
end = time.monotonic()
self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2
class TestOptionsInsideTransactionProse(AsyncTransactionsBase):
@async_client_context.require_transactions

View File

@ -0,0 +1,111 @@
{
"description": "tests that connections are returned to the pool on retry attempts for overload errors",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"topologies": [
"replicaset",
"sharded",
"load-balanced"
]
}
],
"createEntities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"connectionCheckedOutEvent",
"connectionCheckedInEvent"
]
}
},
{
"client": {
"id": "fail_point_client",
"useMultipleMongoses": false
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "backpressure-connection-checkin"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "coll"
}
}
],
"tests": [
{
"description": "overload error retry attempts return connections to the pool",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "fail_point_client",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"find"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"object": "collection",
"arguments": {
"filter": {}
},
"expectError": {
"isError": true,
"isClientError": false
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": [
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
},
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
},
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
}
]
}
]
}
]
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,253 @@
{
"description": "getMore-retried-backpressure",
"schemaVersion": "1.3",
"runOnRequirements": [
{
"minServerVersion": "4.4"
}
],
"createEntities": [
{
"client": {
"id": "client0",
"useMultipleMongoses": false,
"observeEvents": [
"commandStartedEvent",
"commandFailedEvent",
"commandSucceededEvent"
]
}
},
{
"client": {
"id": "failPointClient",
"useMultipleMongoses": false
}
},
{
"database": {
"id": "db",
"client": "client0",
"databaseName": "default"
}
},
{
"collection": {
"id": "coll",
"database": "db",
"collectionName": "default"
}
}
],
"initialData": [
{
"databaseName": "default",
"collectionName": "default",
"documents": [
{
"a": 1
},
{
"a": 2
},
{
"a": 3
}
]
}
],
"tests": [
{
"description": "getMores are retried",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "failPointClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"getMore"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"object": "coll",
"arguments": {
"batchSize": 2,
"filter": {},
"sort": {
"a": 1
}
},
"expectResult": [
{
"a": 1
},
{
"a": 2
},
{
"a": 3
}
]
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandSucceededEvent": {
"commandName": "getMore"
}
}
]
}
]
},
{
"description": "getMores are retried maxAttempts=2 times",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "failPointClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"getMore"
],
"errorLabels": [
"RetryableError",
"SystemOverloadedError"
],
"errorCode": 2
}
}
}
},
{
"name": "find",
"arguments": {
"batchSize": 2,
"filter": {}
},
"object": "coll",
"expectError": {
"isError": true,
"isClientError": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"commandName": "find"
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "getMore"
}
},
{
"commandFailedEvent": {
"commandName": "getMore"
}
},
{
"commandStartedEvent": {
"commandName": "killCursors"
}
},
{
"commandSucceededEvent": {
"commandName": "killCursors"
}
}
]
}
]
}
]
}

View File

@ -9,9 +9,7 @@
],
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 50
},
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",

View File

@ -97,14 +97,22 @@
"outcome": {
"servers": {
"a:27017": {
"type": "Unknown",
"topologyVersion": null,
"type": "RSPrimary",
"setName": "rs",
"topologyVersion": {
"processId": {
"$oid": "000000000000000000000001"
},
"counter": {
"$numberLong": "1"
}
},
"pool": {
"generation": 1
"generation": 0
}
}
},
"topologyType": "ReplicaSetNoPrimary",
"topologyType": "ReplicaSetWithPrimary",
"logicalSessionTimeoutMinutes": null,
"setName": "rs"
}

View File

@ -0,0 +1,142 @@
{
"description": "backpressure-network-error-fail-replicaset",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"replicaset"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"initialData": [
{
"collectionName": "backpressure-network-error-fail",
"databaseName": "sdam-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
],
"tests": [
{
"description": "apply backpressure on network connection errors during connection establishment",
"operations": [
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"serverDescriptionChangedEvent",
"poolClearedEvent"
],
"uriOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 1000000,
"serverMonitoringMode": "poll",
"appname": "backpressureNetworkErrorFailTest"
}
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "sdam-tests"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "backpressure-network-error-fail"
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {
"newDescription": {
"type": "RSPrimary"
}
}
},
"count": 1
}
},
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",
"hello"
],
"appName": "backpressureNetworkErrorFailTest",
"closeConnection": true
}
}
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"expectError": {
"isError": true,
"errorLabelsContain": [
"SystemOverloadedError",
"RetryableError"
]
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,142 @@
{
"description": "backpressure-network-error-fail-single",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"single"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"initialData": [
{
"collectionName": "backpressure-network-error-fail",
"databaseName": "sdam-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
],
"tests": [
{
"description": "apply backpressure on network connection errors during connection establishment",
"operations": [
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"serverDescriptionChangedEvent",
"poolClearedEvent"
],
"uriOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 1000000,
"serverMonitoringMode": "poll",
"appname": "backpressureNetworkErrorFailTest"
}
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "sdam-tests"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "backpressure-network-error-fail"
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {
"newDescription": {
"type": "Standalone"
}
}
},
"count": 1
}
},
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",
"hello"
],
"appName": "backpressureNetworkErrorFailTest",
"closeConnection": true
}
}
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"expectError": {
"isError": true,
"errorLabelsContain": [
"SystemOverloadedError",
"RetryableError"
]
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,145 @@
{
"description": "backpressure-network-timeout-error-replicaset",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"replicaset"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"initialData": [
{
"collectionName": "backpressure-network-timeout-error",
"databaseName": "sdam-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
],
"tests": [
{
"description": "apply backpressure on network timeout error during connection establishment",
"operations": [
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"serverDescriptionChangedEvent",
"poolClearedEvent"
],
"uriOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 1000000,
"appname": "backpressureNetworkTimeoutErrorTest",
"serverMonitoringMode": "poll",
"connectTimeoutMS": 250,
"socketTimeoutMS": 250
}
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "sdam-tests"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "backpressure-network-timeout-error"
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {
"newDescription": {
"type": "RSPrimary"
}
}
},
"count": 1
}
},
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",
"hello"
],
"blockConnection": true,
"blockTimeMS": 500,
"appName": "backpressureNetworkTimeoutErrorTest"
}
}
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"expectError": {
"isError": true,
"errorLabelsContain": [
"SystemOverloadedError",
"RetryableError"
]
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,145 @@
{
"description": "backpressure-network-timeout-error-single",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"single"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"initialData": [
{
"collectionName": "backpressure-network-timeout-error",
"databaseName": "sdam-tests",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
],
"tests": [
{
"description": "apply backpressure on network timeout error during connection establishment",
"operations": [
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"useMultipleMongoses": false,
"observeEvents": [
"serverDescriptionChangedEvent",
"poolClearedEvent"
],
"uriOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 1000000,
"appname": "backpressureNetworkTimeoutErrorTest",
"serverMonitoringMode": "poll",
"connectTimeoutMS": 250,
"socketTimeoutMS": 250
}
}
},
{
"database": {
"id": "database",
"client": "client",
"databaseName": "sdam-tests"
}
},
{
"collection": {
"id": "collection",
"database": "database",
"collectionName": "backpressure-network-timeout-error"
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {
"newDescription": {
"type": "Standalone"
}
}
},
"count": 1
}
},
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": [
"isMaster",
"hello"
],
"blockConnection": true,
"blockTimeMS": 500,
"appName": "backpressureNetworkTimeoutErrorTest"
}
}
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"expectError": {
"isError": true,
"errorLabelsContain": [
"SystemOverloadedError",
"RetryableError"
]
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "cmap",
"events": []
}
]
}
]
}

View File

@ -0,0 +1,106 @@
{
"description": "backpressure-server-description-unchanged-on-min-pool-size-population-error",
"schemaVersion": "1.17",
"runOnRequirements": [
{
"minServerVersion": "4.4",
"serverless": "forbid",
"topologies": [
"single"
]
}
],
"createEntities": [
{
"client": {
"id": "setupClient",
"useMultipleMongoses": false
}
}
],
"tests": [
{
"description": "the server description is not changed on handshake error during minPoolSize population",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "setupClient",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"skip": 1
},
"data": {
"failCommands": [
"hello",
"isMaster"
],
"appName": "authErrorTest",
"closeConnection": true
}
}
}
},
{
"name": "createEntities",
"object": "testRunner",
"arguments": {
"entities": [
{
"client": {
"id": "client",
"observeEvents": [
"serverDescriptionChangedEvent",
"connectionClosedEvent"
],
"uriOptions": {
"appname": "authErrorTest",
"minPoolSize": 5,
"maxConnecting": 1,
"serverMonitoringMode": "poll",
"heartbeatFrequencyMS": 1000000
}
}
}
]
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"serverDescriptionChangedEvent": {}
},
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"client": "client",
"event": {
"connectionClosedEvent": {}
},
"count": 1
}
}
],
"expectEvents": [
{
"client": "client",
"eventType": "sdam",
"events": [
{
"serverDescriptionChangedEvent": {}
}
]
}
]
}
]
}

155
test/test_azure_helpers.py Normal file
View File

@ -0,0 +1,155 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for _azure_helpers.py.
These tests mock urlopen to avoid requiring a live Azure IMDS endpoint.
Integration tests that exercise the real endpoint are gated by environment
variables in test_on_demand_csfle.py and test_auth_oidc.py.
"""
from __future__ import annotations
import json
import sys
import unittest
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
sys.path[0:0] = [""]
from pymongo._azure_helpers import _get_azure_response
@contextmanager
def _mock_urlopen(status: int, body: str):
"""Context manager that patches ``urllib.request.urlopen`` with a fake response."""
mock_response = MagicMock()
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
mock_response.status = status
mock_response.read.return_value = body.encode("utf8")
with patch("urllib.request.urlopen", return_value=mock_response) as mock_open:
yield mock_open
class TestGetAzureResponse(unittest.TestCase):
def _call(self, resource="https://example.com/", client_id=None, timeout=5):
return _get_azure_response(resource, client_id=client_id, timeout=timeout)
def test_success_without_client_id(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
result = self._call()
self.assertEqual(result["access_token"], "tok")
self.assertEqual(result["expires_in"], "3600")
# Verify client_id was NOT added to the URL
url = mock_open.call_args[0][0].full_url
self.assertNotIn("client_id", url)
def test_success_with_client_id(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
result = self._call(client_id="my-client-id")
self.assertEqual(result["access_token"], "tok")
url = mock_open.call_args[0][0].full_url
self.assertIn("client_id=my-client-id", url)
def test_url_contains_resource_and_api_version(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
self._call(resource="https://test-resource.example.com")
url = mock_open.call_args[0][0].full_url
self.assertIn("api-version=2018-02-01", url)
self.assertIn("resource=https://test-resource.example.com", url)
def test_request_headers(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
self._call()
request = mock_open.call_args[0][0]
self.assertEqual(request.get_header("Metadata"), "true")
self.assertEqual(request.get_header("Accept"), "application/json")
def test_urlopen_exception_raises_value_error(self):
with patch("urllib.request.urlopen", side_effect=OSError("connection refused")):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("Failed to acquire IMDS access token", str(ctx.exception))
def test_non_200_status_raises_value_error(self):
body = json.dumps({"error": "something went wrong"})
with _mock_urlopen(400, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("Failed to acquire IMDS access token", str(ctx.exception))
def test_non_json_body_raises_value_error(self):
with _mock_urlopen(200, "not-json"):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("Azure IMDS response must be in JSON format", str(ctx.exception))
def test_missing_access_token_raises_value_error(self):
body = json.dumps({"expires_in": "3600"})
with _mock_urlopen(200, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("access_token", str(ctx.exception))
def test_missing_expires_in_raises_value_error(self):
body = json.dumps({"access_token": "tok"})
with _mock_urlopen(200, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("expires_in", str(ctx.exception))
def test_empty_access_token_raises_value_error(self):
body = json.dumps({"access_token": "", "expires_in": "3600"})
with _mock_urlopen(200, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("access_token", str(ctx.exception))
def test_empty_expires_in_raises_value_error(self):
body = json.dumps({"access_token": "tok", "expires_in": ""})
with _mock_urlopen(200, body):
with self.assertRaises(ValueError) as ctx:
self._call()
self.assertIn("expires_in", str(ctx.exception))
def test_timeout_passed_to_urlopen(self):
body = json.dumps({"access_token": "tok", "expires_in": "3600"})
with _mock_urlopen(200, body) as mock_open:
self._call(timeout=42)
_, kwargs = mock_open.call_args
self.assertEqual(kwargs["timeout"], 42)
if __name__ == "__main__":
unittest.main()

View File

@ -645,6 +645,38 @@ class ClientUnitTest(UnitTest):
with self.assertWarns(UserWarning):
self.simple_client(multi_host)
def test_max_adaptive_retries(self):
# Assert that max adaptive retries defaults to 2.
c = self.simple_client(connect=False)
self.assertEqual(c.options.max_adaptive_retries, 2)
# Assert that max adaptive retries can be configured through connection or client options.
c = self.simple_client(connect=False, max_adaptive_retries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(connect=False, maxAdaptiveRetries=10)
self.assertEqual(c.options.max_adaptive_retries, 10)
c = self.simple_client(host="mongodb://localhost/?maxAdaptiveRetries=10", connect=False)
self.assertEqual(c.options.max_adaptive_retries, 10)
def test_enable_overload_retargeting(self):
# Assert that overload retargeting defaults to false.
c = self.simple_client(connect=False)
self.assertFalse(c.options.enable_overload_retargeting)
# Assert that overload retargeting can be enabled through connection or client options.
c = self.simple_client(connect=False, enable_overload_retargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(connect=False, enableOverloadRetargeting=True)
self.assertTrue(c.options.enable_overload_retargeting)
c = self.simple_client(
host="mongodb://localhost/?enableOverloadRetargeting=true", connect=False
)
self.assertTrue(c.options.enable_overload_retargeting)
class TestClient(IntegrationTest):
def test_multiple_uris(self):
@ -1007,7 +1039,7 @@ class TestClient(IntegrationTest):
db_names = self.client.list_database_names()
self.assertIn("pymongo_test", db_names)
self.assertIn("pymongo_test_mike", db_names)
self.assertEqual(db_names, cmd_names)
self.assertCountEqual(db_names, cmd_names)
def test_drop_database(self):
with self.assertRaises(TypeError):

View File

@ -0,0 +1,310 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Client Backpressure spec."""
from __future__ import annotations
import os
import pathlib
import sys
from time import perf_counter
from unittest.mock import patch
from pymongo.common import MAX_ADAPTIVE_RETRIES
sys.path[0:0] = [""]
from test import (
IntegrationTest,
client_context,
unittest,
)
from test.unified_format import generate_test_classes
from test.utils_shared import EventListener, OvertCommandListener
from pymongo.errors import OperationFailure, PyMongoError
_IS_SYNC = True
# Mock a system overload error.
mock_overload_error = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find", "insert", "update"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def get_mock_overload_error(times: int):
error = mock_overload_error.copy()
error["mode"] = {"times": times}
return error
class TestBackpressure(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
@client_context.require_failCommand_appName
def test_retry_overload_error_command(self):
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.command("find", "t")
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.command("find", "t")
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_find(self):
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.find_one()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.find_one()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_insert_one(self):
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.insert_one({"x": 1})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.insert_one({"x": 1})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_update_many(self):
# Even though update_many is not a retryable write operation, it will
# still be retried via the "RetryableError" error label.
self.db.t.insert_one({"x": 1})
# Ensure command is retried on overload error.
fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES)
with self.fail_point(fail_many):
self.db.t.update_many({}, {"$set": {"x": 2}})
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1)
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
self.db.t.update_many({}, {"$set": {"x": 2}})
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
@client_context.require_failCommand_appName
def test_retry_overload_error_getMore(self):
coll = self.db.t
coll.insert_many([{"x": 1} for _ in range(10)])
# Ensure command is retried on overload error.
fail_many = {
"configureFailPoint": "failCommand",
"mode": {"times": MAX_ADAPTIVE_RETRIES},
"data": {
"failCommands": ["getMore"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
cursor = coll.find(batch_size=2)
cursor.next()
with self.fail_point(fail_many):
cursor.to_list()
# Ensure command stops retrying after MAX_ADAPTIVE_RETRIES.
fail_too_many = fail_many.copy()
fail_too_many["mode"] = {"times": MAX_ADAPTIVE_RETRIES + 1}
cursor = coll.find(batch_size=2)
cursor.next()
with self.fail_point(fail_too_many):
with self.assertRaises(PyMongoError) as error:
cursor.to_list()
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# Prose tests.
class TestClientBackpressure(IntegrationTest):
listener: EventListener
@classmethod
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
@client_context.require_connection
def setUp(self) -> None:
super().setUp()
self.listener.reset()
self.app_name = self.__class__.__name__.lower()
self.client = self.rs_or_single_client(
event_listeners=[self.listener], appName=self.app_name
)
@patch("random.random")
@client_context.require_failCommand_appName
def test_01_operation_retry_uses_exponential_backoff(self, random_func):
# Drivers should test that retries do not occur immediately when a SystemOverloadedError is encountered.
# 1. let `client` be a `MongoClient`
client = self.client
# 2. let `collection` be a collection
collection = client.test.test
# 3. Now, run transactions without backoff:
# a. Configure the random number generator used for jitter to always return `0` -- this effectively disables backoff.
random_func.return_value = 0
# b. Configure the following failPoint:
fail_point = dict(
mode="alwaysOn",
data=dict(
failCommands=["insert"],
errorCode=2,
errorLabels=["SystemOverloadedError", "RetryableError"],
appName=self.app_name,
),
)
with self.fail_point(fail_point):
# c. Execute the following command. Expect that the command errors. Measure the duration of the command execution.
start0 = perf_counter()
with self.assertRaises(OperationFailure):
collection.insert_one({"a": 1})
end0 = perf_counter()
# d. Configure the random number generator used for jitter to always return `1`.
random_func.return_value = 1
# e. Execute step c again.
start1 = perf_counter()
with self.assertRaises(OperationFailure):
collection.insert_one({"a": 1})
end1 = perf_counter()
# f. Compare the times between the two runs.
# The sum of 2 backoffs is 0.3 seconds. There is a 0.3-second window to account for potential variance between the two
# runs.
self.assertTrue(abs((end1 - start1) - (end0 - start0 + 0.3)) < 0.3)
@client_context.require_failCommand_appName
def test_03_overload_retries_limited(self):
# Drivers should test that overload errors are retried a maximum of two times.
# 1. Let `client` be a `MongoClient`.
client = self.client
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is MAX_ADAPTIVE_RETRIES + 1.
self.assertEqual(len(self.listener.started_events), MAX_ADAPTIVE_RETRIES + 1)
@client_context.require_failCommand_appName
def test_04_overload_retries_limited_configured(self):
# Drivers should test that overload errors are retried a maximum of maxAdaptiveRetries times.
max_retries = 1
# 1. Let `client` be a `MongoClient` with `maxAdaptiveRetries=1` and command event monitoring enabled.
client = self.single_client(maxAdaptiveRetries=max_retries, event_listeners=[self.listener])
# 2. Let `coll` be a collection.
coll = client.pymongo_test.coll
# 3. Configure the following failpoint:
failpoint = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["find"],
"errorCode": 462, # IngressRequestRateLimitExceeded
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
# 4. Perform a find operation with `coll` that fails.
with self.fail_point(failpoint):
with self.assertRaises(PyMongoError) as error:
coll.find_one({})
# 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels.
self.assertIn("RetryableError", str(error.exception))
self.assertIn("SystemOverloadedError", str(error.exception))
# 6. Assert that the total number of started commands is max_retries + 1.
self.assertEqual(len(self.listener.started_events), max_retries + 1)
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "client-backpressure")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-backpressure")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -219,6 +219,19 @@ class TestClientMetadataProse(IntegrationTest):
# add same metadata again
self.check_metadata_added(client, "Framework", None, None)
def test_handshake_documents_include_backpressure(self):
# Create a `MongoClient` that is configured to record all handshake documents sent to the server as a part of
# connection establishment.
client = self.rs_or_single_client("mongodb://" + self.server.address_string)
# Send a `ping` command to the server and verify that the command succeeds. This ensure that a connection is
# established on all topologies. Note: MockupDB only supports standalone servers.
client.admin.command("ping")
# Assert that for every handshake document intercepted:
# the document has a field `backpressure` whose value is `true`.
self.assertEqual(self.handshake_req["backpressure"], True)
if __name__ == "__main__":
unittest.main()

View File

@ -25,7 +25,9 @@ from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.helpers import ConcurrentRunner
from test.utils import flaky
from test.utils_shared import delay
from pymongo.errors import ConnectionFailure
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import Connection
@ -67,7 +69,12 @@ from pymongo.errors import (
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers_shared import _check_command_response, _check_write_command_response
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
from pymongo.monitoring import (
ConnectionCheckOutFailedEvent,
PoolClearedEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatStartedEvent,
)
from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
@ -131,6 +138,9 @@ def got_app_error(topology, app_error):
raise AssertionError
except (AutoReconnect, NotPrimaryError, OperationFailure) as e:
if when == "beforeHandshakeCompletes":
# The pool would have added the SystemOverloadedError in this case.
if isinstance(e, AutoReconnect):
e._add_error_label("SystemOverloadedError")
completed_handshake = False
elif when == "afterHandshakeCompletes":
completed_handshake = True
@ -437,6 +447,57 @@ class TestPoolManagement(IntegrationTest):
Connection.close_conn = original_close
class TestPoolBackpressure(IntegrationTest):
@client_context.require_version_min(7, 0, 0)
def test_connection_pool_is_not_cleared(self):
listener = CMAPListener()
# Create a client that listens to CMAP events, with maxConnecting=100.
client = self.rs_or_single_client(maxConnecting=100, event_listeners=[listener])
# Enable the ingress rate limiter.
client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=True
)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentRatePerSec=20)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentBurstCapacitySecs=1)
client.admin.command("setParameter", 1, ingressConnectionEstablishmentMaxQueueDepth=1)
# Disable the ingress rate limiter on teardown.
# Sleep for 1 second before disabling to avoid the rate limiter.
def teardown():
time.sleep(1)
client.admin.command(
"setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=False
)
self.addCleanup(teardown)
# Make sure the collection has at least one document.
client.test.test.delete_many({})
client.test.test.insert_one({})
# Run a slow operation to tie up the connection.
def target():
try:
client.test.test.find_one({"$where": delay(0.1)})
except ConnectionFailure:
pass
# Run 100 parallel operations that contend for connections.
tasks = []
for _ in range(100):
tasks.append(ConcurrentRunner(target=target))
for t in tasks:
t.start()
for t in tasks:
t.join()
# Verify there were at least 10 connection checkout failed event but no pool cleared events.
self.assertGreater(len(listener.events_by_type(ConnectionCheckOutFailedEvent)), 10)
self.assertEqual(len(listener.events_by_type(PoolClearedEvent)), 0)
class TestServerMonitoringMode(IntegrationTest):
@client_context.require_no_load_balancer
def setUp(self):

116
test/test_gcp_helpers.py Normal file
View File

@ -0,0 +1,116 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for pymongo/_gcp_helpers.py."""
from __future__ import annotations
import sys
import unittest
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
sys.path[0:0] = [""]
from pymongo._gcp_helpers import _get_gcp_response
@contextmanager
def _mock_urlopen(status: int, body: str):
"""Context manager that patches ``urllib.request.urlopen`` with a fake response."""
mock_response = MagicMock()
mock_response.__enter__ = MagicMock(return_value=mock_response)
mock_response.__exit__ = MagicMock(return_value=False)
mock_response.status = status
mock_response.read.return_value = body.encode("utf8")
with patch("urllib.request.urlopen", return_value=mock_response) as mock_open:
yield mock_open
class TestGetGcpResponse(unittest.TestCase):
"""Tests for :func:`pymongo._gcp_helpers._get_gcp_response`."""
def test_successful_response_returns_access_token(self):
"""A 200 response yields ``{"access_token": <body>}``."""
token = "ya29.some-gcp-token"
with _mock_urlopen(200, token):
result = _get_gcp_response("https://example.com")
self.assertEqual(result, {"access_token": token})
def test_non_200_status_raises_value_error(self):
"""A non-200 HTTP status raises :class:`ValueError`."""
for status in (400, 401, 403, 500, 503):
with self.subTest(status=status):
with _mock_urlopen(status, "error"):
with self.assertRaises(ValueError) as ctx:
_get_gcp_response("https://example.com")
self.assertIn("Failed to acquire IMDS access token", str(ctx.exception))
def test_urlopen_exception_raises_value_error(self):
"""An exception from ``urlopen`` is wrapped in :class:`ValueError`."""
with patch("urllib.request.urlopen", side_effect=OSError("connection refused")):
with self.assertRaises(ValueError) as ctx:
_get_gcp_response("https://example.com")
self.assertIn("Failed to acquire IMDS access token", str(ctx.exception))
self.assertIn("connection refused", str(ctx.exception))
def test_url_contains_resource_as_audience(self):
"""The ``resource`` argument is appended as ``?audience=`` in the URL."""
resource = "https://my-service.example.com"
with _mock_urlopen(200, "token") as mock_open:
_get_gcp_response(resource)
request_obj = mock_open.call_args[0][0]
self.assertIn(f"?audience={resource}", request_obj.full_url)
def test_request_has_metadata_flavor_google_header(self):
"""The request must include the ``Metadata-Flavor: Google`` header."""
with _mock_urlopen(200, "token") as mock_open:
_get_gcp_response("https://example.com")
request_obj = mock_open.call_args[0][0]
self.assertEqual(request_obj.get_header("Metadata-flavor"), "Google")
def test_default_timeout_is_five_seconds(self):
"""Without an explicit timeout, ``urlopen`` is called with ``timeout=5``."""
with _mock_urlopen(200, "token") as mock_open:
_get_gcp_response("https://example.com")
_, kwargs = mock_open.call_args
self.assertEqual(kwargs.get("timeout"), 5)
def test_custom_timeout_is_forwarded(self):
"""An explicit ``timeout`` value is passed through to ``urlopen``."""
with _mock_urlopen(200, "token") as mock_open:
_get_gcp_response("https://example.com", timeout=30)
_, kwargs = mock_open.call_args
self.assertEqual(kwargs.get("timeout"), 30)
def test_urlopen_exception_does_not_chain_original(self):
"""The raised ``ValueError`` suppresses the original exception (``from None``)."""
with patch("urllib.request.urlopen", side_effect=RuntimeError("network error")):
with self.assertRaises(ValueError) as ctx:
_get_gcp_response("https://example.com")
# ``raise ... from None`` sets __cause__ to None and __suppress_context__ to True.
self.assertIs(ctx.exception.__cause__, None)
self.assertIs(ctx.exception.__suppress_context__, True)
if __name__ == "__main__":
unittest.main()

View File

@ -511,6 +511,39 @@ class TestPooling(_TestPoolingBase):
str(error.exception),
)
@client_context.require_failCommand_appName
def test_pool_backpressure_preserves_existing_connections(self):
client = self.rs_or_single_client()
coll = client.pymongo_test.t
pool = get_pool(client)
coll.insert_many([{"x": 1} for _ in range(10)])
t = SocketGetter(self.c, pool)
t.start()
while t.state != "connection":
time.sleep(0.1)
assert not t.sock.conn_closed()
# Mock a session establishment overload.
mock_connection_fail = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"closeConnection": True,
},
}
with self.fail_point(mock_connection_fail):
coll.find_one({})
# Make sure the existing socket was not affected.
assert not t.sock.conn_closed()
# Cleanup
t.release_conn()
t.join()
pool.close()
class TestPoolMaxSize(_TestPoolingBase):
def test_max_pool_size(self):

View File

@ -263,14 +263,17 @@ class TestRetryableReads(IntegrationTest):
@client_context.require_secondaries_count(1)
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, 0)
def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available(
def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available_and_overload_retargeting_is_enabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, `enableOverloadRetargeting=True`, and command event monitoring enabled.
client = self.rs_or_single_client(
event_listeners=[listener], retryReads=True, readPreference="primaryPreferred"
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
enableOverloadRetargeting=True,
)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
@ -337,6 +340,47 @@ class TestRetryableReads(IntegrationTest):
# 6. Assert that both events occurred the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
@client_context.require_replica_set
@client_context.require_secondaries_count(1)
@client_context.require_failCommand_fail_point
@client_context.require_version_min(4, 4, 0)
def test_03_03_retryable_reads_caused_by_overload_errors_are_retried_on_the_same_replicaset_server_when_one_is_available_and_overload_retargeting_is_disabled(
self
):
listener = OvertCommandListener()
# 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled.
client = self.rs_or_single_client(
event_listeners=[listener],
retryReads=True,
readPreference="primaryPreferred",
)
# 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 6,
},
}
set_fail_point(client, command_args)
# 3. Reset the command event monitor to clear the fail point command from its stored events.
listener.reset()
# 4. Execute a `find` command with `client`.
client.t.t.find_one({})
# 5. Assert that one failed command event and one successful command event occurred.
self.assertEqual(len(listener.failed_events), 1)
self.assertEqual(len(listener.succeeded_events), 1)
# 6. Assert that both events occurred on the same server.
assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id
if __name__ == "__main__":
unittest.main()

View File

@ -43,14 +43,17 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo import MongoClient
from pymongo.errors import (
AutoReconnect,
ConnectionFailure,
OperationFailure,
NotPrimaryError,
PyMongoError,
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.monitoring import (
CommandFailedEvent,
CommandSucceededEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
@ -597,5 +600,186 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
self.assertEqual(sent_txn_id, final_txn_id, msg)
class TestErrorPropagationAfterEncounteringMultipleErrors(IntegrationTest):
# Only run against replica sets as mongos does not propagate the NoWritesPerformed label to the drivers.
@client_context.require_replica_set
# Run against server versions 6.0 and above.
@client_context.require_version_min(6, 0) # type: ignore[untyped-decorator]
def setUp(self) -> None:
super().setUp()
self.setup_client = MongoClient(**client_context.default_client_options)
self.addCleanup(self.setup_client.close)
# TODO: After PYTHON-4595 we can use async event handlers and remove this workaround.
def configure_fail_point_sync(self, command_args, off=False) -> None:
cmd = {"configureFailPoint": "failCommand"}
cmd.update(command_args)
if off:
cmd["mode"] = "off"
cmd.pop("data", None)
self.setup_client.admin.command(cmd)
def test_01_drivers_return_the_correct_error_when_receiving_only_errors_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code 10107 (NotWritablePrimary).
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 10107.
assert exc.exception.errors["code"] == 10107 # type:ignore[call-overload]
def test_02_drivers_return_the_correct_error_when_receiving_only_errors_with_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Via the command monitoring CommandFailedEvent, configure a fail point with error code `10107` (NotWritablePrimary)
# and a NoWritesPerformed label.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorCode": 10107,
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
},
}
def failed(event: CommandFailedEvent) -> None:
if listener.failed_events:
return
# Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2.
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(NotPrimaryError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91 # type:ignore[call-overload]
def test_03_drivers_return_the_correct_error_when_receiving_some_errors_with_NoWritesPerformed_and_some_without_NoWritesPerformed(
self
) -> None:
# Create a client with retryWrites=true.
listener = OvertCommandListener()
# Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error
# code `91` (NotWritablePrimary) and the `NoWritesPerformed`, `RetryableError` and `SystemOverloadedError` labels.
command_args_inner = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {
"failCommands": ["insert"],
"errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"],
"errorCode": 91,
},
}
# Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and
# `SystemOverloadedError` error labels but without the `NoWritesPerformed` error label.
command_args = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"errorCode": 91,
"errorLabels": ["RetryableError", "SystemOverloadedError"],
},
}
def failed(event: CommandFailedEvent) -> None:
# Configure the fail point command only if the failed event is for the 91 error configured in step 2.
if listener.failed_events:
return
assert event.failure["code"] == 91
self.configure_fail_point_sync(command_args_inner)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
listener.failed_events.append(event)
listener.failed = failed
client = self.rs_client(retryWrites=True, event_listeners=[listener])
self.configure_fail_point_sync(command_args)
self.addCleanup(self.configure_fail_point_sync, {}, off=True)
# Attempt an insertOne operation on any record for any database and collection.
# Expect the insertOne to fail with a server error.
with self.assertRaises(PyMongoError) as exc:
client.test.test.insert_one({})
# Assert that the error code of the server error is 91.
assert exc.exception.errors["code"] == 91
# Assert that the error does not contain the error label `NoWritesPerformed`.
assert "NoWritesPerformed" not in exc.exception.errors["errorLabels"]
if __name__ == "__main__":
unittest.main()

View File

@ -16,9 +16,13 @@
from __future__ import annotations
import asyncio
import random
import sys
import time
from io import BytesIO
from unittest.mock import patch
import pymongo
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import PoolState
@ -40,7 +44,9 @@ from pymongo.errors import (
CollectionInvalid,
ConfigurationError,
ConnectionFailure,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
@ -426,7 +432,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.configure_fail_point(client, command_args)
@client_context.require_transactions
def test_callback_raises_custom_error(self):
def test_1_callback_raises_custom_error(self):
class _MyException(Exception):
pass
@ -438,7 +444,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
s.with_transaction(raise_error)
@client_context.require_transactions
def test_callback_returns_value(self):
def test_2_callback_returns_value(self):
def callback(_):
return "Foo"
@ -466,7 +472,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.assertEqual(s.with_transaction(callback), "Foo")
@client_context.require_transactions
def test_callback_not_retried_after_timeout(self):
def test_3_1_callback_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -487,14 +493,16 @@ class TestTransactionsConvenientAPI(TransactionsBase):
listener.reset()
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@client_context.require_test_commands
@client_context.require_transactions
def test_callback_not_retried_after_commit_timeout(self):
def test_3_2_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -519,14 +527,16 @@ class TestTransactionsConvenientAPI(TransactionsBase):
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@client_context.require_test_commands
@client_context.require_transactions
def test_commit_not_retried_after_timeout(self):
def test_3_3_commit_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
@ -548,7 +558,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic
@ -556,6 +566,40 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
@client_context.require_transactions
def test_callback_not_retried_after_csot_timeout(self):
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
coll = client[self.db.name].test
def callback(session):
coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)
# Create the collection.
coll.insert_one({})
listener.reset()
with client.start_session() as s:
with pymongo.timeout(1.0):
with self.assertRaises(ExecutionTimeout):
s.with_transaction(callback)
# At least two attempts: the original and one or more retries.
inserts = len([x for x in listener.started_command_names() if x == "insert"])
aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"])
self.assertGreaterEqual(inserts, 2)
self.assertGreaterEqual(aborts, 2)
# Tested here because this supports Motor's convenient transactions API.
@client_context.require_transactions
@ -594,6 +638,63 @@ class TestTransactionsConvenientAPI(TransactionsBase):
s.with_transaction(callback)
self.assertFalse(s.in_transaction)
@client_context.require_test_commands
@client_context.require_transactions
def test_4_retry_backoff_is_enforced(self):
client = client_context.client
coll = client[self.db.name].test
end = start = no_backoff_time = 0
# Make random.random always return 0 (no backoff)
with patch.object(random, "random", return_value=0):
# set fail point to trigger transaction failure and trigger backoff
self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {"times": 13},
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
def callback(session):
coll.insert_one({}, session=session)
start = time.monotonic()
with self.client.start_session() as s:
s.with_transaction(callback)
end = time.monotonic()
no_backoff_time = end - start
# Make random.random always return 1 (max backoff)
with patch.object(random, "random", return_value=1):
# set fail point to trigger transaction failure and trigger backoff
self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {
"times": 13
}, # sufficiently high enough such that the time effect of backoff is noticeable
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251,
},
}
)
self.addCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
start = time.monotonic()
with self.client.start_session() as s:
s.with_transaction(callback)
end = time.monotonic()
self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2
class TestOptionsInsideTransactionProse(TransactionsBase):
@client_context.require_transactions

View File

@ -213,6 +213,7 @@ converted_tests = [
"test_bulk.py",
"test_change_stream.py",
"test_client.py",
"test_client_backpressure.py",
"test_client_bulk_write.py",
"test_client_context.py",
"test_client_metadata.py",
@ -350,7 +351,7 @@ def translate_async_sleeps(lines: list[str]) -> list[str]:
sleeps = [line for line in lines if "asyncio.sleep" in line]
for line in sleeps:
res = re.search(r"asyncio.sleep\(([^()]*)\)", line)
res = re.search(r"asyncio\.sleep\(\s*(.*?)\)", line)
if res:
old = res[0]
index = lines.index(line)