PYTHON-3289 Apply client timeoutMS to every operation (#1011)
This commit is contained in:
parent
5c38676d53
commit
667046129a
@ -14,9 +14,10 @@
|
||||
|
||||
"""Internal helpers for CSOT."""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Callable, Optional, Tuple, TypeVar, cast
|
||||
|
||||
TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None)
|
||||
RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
|
||||
@ -83,3 +84,22 @@ class _TimeoutContext(object):
|
||||
TIMEOUT.reset(timeout_token)
|
||||
DEADLINE.reset(deadline_token)
|
||||
RTT.reset(rtt_token)
|
||||
|
||||
|
||||
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def apply(func: F) -> F:
|
||||
"""Apply the client's timeoutMS to this operation."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def csot_wrapper(self, *args, **kwargs):
|
||||
if get_timeout() is None:
|
||||
timeout = self._timeout
|
||||
if timeout is not None:
|
||||
with _TimeoutContext(timeout):
|
||||
return func(self, *args, **kwargs)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return cast(F, csot_wrapper)
|
||||
|
||||
@ -35,7 +35,7 @@ from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import ASCENDING, common, helpers, message
|
||||
from pymongo import ASCENDING, _csot, common, helpers, message
|
||||
from pymongo.aggregation import (
|
||||
_CollectionAggregationCommand,
|
||||
_CollectionRawAggregationCommand,
|
||||
@ -217,6 +217,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
self.__database: Database[_DocumentType] = database
|
||||
self.__name = name
|
||||
self.__full_name = "%s.%s" % (self.__database.name, self.__name)
|
||||
self.__write_response_codec_options = self.codec_options._replace(
|
||||
unicode_decode_error_handler="replace", document_class=dict
|
||||
)
|
||||
self._timeout = database.client.options.timeout
|
||||
encrypted_fields = kwargs.pop("encryptedFields", None)
|
||||
if create or kwargs or collation:
|
||||
if encrypted_fields:
|
||||
@ -230,10 +234,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
else:
|
||||
self.__create(name, kwargs, collation, session)
|
||||
|
||||
self.__write_response_codec_options = self.codec_options._replace(
|
||||
unicode_decode_error_handler="replace", document_class=dict
|
||||
)
|
||||
|
||||
def _socket_for_reads(self, session):
|
||||
return self.__database.client._socket_for_reads(self._read_preference_for(session), session)
|
||||
|
||||
@ -433,6 +433,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
read_concern or self.read_concern,
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
def bulk_write(
|
||||
self,
|
||||
requests: Sequence[_WriteOp],
|
||||
@ -631,6 +632,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
write_concern.acknowledged,
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
def insert_many(
|
||||
self,
|
||||
documents: Iterable[_DocumentIn],
|
||||
@ -1892,6 +1894,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
kwargs["comment"] = comment
|
||||
return self.__create_indexes(indexes, session, **kwargs)
|
||||
|
||||
@_csot.apply
|
||||
def __create_indexes(self, indexes, session, **kwargs):
|
||||
"""Internal createIndexes helper.
|
||||
|
||||
@ -2088,6 +2091,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
kwargs["comment"] = comment
|
||||
self.drop_index("*", session=session, **kwargs)
|
||||
|
||||
@_csot.apply
|
||||
def drop_index(
|
||||
self,
|
||||
index_or_name: _IndexKeyHint,
|
||||
@ -2311,6 +2315,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
return options
|
||||
|
||||
@_csot.apply
|
||||
def _aggregate(
|
||||
self,
|
||||
aggregation_command,
|
||||
@ -2618,6 +2623,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
full_document_before_change,
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
def rename(
|
||||
self,
|
||||
new_name: str,
|
||||
|
||||
@ -33,7 +33,7 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
|
||||
from bson.dbref import DBRef
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import common
|
||||
from pymongo import _csot, common
|
||||
from pymongo.aggregation import _DatabaseAggregationCommand
|
||||
from pymongo.change_stream import DatabaseChangeStream
|
||||
from pymongo.collection import Collection
|
||||
@ -138,6 +138,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
self.__name = name
|
||||
self.__client: MongoClient[_DocumentType] = client
|
||||
self._timeout = client.options.timeout
|
||||
|
||||
@property
|
||||
def client(self) -> "MongoClient[_DocumentType]":
|
||||
@ -290,6 +291,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
read_concern,
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
@ -690,6 +692,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
client=self.__client,
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
def command(
|
||||
self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
@ -964,6 +967,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
session=session,
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
def drop_collection(
|
||||
self,
|
||||
name_or_collection: Union[str, Collection],
|
||||
|
||||
@ -838,6 +838,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
from pymongo.encryption import _Encrypter
|
||||
|
||||
self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts)
|
||||
self._timeout = options.timeout
|
||||
|
||||
def _duplicate(self, **kwargs):
|
||||
args = self.__init_kwargs.copy()
|
||||
@ -1270,6 +1271,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
def _should_pin_cursor(self, session):
|
||||
return self.__options.load_balanced and not (session and session.in_transaction)
|
||||
|
||||
@_csot.apply
|
||||
def _run_operation(self, operation, unpack_res, address=None):
|
||||
"""Run a _Query/_GetMore operation and return a Response.
|
||||
|
||||
@ -1318,6 +1320,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
return self._retry_internal(retryable, func, session, bulk)
|
||||
|
||||
@_csot.apply
|
||||
def _retry_internal(self, retryable, func, session, bulk):
|
||||
"""Internal retryable write helper."""
|
||||
max_wire_version = 0
|
||||
@ -1384,6 +1387,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
retrying = True
|
||||
last_error = exc
|
||||
|
||||
@_csot.apply
|
||||
def _retryable_read(self, func, read_pref, session, address=None, retryable=True):
|
||||
"""Execute an operation with at most one consecutive retries
|
||||
|
||||
@ -1834,6 +1838,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
return [doc["name"] for doc in self.list_databases(session, nameOnly=True, comment=comment)]
|
||||
|
||||
@_csot.apply
|
||||
def drop_database(
|
||||
self,
|
||||
name_or_database: Union[str, database.Database],
|
||||
|
||||
@ -1140,27 +1140,20 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
|
||||
if isinstance(target, MongoClient):
|
||||
method_name = "_clientOperation_%s" % (opname,)
|
||||
client = target
|
||||
elif isinstance(target, Database):
|
||||
method_name = "_databaseOperation_%s" % (opname,)
|
||||
client = target.client
|
||||
elif isinstance(target, Collection):
|
||||
method_name = "_collectionOperation_%s" % (opname,)
|
||||
client = target.database.client
|
||||
elif isinstance(target, ChangeStream):
|
||||
method_name = "_changeStreamOperation_%s" % (opname,)
|
||||
client = target._client
|
||||
elif isinstance(target, NonLazyCursor):
|
||||
method_name = "_cursor_%s" % (opname,)
|
||||
client = target.client
|
||||
elif isinstance(target, ClientSession):
|
||||
method_name = "_sessionOperation_%s" % (opname,)
|
||||
client = target._client
|
||||
elif isinstance(target, GridFSBucket):
|
||||
raise NotImplementedError
|
||||
elif isinstance(target, ClientEncryption):
|
||||
method_name = "_clientEncryptionOperation_%s" % (opname,)
|
||||
client = target._key_vault_client
|
||||
else:
|
||||
method_name = "doesNotExist"
|
||||
|
||||
@ -1175,13 +1168,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
cmd = functools.partial(method, target)
|
||||
|
||||
try:
|
||||
# TODO: PYTHON-3289 apply inherited timeout by default.
|
||||
inherit_timeout = client.options.timeout
|
||||
# CSOT: Translate the spec test "timeout" arg into pymongo's context timeout API.
|
||||
if "timeout" in arguments or inherit_timeout is not None:
|
||||
timeout = arguments.pop("timeout", None)
|
||||
if timeout is None:
|
||||
timeout = inherit_timeout
|
||||
if "timeout" in arguments:
|
||||
timeout = arguments.pop("timeout")
|
||||
with pymongo.timeout(timeout):
|
||||
result = cmd(**dict(arguments))
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user