diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 6ab3b3998..697de81b1 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -135,7 +135,9 @@ Classes from __future__ import annotations +import asyncio import collections +import random import time import uuid from collections.abc import Mapping as _Mapping @@ -470,6 +472,8 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( # # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 +_BACKOFF_MAX = 1 +_BACKOFF_INITIAL = 0.050 # 50ms initial backoff def _within_time_limit(start_time: float) -> bool: @@ -703,7 +707,13 @@ class AsyncClientSession: https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback """ start_time = time.monotonic() + retry = 0 while True: + if retry: # Implement exponential backoff on retry. + jitter = random.random() # noqa: S311 + backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX) + await asyncio.sleep(backoff) + retry += 1 await self.start_transaction( read_concern, write_concern, read_preference, max_commit_time_ms ) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index e7e2f5803..c6cb69af2 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -58,6 +58,7 @@ from pymongo.asynchronous.cursor import ( AsyncCursor, AsyncRawBatchCursor, ) +from pymongo.asynchronous.helpers import _retry_overload from pymongo.collation import validate_collation_or_none from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.errors import ( @@ -252,6 +253,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): unicode_decode_error_handler="replace", document_class=dict ) self._timeout = database.client.options.timeout + self._retry_policy = database.client._retry_policy if create or kwargs: if _IS_SYNC: @@ -2227,6 +2229,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): return await self._create_indexes(indexes, session, **kwargs) @_csot.apply + @_retry_overload async def _create_indexes( self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any ) -> list[str]: @@ -2422,7 +2425,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): kwargs["comment"] = comment await self._drop_index("*", session=session, **kwargs) - @_csot.apply async def drop_index( self, index_or_name: _IndexKeyHint, @@ -2472,6 +2474,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): await self._drop_index(index_or_name, session, comment, **kwargs) @_csot.apply + @_retry_overload async def _drop_index( self, index_or_name: _IndexKeyHint, @@ -3072,6 +3075,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ) @_csot.apply + @_retry_overload async def rename( self, new_name: str, diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 8e0afc9dc..9b8486931 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -38,6 +38,7 @@ from pymongo.asynchronous.aggregation import _DatabaseAggregationCommand from pymongo.asynchronous.change_stream import AsyncDatabaseChangeStream from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.helpers import _retry_overload from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.database_shared import _check_name, _CodecDocumentType from pymongo.errors import CollectionInvalid, InvalidOperation @@ -135,6 +136,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): self._name = name self._client: AsyncMongoClient[_DocumentType] = client self._timeout = client.options.timeout + self._retry_policy = client._retry_policy @property def client(self) -> AsyncMongoClient[_DocumentType]: @@ -477,6 +479,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): return change_stream @_csot.apply + @_retry_overload async def create_collection( self, name: str, @@ -819,6 +822,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): ... @_csot.apply + @_retry_overload async def command( self, command: Union[str, MutableMapping[str, Any]], @@ -950,6 +954,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): ) @_csot.apply + @_retry_overload async def cursor_command( self, command: Union[str, MutableMapping[str, Any]], @@ -1265,6 +1270,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): ) @_csot.apply + @_retry_overload async def drop_collection( self, name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]], diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 4a8c91813..4ff3caa10 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -16,7 +16,12 @@ from __future__ import annotations import asyncio +import builtins +import functools +import random import socket +import time +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -24,10 +29,13 @@ from typing import ( cast, ) +from pymongo import _csot from pymongo.errors import ( OperationFailure, + PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE +from pymongo.lock import _async_create_lock _IS_SYNC = False @@ -36,6 +44,7 @@ F = TypeVar("F", bound=Callable[..., Any]) def _handle_reauth(func: F) -> F: + @functools.wraps(func) async def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.asynchronous.pool import AsyncConnection @@ -68,6 +77,123 @@ def _handle_reauth(func: F) -> F: return cast(F, inner) +_MAX_RETRIES = 3 +_BACKOFF_INITIAL = 0.05 +_BACKOFF_MAX = 10 +# DRIVERS-3240 will determine these defaults. +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 +DEFAULT_RETRY_TOKEN_RETURN = 0.1 + + +def _backoff( + attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX +) -> float: + jitter = random.random() # noqa: S311 + return jitter * min(initial_delay * (2**attempt), max_delay) + + +class _TokenBucket: + """A token bucket implementation for rate limiting.""" + + def __init__( + self, + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, + ): + self.lock = _async_create_lock() + self.capacity = capacity + # DRIVERS-3240 will determine how full the bucket should start. + self.tokens = capacity + self.return_rate = return_rate + + async def consume(self) -> bool: + """Consume a token from the bucket if available.""" + async with self.lock: + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + async def deposit(self, retry: bool = False) -> None: + """Deposit a token back into the bucket.""" + retry_token = 1 if retry else 0 + async with self.lock: + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter. + + Retry attempts are limited by a token bucket to prevent overwhelming the server during + a prolonged outage or high load. + """ + + def __init__( + self, + token_bucket: _TokenBucket, + attempts: int = _MAX_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.token_bucket = token_bucket + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + async def record_success(self, retry: bool) -> None: + """Record a successful operation.""" + await self.token_bucket.deposit(retry) + + def backoff(self, attempt: int) -> float: + """Return the backoff duration for the given .""" + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + async def should_retry(self, attempt: int, delay: float) -> bool: + """Return if we have budget to retry and how long to backoff.""" + if attempt > self.attempts: + return False + + # If the delay would exceed the deadline, bail early before consuming a token. + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + + # Check token bucket last since we only want to consume a token if we actually retry. + if not await self.token_bucket.consume(): + # DRIVERS-3246 Improve diagnostics when this case happens. + # We could add info to the exception and log. + return False + return True + + +def _retry_overload(func: F) -> F: + @functools.wraps(func) + async def inner(self: Any, *args: Any, **kwargs: Any) -> Any: + retry_policy = self._retry_policy + attempt = 0 + while True: + try: + res = await func(self, *args, **kwargs) + await retry_policy.record_success(retry=attempt > 0) + return res + except PyMongoError as exc: + if not exc.has_error_label("RetryableError"): + raise + attempt += 1 + delay = 0 + if exc.has_error_label("SystemOverloadedError"): + delay = retry_policy.backoff(attempt) + if not await retry_policy.should_retry(attempt, delay): + raise + + # Implement exponential backoff on retry. + if delay: + await asyncio.sleep(delay) + continue + + return cast(F, inner) + + async def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 2a8ff4339..486e00ae4 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -35,6 +35,7 @@ from __future__ import annotations import asyncio import contextlib import os +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -67,6 +68,11 @@ from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterCh from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.helpers import ( + _retry_overload, + _RetryPolicy, + _TokenBucket, +) from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions @@ -773,6 +779,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self._timeout: float | None = None self._topology_settings: TopologySettings = None # type: ignore[assignment] self._event_listeners: _EventListeners | None = None + self._retry_policy = _RetryPolicy(_TokenBucket()) # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -2396,6 +2403,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): return [doc["name"] async for doc in res] @_csot.apply + @_retry_overload async def drop_database( self, name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]], @@ -2733,9 +2741,10 @@ class _ClientConnectionRetryable(Generic[T]): ): self._last_error: Optional[Exception] = None self._retrying = False + self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2770,7 +2779,9 @@ class _ClientConnectionRetryable(Generic[T]): while True: self._check_last_error(check_csot=True) try: - return await self._read() if self._is_read else await self._write() + res = await self._read() if self._is_read else await self._write() + await self._retry_policy.record_success(self._attempt_number > 0) + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2781,14 +2792,22 @@ class _ClientConnectionRetryable(Generic[T]): # most likely be a waste of time. raise except PyMongoError as exc: + always_retryable = False + overloaded = False + exc_to_check = exc # Execute specialized catch on read if self._is_read: if isinstance(exc, (ConnectionFailure, OperationFailure)): # ConnectionFailures do not supply a code property exc_code = getattr(exc, "code", None) - if self._is_not_eligible_for_retry() or ( - isinstance(exc, OperationFailure) - and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + always_retryable = exc.has_error_label("RetryableError") + overloaded = exc.has_error_label("SystemOverloadedError") + if not always_retryable and ( + self._is_not_eligible_for_retry() + or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + ) ): raise self._retrying = True @@ -2799,19 +2818,22 @@ class _ClientConnectionRetryable(Generic[T]): # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if isinstance(exc, ClientBulkWriteException) and isinstance( + exc.error, PyMongoError + ): + exc_to_check = exc.error + retryable_write_label = exc_to_check.has_error_label("RetryableWriteError") + always_retryable = exc_to_check.has_error_label("RetryableError") + overloaded = exc_to_check.has_error_label("SystemOverloadedError") + if not self._retryable and not always_retryable: raise - if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = isinstance( - exc.error, PyMongoError - ) and exc.error.has_error_label("RetryableWriteError") - else: - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") - if retryable_write_error_exc: + if retryable_write_label or always_retryable: assert self._session await self._session._unpin() - if not retryable_write_error_exc or self._is_not_eligible_for_retry(): - if exc.has_error_label("NoWritesPerformed") and self._last_error: + if not always_retryable and ( + not retryable_write_label or self._is_not_eligible_for_retry() + ): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise @@ -2820,7 +2842,7 @@ class _ClientConnectionRetryable(Generic[T]): self._bulk.retrying = True else: self._retrying = True - if not exc.has_error_label("NoWritesPerformed"): + if not exc_to_check.has_error_label("NoWritesPerformed"): self._last_error = exc if self._last_error is None: self._last_error = exc @@ -2828,6 +2850,17 @@ class _ClientConnectionRetryable(Generic[T]): if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded: self._deprioritized_servers.append(self._server) + self._always_retryable = always_retryable + if always_retryable: + delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0 + if not await self._retry_policy.should_retry(self._attempt_number, delay): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + if overloaded: + await asyncio.sleep(delay) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2889,7 +2922,7 @@ class _ClientConnectionRetryable(Generic[T]): and conn.supports_sessions ) is_mongos = conn.is_mongos - if not sessions_supported: + if not self._always_retryable and not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() @@ -2921,7 +2954,7 @@ class _ClientConnectionRetryable(Generic[T]): conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: _debug_log( diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 9b547dc94..d5a37eb10 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -136,6 +136,7 @@ Classes from __future__ import annotations import collections +import random import time import uuid from collections.abc import Mapping as _Mapping @@ -469,6 +470,8 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( # # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 +_BACKOFF_MAX = 1 +_BACKOFF_INITIAL = 0.050 # 50ms initial backoff def _within_time_limit(start_time: float) -> bool: @@ -702,7 +705,13 @@ class ClientSession: https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback """ start_time = time.monotonic() + retry = 0 while True: + if retry: # Implement exponential backoff on retry. + jitter = random.random() # noqa: S311 + backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX) + time.sleep(backoff) + retry += 1 self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 4e5f7d08f..9ee8e6394 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -89,6 +89,7 @@ from pymongo.synchronous.cursor import ( Cursor, RawBatchCursor, ) +from pymongo.synchronous.helpers import _retry_overload from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean @@ -255,6 +256,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): unicode_decode_error_handler="replace", document_class=dict ) self._timeout = database.client.options.timeout + self._retry_policy = database.client._retry_policy if create or kwargs: if _IS_SYNC: @@ -2224,6 +2226,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): return self._create_indexes(indexes, session, **kwargs) @_csot.apply + @_retry_overload def _create_indexes( self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any ) -> list[str]: @@ -2419,7 +2422,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]): kwargs["comment"] = comment self._drop_index("*", session=session, **kwargs) - @_csot.apply def drop_index( self, index_or_name: _IndexKeyHint, @@ -2469,6 +2471,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): self._drop_index(index_or_name, session, comment, **kwargs) @_csot.apply + @_retry_overload def _drop_index( self, index_or_name: _IndexKeyHint, @@ -3065,6 +3068,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): ) @_csot.apply + @_retry_overload def rename( self, new_name: str, diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 0d129ba97..6877854f4 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -43,6 +43,7 @@ from pymongo.synchronous.aggregation import _DatabaseAggregationCommand from pymongo.synchronous.change_stream import DatabaseChangeStream from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.helpers import _retry_overload from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: @@ -135,6 +136,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): self._name = name self._client: MongoClient[_DocumentType] = client self._timeout = client.options.timeout + self._retry_policy = client._retry_policy @property def client(self) -> MongoClient[_DocumentType]: @@ -477,6 +479,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): return change_stream @_csot.apply + @_retry_overload def create_collection( self, name: str, @@ -819,6 +822,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): ... @_csot.apply + @_retry_overload def command( self, command: Union[str, MutableMapping[str, Any]], @@ -948,6 +952,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): ) @_csot.apply + @_retry_overload def cursor_command( self, command: Union[str, MutableMapping[str, Any]], @@ -1258,6 +1263,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): ) @_csot.apply + @_retry_overload def drop_collection( self, name_or_collection: Union[str, Collection[_DocumentTypeArg]], diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index fea2d6dae..30b8c4fc6 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -35,6 +35,7 @@ from __future__ import annotations import asyncio import contextlib import os +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -110,6 +111,11 @@ from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.helpers import ( + _retry_overload, + _RetryPolicy, + _TokenBucket, +) from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -773,6 +779,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): self._timeout: float | None = None self._topology_settings: TopologySettings = None # type: ignore[assignment] self._event_listeners: _EventListeners | None = None + self._retry_policy = _RetryPolicy(_TokenBucket()) # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -2386,6 +2393,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): return [doc["name"] for doc in res] @_csot.apply + @_retry_overload def drop_database( self, name_or_database: Union[str, database.Database[_DocumentTypeArg]], @@ -2723,9 +2731,10 @@ class _ClientConnectionRetryable(Generic[T]): ): self._last_error: Optional[Exception] = None self._retrying = False + self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2760,7 +2769,9 @@ class _ClientConnectionRetryable(Generic[T]): while True: self._check_last_error(check_csot=True) try: - return self._read() if self._is_read else self._write() + res = self._read() if self._is_read else self._write() + self._retry_policy.record_success(self._attempt_number > 0) + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2771,14 +2782,22 @@ class _ClientConnectionRetryable(Generic[T]): # most likely be a waste of time. raise except PyMongoError as exc: + always_retryable = False + overloaded = False + exc_to_check = exc # Execute specialized catch on read if self._is_read: if isinstance(exc, (ConnectionFailure, OperationFailure)): # ConnectionFailures do not supply a code property exc_code = getattr(exc, "code", None) - if self._is_not_eligible_for_retry() or ( - isinstance(exc, OperationFailure) - and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + always_retryable = exc.has_error_label("RetryableError") + overloaded = exc.has_error_label("SystemOverloadedError") + if not always_retryable and ( + self._is_not_eligible_for_retry() + or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + ) ): raise self._retrying = True @@ -2789,19 +2808,22 @@ class _ClientConnectionRetryable(Generic[T]): # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if isinstance(exc, ClientBulkWriteException) and isinstance( + exc.error, PyMongoError + ): + exc_to_check = exc.error + retryable_write_label = exc_to_check.has_error_label("RetryableWriteError") + always_retryable = exc_to_check.has_error_label("RetryableError") + overloaded = exc_to_check.has_error_label("SystemOverloadedError") + if not self._retryable and not always_retryable: raise - if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = isinstance( - exc.error, PyMongoError - ) and exc.error.has_error_label("RetryableWriteError") - else: - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") - if retryable_write_error_exc: + if retryable_write_label or always_retryable: assert self._session self._session._unpin() - if not retryable_write_error_exc or self._is_not_eligible_for_retry(): - if exc.has_error_label("NoWritesPerformed") and self._last_error: + if not always_retryable and ( + not retryable_write_label or self._is_not_eligible_for_retry() + ): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise @@ -2810,7 +2832,7 @@ class _ClientConnectionRetryable(Generic[T]): self._bulk.retrying = True else: self._retrying = True - if not exc.has_error_label("NoWritesPerformed"): + if not exc_to_check.has_error_label("NoWritesPerformed"): self._last_error = exc if self._last_error is None: self._last_error = exc @@ -2818,6 +2840,17 @@ class _ClientConnectionRetryable(Generic[T]): if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded: self._deprioritized_servers.append(self._server) + self._always_retryable = always_retryable + if always_retryable: + delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0 + if not self._retry_policy.should_retry(self._attempt_number, delay): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + if overloaded: + time.sleep(delay) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2879,7 +2912,7 @@ class _ClientConnectionRetryable(Generic[T]): and conn.supports_sessions ) is_mongos = conn.is_mongos - if not sessions_supported: + if not self._always_retryable and not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() @@ -2911,7 +2944,7 @@ class _ClientConnectionRetryable(Generic[T]): conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: _debug_log( diff --git a/test/asynchronous/test_backpressure.py b/test/asynchronous/test_backpressure.py new file mode 100644 index 000000000..11f8edde6 --- /dev/null +++ b/test/asynchronous/test_backpressure.py @@ -0,0 +1,230 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Client Backpressure spec.""" +from __future__ import annotations + +import asyncio +import sys + +import pymongo + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + async_client_context, + unittest, +) + +from pymongo.asynchronous import helpers +from pymongo.asynchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket +from pymongo.errors import PyMongoError + +_IS_SYNC = False + +# Mock an system overload error. +mock_overload_error = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "insert", "update"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError"], + }, +} + + +class TestBackpressure(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_command(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.command("find", "t") + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.command("find", "t") + + self.assertIn("RetryableError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_find(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.find_one() + + self.assertIn("RetryableError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_insert_one(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.find_one() + + self.assertIn("RetryableError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_update_many(self): + # Even though update_many is not a retryable write operation, it will + # still be retried via the "RetryableError" error label. + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.update_many({}, {"$set": {"x": 2}}) + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.update_many({}, {"$set": {"x": 2}}) + + self.assertIn("RetryableError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_getMore(self): + coll = self.db.t + await coll.insert_many([{"x": 1} for _ in range(10)]) + + # Ensure command is retried on overload error. + fail_many = { + "configureFailPoint": "failCommand", + "mode": {"times": _MAX_RETRIES}, + "data": { + "failCommands": ["getMore"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError"], + }, + } + cursor = coll.find(batch_size=2) + await cursor.next() + async with self.fail_point(fail_many): + await cursor.to_list() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = fail_many.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + cursor = coll.find(batch_size=2) + await cursor.next() + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await cursor.to_list() + + self.assertIn("RetryableError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_limit_retry_command(self): + client = await self.async_rs_or_single_client() + client._retry_policy.token_bucket.tokens = 1 + db = client.pymongo_test + await db.t.insert_one({"x": 1}) + + # Ensure command is retried once overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": 1} + async with self.fail_point(fail_many): + await db.command("find", "t") + + # Ensure command stops retrying when there are no tokens left. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": 2} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await db.command("find", "t") + + self.assertIn("RetryableError", str(error.exception)) + + +class TestRetryPolicy(AsyncPyMongoTestCase): + async def test_retry_policy(self): + capacity = 10 + retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity)) + self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES) + self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) + self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) + for i in range(1, helpers._MAX_RETRIES + 1): + self.assertTrue(await retry_policy.should_retry(i, 0)) + self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) + for i in range(capacity - helpers._MAX_RETRIES): + self.assertTrue(await retry_policy.should_retry(1, 0)) + # No tokens left, should not retry. + self.assertFalse(await retry_policy.should_retry(1, 0)) + self.assertEqual(retry_policy.token_bucket.tokens, 0) + + # record_success should generate tokens. + for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): + await retry_policy.record_success(retry=False) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) + for i in range(2): + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertFalse(await retry_policy.should_retry(1, 0)) + + # Recording a successful retry should return 1 additional token. + await retry_policy.record_success(retry=True) + self.assertAlmostEqual( + retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN + ) + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertFalse(await retry_policy.should_retry(1, 0)) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + + async def test_retry_policy_csot(self): + retry_policy = _RetryPolicy(_TokenBucket()) + self.assertTrue(await retry_policy.should_retry(1, 0.5)) + with pymongo.timeout(0.5): + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertTrue(await retry_policy.should_retry(1, 0.1)) + # Would exceed the timeout, should not retry. + self.assertFalse(await retry_policy.should_retry(1, 1.0)) + self.assertTrue(await retry_policy.should_retry(1, 1.0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_backpressure.py b/test/test_backpressure.py new file mode 100644 index 000000000..fac1d6236 --- /dev/null +++ b/test/test_backpressure.py @@ -0,0 +1,230 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Client Backpressure spec.""" +from __future__ import annotations + +import asyncio +import sys + +import pymongo + +sys.path[0:0] = [""] + +from test import ( + IntegrationTest, + PyMongoTestCase, + client_context, + unittest, +) + +from pymongo.errors import PyMongoError +from pymongo.synchronous import helpers +from pymongo.synchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket + +_IS_SYNC = True + +# Mock an system overload error. +mock_overload_error = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "insert", "update"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError"], + }, +} + + +class TestBackpressure(IntegrationTest): + RUN_ON_LOAD_BALANCER = True + + @client_context.require_failCommand_appName + def test_retry_overload_error_command(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.command("find", "t") + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.command("find", "t") + + self.assertIn("RetryableError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_find(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.find_one() + + self.assertIn("RetryableError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_insert_one(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.find_one() + + self.assertIn("RetryableError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_update_many(self): + # Even though update_many is not a retryable write operation, it will + # still be retried via the "RetryableError" error label. + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.update_many({}, {"$set": {"x": 2}}) + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.update_many({}, {"$set": {"x": 2}}) + + self.assertIn("RetryableError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_getMore(self): + coll = self.db.t + coll.insert_many([{"x": 1} for _ in range(10)]) + + # Ensure command is retried on overload error. + fail_many = { + "configureFailPoint": "failCommand", + "mode": {"times": _MAX_RETRIES}, + "data": { + "failCommands": ["getMore"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError"], + }, + } + cursor = coll.find(batch_size=2) + cursor.next() + with self.fail_point(fail_many): + cursor.to_list() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = fail_many.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + cursor = coll.find(batch_size=2) + cursor.next() + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + cursor.to_list() + + self.assertIn("RetryableError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_limit_retry_command(self): + client = self.rs_or_single_client() + client._retry_policy.token_bucket.tokens = 1 + db = client.pymongo_test + db.t.insert_one({"x": 1}) + + # Ensure command is retried once overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": 1} + with self.fail_point(fail_many): + db.command("find", "t") + + # Ensure command stops retrying when there are no tokens left. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": 2} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + db.command("find", "t") + + self.assertIn("RetryableError", str(error.exception)) + + +class TestRetryPolicy(PyMongoTestCase): + def test_retry_policy(self): + capacity = 10 + retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity)) + self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES) + self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) + self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) + for i in range(1, helpers._MAX_RETRIES + 1): + self.assertTrue(retry_policy.should_retry(i, 0)) + self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) + for i in range(capacity - helpers._MAX_RETRIES): + self.assertTrue(retry_policy.should_retry(1, 0)) + # No tokens left, should not retry. + self.assertFalse(retry_policy.should_retry(1, 0)) + self.assertEqual(retry_policy.token_bucket.tokens, 0) + + # record_success should generate tokens. + for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): + retry_policy.record_success(retry=False) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) + for i in range(2): + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertFalse(retry_policy.should_retry(1, 0)) + + # Recording a successful retry should return 1 additional token. + retry_policy.record_success(retry=True) + self.assertAlmostEqual( + retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN + ) + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertFalse(retry_policy.should_retry(1, 0)) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + + def test_retry_policy_csot(self): + retry_policy = _RetryPolicy(_TokenBucket()) + self.assertTrue(retry_policy.should_retry(1, 0.5)) + with pymongo.timeout(0.5): + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertTrue(retry_policy.should_retry(1, 0.1)) + # Would exceed the timeout, should not retry. + self.assertFalse(retry_policy.should_retry(1, 1.0)) + self.assertTrue(retry_policy.should_retry(1, 1.0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index 1444b2299..492fbf428 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -209,6 +209,7 @@ converted_tests = [ "test_auth_oidc.py", "test_auth_spec.py", "test_bulk.py", + "test_backpressure.py", "test_change_stream.py", "test_client.py", "test_client_bulk_write.py",