Compare commits
7 Commits
master
...
revert-250
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aeb6bcfe76 | ||
|
|
27785aede6 | ||
|
|
d267eb4833 | ||
|
|
c458379522 | ||
|
|
875c5640d7 | ||
|
|
75eee91818 | ||
|
|
cf7a1aaa2f |
@ -135,7 +135,9 @@ Classes
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping as _Mapping
|
||||
@ -471,6 +473,8 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
|
||||
# This limit is non-configurable and was chosen to be twice the 60 second
|
||||
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
|
||||
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
|
||||
_BACKOFF_MAX = 1
|
||||
_BACKOFF_INITIAL = 0.050 # 50ms initial backoff
|
||||
|
||||
|
||||
def _within_time_limit(start_time: float) -> bool:
|
||||
@ -700,7 +704,13 @@ class AsyncClientSession:
|
||||
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
retry = 0
|
||||
while True:
|
||||
if retry: # Implement exponential backoff on retry.
|
||||
jitter = random.random() # noqa: S311
|
||||
backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX)
|
||||
await asyncio.sleep(backoff)
|
||||
retry += 1
|
||||
await self.start_transaction(
|
||||
read_concern, write_concern, read_preference, max_commit_time_ms
|
||||
)
|
||||
|
||||
@ -58,6 +58,7 @@ from pymongo.asynchronous.cursor import (
|
||||
AsyncCursor,
|
||||
AsyncRawBatchCursor,
|
||||
)
|
||||
from pymongo.asynchronous.helpers import _retry_overload
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.errors import (
|
||||
@ -252,6 +253,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
unicode_decode_error_handler="replace", document_class=dict
|
||||
)
|
||||
self._timeout = database.client.options.timeout
|
||||
self._retry_policy = database.client._retry_policy
|
||||
|
||||
if create or kwargs:
|
||||
if _IS_SYNC:
|
||||
@ -2227,6 +2229,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
return await self._create_indexes(indexes, session, **kwargs)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
async def _create_indexes(
|
||||
self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any
|
||||
) -> list[str]:
|
||||
@ -2422,7 +2425,6 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
kwargs["comment"] = comment
|
||||
await self._drop_index("*", session=session, **kwargs)
|
||||
|
||||
@_csot.apply
|
||||
async def drop_index(
|
||||
self,
|
||||
index_or_name: _IndexKeyHint,
|
||||
@ -2472,6 +2474,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
await self._drop_index(index_or_name, session, comment, **kwargs)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
async def _drop_index(
|
||||
self,
|
||||
index_or_name: _IndexKeyHint,
|
||||
@ -3079,6 +3082,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
async def rename(
|
||||
self,
|
||||
new_name: str,
|
||||
|
||||
@ -38,6 +38,7 @@ from pymongo.asynchronous.aggregation import _DatabaseAggregationCommand
|
||||
from pymongo.asynchronous.change_stream import AsyncDatabaseChangeStream
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.helpers import _retry_overload
|
||||
from pymongo.common import _ecoc_coll_name, _esc_coll_name
|
||||
from pymongo.database_shared import _check_name, _CodecDocumentType
|
||||
from pymongo.errors import CollectionInvalid, InvalidOperation
|
||||
@ -135,6 +136,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
self._name = name
|
||||
self._client: AsyncMongoClient[_DocumentType] = client
|
||||
self._timeout = client.options.timeout
|
||||
self._retry_policy = client._retry_policy
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncMongoClient[_DocumentType]:
|
||||
@ -477,6 +479,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
return change_stream
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
async def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
@ -816,6 +819,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
...
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
async def command(
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
@ -947,6 +951,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
async def cursor_command(
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
@ -1264,6 +1269,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
async def drop_collection(
|
||||
self,
|
||||
name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]],
|
||||
|
||||
@ -17,8 +17,11 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import functools
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time as time # noqa: PLC0414 # needed in sync version
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -26,10 +29,13 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pymongo import _csot
|
||||
from pymongo.errors import (
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
|
||||
from pymongo.lock import _async_create_lock
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -38,6 +44,7 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _handle_reauth(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
@ -70,6 +77,123 @@ def _handle_reauth(func: F) -> F:
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
_MAX_RETRIES = 5
|
||||
_BACKOFF_INITIAL = 0.1
|
||||
_BACKOFF_MAX = 10
|
||||
# DRIVERS-3240 will determine these defaults.
|
||||
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
|
||||
DEFAULT_RETRY_TOKEN_RETURN = 0.1
|
||||
|
||||
|
||||
def _backoff(
|
||||
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
|
||||
) -> float:
|
||||
jitter = random.random() # noqa: S311
|
||||
return jitter * min(initial_delay * (2**attempt), max_delay)
|
||||
|
||||
|
||||
class _TokenBucket:
|
||||
"""A token bucket implementation for rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
|
||||
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
|
||||
):
|
||||
self.lock = _async_create_lock()
|
||||
self.capacity = capacity
|
||||
# DRIVERS-3240 will determine how full the bucket should start.
|
||||
self.tokens = capacity
|
||||
self.return_rate = return_rate
|
||||
|
||||
async def consume(self) -> bool:
|
||||
"""Consume a token from the bucket if available."""
|
||||
async with self.lock:
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
return True
|
||||
return False
|
||||
|
||||
async def deposit(self, retry: bool = False) -> None:
|
||||
"""Deposit a token back into the bucket."""
|
||||
retry_token = 1 if retry else 0
|
||||
async with self.lock:
|
||||
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)
|
||||
|
||||
|
||||
class _RetryPolicy:
|
||||
"""A retry limiter that performs exponential backoff with jitter.
|
||||
|
||||
Retry attempts are limited by a token bucket to prevent overwhelming the server during
|
||||
a prolonged outage or high load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_bucket: _TokenBucket,
|
||||
attempts: int = _MAX_RETRIES,
|
||||
backoff_initial: float = _BACKOFF_INITIAL,
|
||||
backoff_max: float = _BACKOFF_MAX,
|
||||
):
|
||||
self.token_bucket = token_bucket
|
||||
self.attempts = attempts
|
||||
self.backoff_initial = backoff_initial
|
||||
self.backoff_max = backoff_max
|
||||
|
||||
async def record_success(self, retry: bool) -> None:
|
||||
"""Record a successful operation."""
|
||||
await self.token_bucket.deposit(retry)
|
||||
|
||||
def backoff(self, attempt: int) -> float:
|
||||
"""Return the backoff duration for the given ."""
|
||||
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
|
||||
|
||||
async def should_retry(self, attempt: int, delay: float) -> bool:
|
||||
"""Return if we have budget to retry and how long to backoff."""
|
||||
if attempt > self.attempts:
|
||||
return False
|
||||
|
||||
# If the delay would exceed the deadline, bail early before consuming a token.
|
||||
if _csot.get_timeout():
|
||||
if time.monotonic() + delay > _csot.get_deadline():
|
||||
return False
|
||||
|
||||
# Check token bucket last since we only want to consume a token if we actually retry.
|
||||
if not await self.token_bucket.consume():
|
||||
# DRIVERS-3246 Improve diagnostics when this case happens.
|
||||
# We could add info to the exception and log.
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _retry_overload(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
retry_policy = self._retry_policy
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
res = await func(self, *args, **kwargs)
|
||||
await retry_policy.record_success(retry=attempt > 0)
|
||||
return res
|
||||
except PyMongoError as exc:
|
||||
if not exc.has_error_label("RetryableError"):
|
||||
raise
|
||||
attempt += 1
|
||||
delay = 0
|
||||
if exc.has_error_label("SystemOverloadedError"):
|
||||
delay = retry_policy.backoff(attempt)
|
||||
if not await retry_policy.should_retry(attempt, delay):
|
||||
raise
|
||||
|
||||
# Implement exponential backoff on retry.
|
||||
if delay:
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
async def _getaddrinfo(
|
||||
host: Any, port: Any, **kwargs: Any
|
||||
) -> list[
|
||||
|
||||
@ -35,6 +35,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import time as time # noqa: PLC0414 # needed in sync version
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
@ -67,6 +68,11 @@ from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterCh
|
||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||
from pymongo.asynchronous.client_session import _EmptyServerSession
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.helpers import (
|
||||
_retry_overload,
|
||||
_RetryPolicy,
|
||||
_TokenBucket,
|
||||
)
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology, _ErrorContext
|
||||
from pymongo.client_options import ClientOptions
|
||||
@ -773,6 +779,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._timeout: float | None = None
|
||||
self._topology_settings: TopologySettings = None # type: ignore[assignment]
|
||||
self._event_listeners: _EventListeners | None = None
|
||||
self._retry_policy = _RetryPolicy(_TokenBucket())
|
||||
|
||||
# _pool_class, _monitor_class, and _condition_class are for deep
|
||||
# customization of PyMongo, e.g. Motor.
|
||||
@ -2398,6 +2405,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return [doc["name"] async for doc in res]
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
async def drop_database(
|
||||
self,
|
||||
name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]],
|
||||
@ -2735,9 +2743,10 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
):
|
||||
self._last_error: Optional[Exception] = None
|
||||
self._retrying = False
|
||||
self._always_retryable = False
|
||||
self._multiple_retries = _csot.get_timeout() is not None
|
||||
self._client = mongo_client
|
||||
|
||||
self._retry_policy = mongo_client._retry_policy
|
||||
self._func = func
|
||||
self._bulk = bulk
|
||||
self._session = session
|
||||
@ -2772,7 +2781,9 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
while True:
|
||||
self._check_last_error(check_csot=True)
|
||||
try:
|
||||
return await self._read() if self._is_read else await self._write()
|
||||
res = await self._read() if self._is_read else await self._write()
|
||||
await self._retry_policy.record_success(self._attempt_number > 0)
|
||||
return res
|
||||
except ServerSelectionTimeoutError:
|
||||
# The application may think the write was never attempted
|
||||
# if we raise ServerSelectionTimeoutError on the retry
|
||||
@ -2783,14 +2794,22 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
# most likely be a waste of time.
|
||||
raise
|
||||
except PyMongoError as exc:
|
||||
always_retryable = False
|
||||
overloaded = False
|
||||
exc_to_check = exc
|
||||
# Execute specialized catch on read
|
||||
if self._is_read:
|
||||
if isinstance(exc, (ConnectionFailure, OperationFailure)):
|
||||
# ConnectionFailures do not supply a code property
|
||||
exc_code = getattr(exc, "code", None)
|
||||
if self._is_not_eligible_for_retry() or (
|
||||
isinstance(exc, OperationFailure)
|
||||
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
|
||||
always_retryable = exc.has_error_label("RetryableError")
|
||||
overloaded = exc.has_error_label("SystemOverloadedError")
|
||||
if not always_retryable and (
|
||||
self._is_not_eligible_for_retry()
|
||||
or (
|
||||
isinstance(exc, OperationFailure)
|
||||
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
|
||||
)
|
||||
):
|
||||
raise
|
||||
self._retrying = True
|
||||
@ -2801,19 +2820,22 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
|
||||
# Specialized catch on write operation
|
||||
if not self._is_read:
|
||||
if not self._retryable:
|
||||
if isinstance(exc, ClientBulkWriteException) and isinstance(
|
||||
exc.error, PyMongoError
|
||||
):
|
||||
exc_to_check = exc.error
|
||||
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
|
||||
always_retryable = exc_to_check.has_error_label("RetryableError")
|
||||
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
|
||||
if not self._retryable and not always_retryable:
|
||||
raise
|
||||
if isinstance(exc, ClientBulkWriteException) and exc.error:
|
||||
retryable_write_error_exc = isinstance(
|
||||
exc.error, PyMongoError
|
||||
) and exc.error.has_error_label("RetryableWriteError")
|
||||
else:
|
||||
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
|
||||
if retryable_write_error_exc:
|
||||
if retryable_write_label or always_retryable:
|
||||
assert self._session
|
||||
await self._session._unpin()
|
||||
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
|
||||
if exc.has_error_label("NoWritesPerformed") and self._last_error:
|
||||
if not always_retryable and (
|
||||
not retryable_write_label or self._is_not_eligible_for_retry()
|
||||
):
|
||||
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
|
||||
raise self._last_error from exc
|
||||
else:
|
||||
raise
|
||||
@ -2822,7 +2844,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
self._bulk.retrying = True
|
||||
else:
|
||||
self._retrying = True
|
||||
if not exc.has_error_label("NoWritesPerformed"):
|
||||
if not exc_to_check.has_error_label("NoWritesPerformed"):
|
||||
self._last_error = exc
|
||||
if self._last_error is None:
|
||||
self._last_error = exc
|
||||
@ -2830,6 +2852,17 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
|
||||
self._deprioritized_servers.append(self._server)
|
||||
|
||||
self._always_retryable = always_retryable
|
||||
if always_retryable:
|
||||
delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0
|
||||
if not await self._retry_policy.should_retry(self._attempt_number, delay):
|
||||
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
|
||||
raise self._last_error from exc
|
||||
else:
|
||||
raise
|
||||
if overloaded:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
def _is_not_eligible_for_retry(self) -> bool:
|
||||
"""Checks if the exchange is not eligible for retry"""
|
||||
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
|
||||
@ -2891,7 +2924,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
and conn.supports_sessions
|
||||
)
|
||||
is_mongos = conn.is_mongos
|
||||
if not sessions_supported:
|
||||
if not self._always_retryable and not sessions_supported:
|
||||
# A retry is not possible because this server does
|
||||
# not support sessions raise the last error.
|
||||
self._check_last_error()
|
||||
@ -2923,7 +2956,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
conn,
|
||||
read_pref,
|
||||
):
|
||||
if self._retrying and not self._retryable:
|
||||
if self._retrying and not self._retryable and not self._always_retryable:
|
||||
self._check_last_error()
|
||||
if self._retrying:
|
||||
_debug_log(
|
||||
|
||||
@ -136,6 +136,7 @@ Classes
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping as _Mapping
|
||||
@ -470,6 +471,8 @@ _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( #
|
||||
# This limit is non-configurable and was chosen to be twice the 60 second
|
||||
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
|
||||
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
|
||||
_BACKOFF_MAX = 1
|
||||
_BACKOFF_INITIAL = 0.050 # 50ms initial backoff
|
||||
|
||||
|
||||
def _within_time_limit(start_time: float) -> bool:
|
||||
@ -699,7 +702,13 @@ class ClientSession:
|
||||
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
retry = 0
|
||||
while True:
|
||||
if retry: # Implement exponential backoff on retry.
|
||||
jitter = random.random() # noqa: S311
|
||||
backoff = jitter * min(_BACKOFF_INITIAL * (2**retry), _BACKOFF_MAX)
|
||||
time.sleep(backoff)
|
||||
retry += 1
|
||||
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
|
||||
try:
|
||||
ret = callback(self)
|
||||
|
||||
@ -89,6 +89,7 @@ from pymongo.synchronous.cursor import (
|
||||
Cursor,
|
||||
RawBatchCursor,
|
||||
)
|
||||
from pymongo.synchronous.helpers import _retry_overload
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean
|
||||
|
||||
@ -255,6 +256,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
unicode_decode_error_handler="replace", document_class=dict
|
||||
)
|
||||
self._timeout = database.client.options.timeout
|
||||
self._retry_policy = database.client._retry_policy
|
||||
|
||||
if create or kwargs:
|
||||
if _IS_SYNC:
|
||||
@ -2224,6 +2226,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
return self._create_indexes(indexes, session, **kwargs)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
def _create_indexes(
|
||||
self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any
|
||||
) -> list[str]:
|
||||
@ -2419,7 +2422,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
kwargs["comment"] = comment
|
||||
self._drop_index("*", session=session, **kwargs)
|
||||
|
||||
@_csot.apply
|
||||
def drop_index(
|
||||
self,
|
||||
index_or_name: _IndexKeyHint,
|
||||
@ -2469,6 +2471,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
self._drop_index(index_or_name, session, comment, **kwargs)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
def _drop_index(
|
||||
self,
|
||||
index_or_name: _IndexKeyHint,
|
||||
@ -3072,6 +3075,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
def rename(
|
||||
self,
|
||||
new_name: str,
|
||||
|
||||
@ -43,6 +43,7 @@ from pymongo.synchronous.aggregation import _DatabaseAggregationCommand
|
||||
from pymongo.synchronous.change_stream import DatabaseChangeStream
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.helpers import _retry_overload
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -135,6 +136,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
self._name = name
|
||||
self._client: MongoClient[_DocumentType] = client
|
||||
self._timeout = client.options.timeout
|
||||
self._retry_policy = client._retry_policy
|
||||
|
||||
@property
|
||||
def client(self) -> MongoClient[_DocumentType]:
|
||||
@ -477,6 +479,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
return change_stream
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
@ -816,6 +819,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
...
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
def command(
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
@ -945,6 +949,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
def cursor_command(
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
@ -1257,6 +1262,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
def drop_collection(
|
||||
self,
|
||||
name_or_collection: Union[str, Collection[_DocumentTypeArg]],
|
||||
|
||||
@ -17,8 +17,11 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import functools
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time as time # noqa: PLC0414 # needed in sync version
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -26,10 +29,13 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pymongo import _csot
|
||||
from pymongo.errors import (
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
|
||||
from pymongo.lock import _create_lock
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -38,6 +44,7 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _handle_reauth(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.message import _BulkWriteContext
|
||||
@ -70,6 +77,123 @@ def _handle_reauth(func: F) -> F:
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
_MAX_RETRIES = 5
|
||||
_BACKOFF_INITIAL = 0.1
|
||||
_BACKOFF_MAX = 10
|
||||
# DRIVERS-3240 will determine these defaults.
|
||||
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
|
||||
DEFAULT_RETRY_TOKEN_RETURN = 0.1
|
||||
|
||||
|
||||
def _backoff(
|
||||
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
|
||||
) -> float:
|
||||
jitter = random.random() # noqa: S311
|
||||
return jitter * min(initial_delay * (2**attempt), max_delay)
|
||||
|
||||
|
||||
class _TokenBucket:
|
||||
"""A token bucket implementation for rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
|
||||
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
|
||||
):
|
||||
self.lock = _create_lock()
|
||||
self.capacity = capacity
|
||||
# DRIVERS-3240 will determine how full the bucket should start.
|
||||
self.tokens = capacity
|
||||
self.return_rate = return_rate
|
||||
|
||||
def consume(self) -> bool:
|
||||
"""Consume a token from the bucket if available."""
|
||||
with self.lock:
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def deposit(self, retry: bool = False) -> None:
|
||||
"""Deposit a token back into the bucket."""
|
||||
retry_token = 1 if retry else 0
|
||||
with self.lock:
|
||||
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)
|
||||
|
||||
|
||||
class _RetryPolicy:
|
||||
"""A retry limiter that performs exponential backoff with jitter.
|
||||
|
||||
Retry attempts are limited by a token bucket to prevent overwhelming the server during
|
||||
a prolonged outage or high load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_bucket: _TokenBucket,
|
||||
attempts: int = _MAX_RETRIES,
|
||||
backoff_initial: float = _BACKOFF_INITIAL,
|
||||
backoff_max: float = _BACKOFF_MAX,
|
||||
):
|
||||
self.token_bucket = token_bucket
|
||||
self.attempts = attempts
|
||||
self.backoff_initial = backoff_initial
|
||||
self.backoff_max = backoff_max
|
||||
|
||||
def record_success(self, retry: bool) -> None:
|
||||
"""Record a successful operation."""
|
||||
self.token_bucket.deposit(retry)
|
||||
|
||||
def backoff(self, attempt: int) -> float:
|
||||
"""Return the backoff duration for the given ."""
|
||||
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
|
||||
|
||||
def should_retry(self, attempt: int, delay: float) -> bool:
|
||||
"""Return if we have budget to retry and how long to backoff."""
|
||||
if attempt > self.attempts:
|
||||
return False
|
||||
|
||||
# If the delay would exceed the deadline, bail early before consuming a token.
|
||||
if _csot.get_timeout():
|
||||
if time.monotonic() + delay > _csot.get_deadline():
|
||||
return False
|
||||
|
||||
# Check token bucket last since we only want to consume a token if we actually retry.
|
||||
if not self.token_bucket.consume():
|
||||
# DRIVERS-3246 Improve diagnostics when this case happens.
|
||||
# We could add info to the exception and log.
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _retry_overload(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
retry_policy = self._retry_policy
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
res = func(self, *args, **kwargs)
|
||||
retry_policy.record_success(retry=attempt > 0)
|
||||
return res
|
||||
except PyMongoError as exc:
|
||||
if not exc.has_error_label("RetryableError"):
|
||||
raise
|
||||
attempt += 1
|
||||
delay = 0
|
||||
if exc.has_error_label("SystemOverloadedError"):
|
||||
delay = retry_policy.backoff(attempt)
|
||||
if not retry_policy.should_retry(attempt, delay):
|
||||
raise
|
||||
|
||||
# Implement exponential backoff on retry.
|
||||
if delay:
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
def _getaddrinfo(
|
||||
host: Any, port: Any, **kwargs: Any
|
||||
) -> list[
|
||||
|
||||
@ -35,6 +35,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import time as time # noqa: PLC0414 # needed in sync version
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
@ -110,6 +111,11 @@ from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
||||
from pymongo.synchronous.client_bulk import _ClientBulk
|
||||
from pymongo.synchronous.client_session import _EmptyServerSession
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.helpers import (
|
||||
_retry_overload,
|
||||
_RetryPolicy,
|
||||
_TokenBucket,
|
||||
)
|
||||
from pymongo.synchronous.settings import TopologySettings
|
||||
from pymongo.synchronous.topology import Topology, _ErrorContext
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
|
||||
@ -773,6 +779,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._timeout: float | None = None
|
||||
self._topology_settings: TopologySettings = None # type: ignore[assignment]
|
||||
self._event_listeners: _EventListeners | None = None
|
||||
self._retry_policy = _RetryPolicy(_TokenBucket())
|
||||
|
||||
# _pool_class, _monitor_class, and _condition_class are for deep
|
||||
# customization of PyMongo, e.g. Motor.
|
||||
@ -2388,6 +2395,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return [doc["name"] for doc in res]
|
||||
|
||||
@_csot.apply
|
||||
@_retry_overload
|
||||
def drop_database(
|
||||
self,
|
||||
name_or_database: Union[str, database.Database[_DocumentTypeArg]],
|
||||
@ -2725,9 +2733,10 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
):
|
||||
self._last_error: Optional[Exception] = None
|
||||
self._retrying = False
|
||||
self._always_retryable = False
|
||||
self._multiple_retries = _csot.get_timeout() is not None
|
||||
self._client = mongo_client
|
||||
|
||||
self._retry_policy = mongo_client._retry_policy
|
||||
self._func = func
|
||||
self._bulk = bulk
|
||||
self._session = session
|
||||
@ -2762,7 +2771,9 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
while True:
|
||||
self._check_last_error(check_csot=True)
|
||||
try:
|
||||
return self._read() if self._is_read else self._write()
|
||||
res = self._read() if self._is_read else self._write()
|
||||
self._retry_policy.record_success(self._attempt_number > 0)
|
||||
return res
|
||||
except ServerSelectionTimeoutError:
|
||||
# The application may think the write was never attempted
|
||||
# if we raise ServerSelectionTimeoutError on the retry
|
||||
@ -2773,14 +2784,22 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
# most likely be a waste of time.
|
||||
raise
|
||||
except PyMongoError as exc:
|
||||
always_retryable = False
|
||||
overloaded = False
|
||||
exc_to_check = exc
|
||||
# Execute specialized catch on read
|
||||
if self._is_read:
|
||||
if isinstance(exc, (ConnectionFailure, OperationFailure)):
|
||||
# ConnectionFailures do not supply a code property
|
||||
exc_code = getattr(exc, "code", None)
|
||||
if self._is_not_eligible_for_retry() or (
|
||||
isinstance(exc, OperationFailure)
|
||||
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
|
||||
always_retryable = exc.has_error_label("RetryableError")
|
||||
overloaded = exc.has_error_label("SystemOverloadedError")
|
||||
if not always_retryable and (
|
||||
self._is_not_eligible_for_retry()
|
||||
or (
|
||||
isinstance(exc, OperationFailure)
|
||||
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
|
||||
)
|
||||
):
|
||||
raise
|
||||
self._retrying = True
|
||||
@ -2791,19 +2810,22 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
|
||||
# Specialized catch on write operation
|
||||
if not self._is_read:
|
||||
if not self._retryable:
|
||||
if isinstance(exc, ClientBulkWriteException) and isinstance(
|
||||
exc.error, PyMongoError
|
||||
):
|
||||
exc_to_check = exc.error
|
||||
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
|
||||
always_retryable = exc_to_check.has_error_label("RetryableError")
|
||||
overloaded = exc_to_check.has_error_label("SystemOverloadedError")
|
||||
if not self._retryable and not always_retryable:
|
||||
raise
|
||||
if isinstance(exc, ClientBulkWriteException) and exc.error:
|
||||
retryable_write_error_exc = isinstance(
|
||||
exc.error, PyMongoError
|
||||
) and exc.error.has_error_label("RetryableWriteError")
|
||||
else:
|
||||
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
|
||||
if retryable_write_error_exc:
|
||||
if retryable_write_label or always_retryable:
|
||||
assert self._session
|
||||
self._session._unpin()
|
||||
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
|
||||
if exc.has_error_label("NoWritesPerformed") and self._last_error:
|
||||
if not always_retryable and (
|
||||
not retryable_write_label or self._is_not_eligible_for_retry()
|
||||
):
|
||||
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
|
||||
raise self._last_error from exc
|
||||
else:
|
||||
raise
|
||||
@ -2812,7 +2834,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
self._bulk.retrying = True
|
||||
else:
|
||||
self._retrying = True
|
||||
if not exc.has_error_label("NoWritesPerformed"):
|
||||
if not exc_to_check.has_error_label("NoWritesPerformed"):
|
||||
self._last_error = exc
|
||||
if self._last_error is None:
|
||||
self._last_error = exc
|
||||
@ -2820,6 +2842,17 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
|
||||
self._deprioritized_servers.append(self._server)
|
||||
|
||||
self._always_retryable = always_retryable
|
||||
if always_retryable:
|
||||
delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0
|
||||
if not self._retry_policy.should_retry(self._attempt_number, delay):
|
||||
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
|
||||
raise self._last_error from exc
|
||||
else:
|
||||
raise
|
||||
if overloaded:
|
||||
time.sleep(delay)
|
||||
|
||||
def _is_not_eligible_for_retry(self) -> bool:
|
||||
"""Checks if the exchange is not eligible for retry"""
|
||||
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
|
||||
@ -2881,7 +2914,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
and conn.supports_sessions
|
||||
)
|
||||
is_mongos = conn.is_mongos
|
||||
if not sessions_supported:
|
||||
if not self._always_retryable and not sessions_supported:
|
||||
# A retry is not possible because this server does
|
||||
# not support sessions raise the last error.
|
||||
self._check_last_error()
|
||||
@ -2913,7 +2946,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
conn,
|
||||
read_pref,
|
||||
):
|
||||
if self._retrying and not self._retryable:
|
||||
if self._retrying and not self._retryable and not self._always_retryable:
|
||||
self._check_last_error()
|
||||
if self._retrying:
|
||||
_debug_log(
|
||||
|
||||
230
test/asynchronous/test_backpressure.py
Normal file
230
test/asynchronous/test_backpressure.py
Normal file
@ -0,0 +1,230 @@
|
||||
# Copyright 2025-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test Client Backpressure spec."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import pymongo
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
AsyncPyMongoTestCase,
|
||||
async_client_context,
|
||||
unittest,
|
||||
)
|
||||
|
||||
from pymongo.asynchronous import helpers
|
||||
from pymongo.asynchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket
|
||||
from pymongo.errors import PyMongoError
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Mock an system overload error.
|
||||
mock_overload_error = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["find", "insert", "update"],
|
||||
"errorCode": 462, # IngressRequestRateLimitExceeded
|
||||
"errorLabels": ["RetryableError"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestBackpressure(AsyncIntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_retry_overload_error_command(self):
|
||||
await self.db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": _MAX_RETRIES}
|
||||
async with self.fail_point(fail_many):
|
||||
await self.db.command("find", "t")
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
async with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
await self.db.command("find", "t")
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_retry_overload_error_find(self):
|
||||
await self.db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": _MAX_RETRIES}
|
||||
async with self.fail_point(fail_many):
|
||||
await self.db.t.find_one()
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
async with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
await self.db.t.find_one()
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_retry_overload_error_insert_one(self):
|
||||
await self.db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": _MAX_RETRIES}
|
||||
async with self.fail_point(fail_many):
|
||||
await self.db.t.find_one()
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
async with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
await self.db.t.find_one()
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_retry_overload_error_update_many(self):
|
||||
# Even though update_many is not a retryable write operation, it will
|
||||
# still be retried via the "RetryableError" error label.
|
||||
await self.db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": _MAX_RETRIES}
|
||||
async with self.fail_point(fail_many):
|
||||
await self.db.t.update_many({}, {"$set": {"x": 2}})
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
async with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
await self.db.t.update_many({}, {"$set": {"x": 2}})
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_retry_overload_error_getMore(self):
|
||||
coll = self.db.t
|
||||
await coll.insert_many([{"x": 1} for _ in range(10)])
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": _MAX_RETRIES},
|
||||
"data": {
|
||||
"failCommands": ["getMore"],
|
||||
"errorCode": 462, # IngressRequestRateLimitExceeded
|
||||
"errorLabels": ["RetryableError"],
|
||||
},
|
||||
}
|
||||
cursor = coll.find(batch_size=2)
|
||||
await cursor.next()
|
||||
async with self.fail_point(fail_many):
|
||||
await cursor.to_list()
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = fail_many.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
cursor = coll.find(batch_size=2)
|
||||
await cursor.next()
|
||||
async with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
await cursor.to_list()
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_limit_retry_command(self):
|
||||
client = await self.async_rs_or_single_client()
|
||||
client._retry_policy.token_bucket.tokens = 1
|
||||
db = client.pymongo_test
|
||||
await db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried once overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": 1}
|
||||
async with self.fail_point(fail_many):
|
||||
await db.command("find", "t")
|
||||
|
||||
# Ensure command stops retrying when there are no tokens left.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": 2}
|
||||
async with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
await db.command("find", "t")
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
|
||||
class TestRetryPolicy(AsyncPyMongoTestCase):
|
||||
async def test_retry_policy(self):
|
||||
capacity = 10
|
||||
retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity))
|
||||
self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES)
|
||||
self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL)
|
||||
self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX)
|
||||
for i in range(1, helpers._MAX_RETRIES + 1):
|
||||
self.assertTrue(await retry_policy.should_retry(i, 0))
|
||||
self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0))
|
||||
for i in range(capacity - helpers._MAX_RETRIES):
|
||||
self.assertTrue(await retry_policy.should_retry(1, 0))
|
||||
# No tokens left, should not retry.
|
||||
self.assertFalse(await retry_policy.should_retry(1, 0))
|
||||
self.assertEqual(retry_policy.token_bucket.tokens, 0)
|
||||
|
||||
# record_success should generate tokens.
|
||||
for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)):
|
||||
await retry_policy.record_success(retry=False)
|
||||
self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2)
|
||||
for i in range(2):
|
||||
self.assertTrue(await retry_policy.should_retry(1, 0))
|
||||
self.assertFalse(await retry_policy.should_retry(1, 0))
|
||||
|
||||
# Recording a successful retry should return 1 additional token.
|
||||
await retry_policy.record_success(retry=True)
|
||||
self.assertAlmostEqual(
|
||||
retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN
|
||||
)
|
||||
self.assertTrue(await retry_policy.should_retry(1, 0))
|
||||
self.assertFalse(await retry_policy.should_retry(1, 0))
|
||||
self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN)
|
||||
|
||||
async def test_retry_policy_csot(self):
|
||||
retry_policy = _RetryPolicy(_TokenBucket())
|
||||
self.assertTrue(await retry_policy.should_retry(1, 0.5))
|
||||
with pymongo.timeout(0.5):
|
||||
self.assertTrue(await retry_policy.should_retry(1, 0))
|
||||
self.assertTrue(await retry_policy.should_retry(1, 0.1))
|
||||
# Would exceed the timeout, should not retry.
|
||||
self.assertFalse(await retry_policy.should_retry(1, 1.0))
|
||||
self.assertTrue(await retry_policy.should_retry(1, 1.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
230
test/test_backpressure.py
Normal file
230
test/test_backpressure.py
Normal file
@ -0,0 +1,230 @@
|
||||
# Copyright 2025-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test Client Backpressure spec."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import pymongo
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
PyMongoTestCase,
|
||||
client_context,
|
||||
unittest,
|
||||
)
|
||||
|
||||
from pymongo.errors import PyMongoError
|
||||
from pymongo.synchronous import helpers
|
||||
from pymongo.synchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Mock an system overload error.
|
||||
mock_overload_error = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["find", "insert", "update"],
|
||||
"errorCode": 462, # IngressRequestRateLimitExceeded
|
||||
"errorLabels": ["RetryableError"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestBackpressure(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
def test_retry_overload_error_command(self):
|
||||
self.db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": _MAX_RETRIES}
|
||||
with self.fail_point(fail_many):
|
||||
self.db.command("find", "t")
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
self.db.command("find", "t")
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
def test_retry_overload_error_find(self):
|
||||
self.db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": _MAX_RETRIES}
|
||||
with self.fail_point(fail_many):
|
||||
self.db.t.find_one()
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
self.db.t.find_one()
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
def test_retry_overload_error_insert_one(self):
|
||||
self.db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": _MAX_RETRIES}
|
||||
with self.fail_point(fail_many):
|
||||
self.db.t.find_one()
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
self.db.t.find_one()
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
def test_retry_overload_error_update_many(self):
|
||||
# Even though update_many is not a retryable write operation, it will
|
||||
# still be retried via the "RetryableError" error label.
|
||||
self.db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": _MAX_RETRIES}
|
||||
with self.fail_point(fail_many):
|
||||
self.db.t.update_many({}, {"$set": {"x": 2}})
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
self.db.t.update_many({}, {"$set": {"x": 2}})
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
def test_retry_overload_error_getMore(self):
|
||||
coll = self.db.t
|
||||
coll.insert_many([{"x": 1} for _ in range(10)])
|
||||
|
||||
# Ensure command is retried on overload error.
|
||||
fail_many = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": _MAX_RETRIES},
|
||||
"data": {
|
||||
"failCommands": ["getMore"],
|
||||
"errorCode": 462, # IngressRequestRateLimitExceeded
|
||||
"errorLabels": ["RetryableError"],
|
||||
},
|
||||
}
|
||||
cursor = coll.find(batch_size=2)
|
||||
cursor.next()
|
||||
with self.fail_point(fail_many):
|
||||
cursor.to_list()
|
||||
|
||||
# Ensure command stops retrying after _MAX_RETRIES.
|
||||
fail_too_many = fail_many.copy()
|
||||
fail_too_many["mode"] = {"times": _MAX_RETRIES + 1}
|
||||
cursor = coll.find(batch_size=2)
|
||||
cursor.next()
|
||||
with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
cursor.to_list()
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
def test_limit_retry_command(self):
|
||||
client = self.rs_or_single_client()
|
||||
client._retry_policy.token_bucket.tokens = 1
|
||||
db = client.pymongo_test
|
||||
db.t.insert_one({"x": 1})
|
||||
|
||||
# Ensure command is retried once overload error.
|
||||
fail_many = mock_overload_error.copy()
|
||||
fail_many["mode"] = {"times": 1}
|
||||
with self.fail_point(fail_many):
|
||||
db.command("find", "t")
|
||||
|
||||
# Ensure command stops retrying when there are no tokens left.
|
||||
fail_too_many = mock_overload_error.copy()
|
||||
fail_too_many["mode"] = {"times": 2}
|
||||
with self.fail_point(fail_too_many):
|
||||
with self.assertRaises(PyMongoError) as error:
|
||||
db.command("find", "t")
|
||||
|
||||
self.assertIn("RetryableError", str(error.exception))
|
||||
|
||||
|
||||
class TestRetryPolicy(PyMongoTestCase):
|
||||
def test_retry_policy(self):
|
||||
capacity = 10
|
||||
retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity))
|
||||
self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES)
|
||||
self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL)
|
||||
self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX)
|
||||
for i in range(1, helpers._MAX_RETRIES + 1):
|
||||
self.assertTrue(retry_policy.should_retry(i, 0))
|
||||
self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0))
|
||||
for i in range(capacity - helpers._MAX_RETRIES):
|
||||
self.assertTrue(retry_policy.should_retry(1, 0))
|
||||
# No tokens left, should not retry.
|
||||
self.assertFalse(retry_policy.should_retry(1, 0))
|
||||
self.assertEqual(retry_policy.token_bucket.tokens, 0)
|
||||
|
||||
# record_success should generate tokens.
|
||||
for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)):
|
||||
retry_policy.record_success(retry=False)
|
||||
self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2)
|
||||
for i in range(2):
|
||||
self.assertTrue(retry_policy.should_retry(1, 0))
|
||||
self.assertFalse(retry_policy.should_retry(1, 0))
|
||||
|
||||
# Recording a successful retry should return 1 additional token.
|
||||
retry_policy.record_success(retry=True)
|
||||
self.assertAlmostEqual(
|
||||
retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN
|
||||
)
|
||||
self.assertTrue(retry_policy.should_retry(1, 0))
|
||||
self.assertFalse(retry_policy.should_retry(1, 0))
|
||||
self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN)
|
||||
|
||||
def test_retry_policy_csot(self):
|
||||
retry_policy = _RetryPolicy(_TokenBucket())
|
||||
self.assertTrue(retry_policy.should_retry(1, 0.5))
|
||||
with pymongo.timeout(0.5):
|
||||
self.assertTrue(retry_policy.should_retry(1, 0))
|
||||
self.assertTrue(retry_policy.should_retry(1, 0.1))
|
||||
# Would exceed the timeout, should not retry.
|
||||
self.assertFalse(retry_policy.should_retry(1, 1.0))
|
||||
self.assertTrue(retry_policy.should_retry(1, 1.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -208,6 +208,7 @@ converted_tests = [
|
||||
"test_auth_oidc.py",
|
||||
"test_auth_spec.py",
|
||||
"test_bulk.py",
|
||||
"test_backpressure.py",
|
||||
"test_change_stream.py",
|
||||
"test_client.py",
|
||||
"test_client_bulk_write.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user