From 667046129a54e5f9a4e29334470bdef235e891dd Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 19 Jul 2022 01:22:43 -0500 Subject: [PATCH] PYTHON-3289 Apply client timeoutMS to every operation (#1011) --- pymongo/_csot.py | 22 +++++++++++++++++++++- pymongo/collection.py | 16 +++++++++++----- pymongo/database.py | 6 +++++- pymongo/mongo_client.py | 5 +++++ test/unified_format.py | 15 ++------------- 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/pymongo/_csot.py b/pymongo/_csot.py index 6d3cd3c0f..e25bba108 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -14,9 +14,10 @@ """Internal helpers for CSOT.""" +import functools import time from contextvars import ContextVar, Token -from typing import Optional, Tuple +from typing import Any, Callable, Optional, Tuple, TypeVar, cast TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None) RTT: ContextVar[float] = ContextVar("RTT", default=0.0) @@ -83,3 +84,22 @@ class _TimeoutContext(object): TIMEOUT.reset(timeout_token) DEADLINE.reset(deadline_token) RTT.reset(rtt_token) + + +# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories +F = TypeVar("F", bound=Callable[..., Any]) + + +def apply(func: F) -> F: + """Apply the client's timeoutMS to this operation.""" + + @functools.wraps(func) + def csot_wrapper(self, *args, **kwargs): + if get_timeout() is None: + timeout = self._timeout + if timeout is not None: + with _TimeoutContext(timeout): + return func(self, *args, **kwargs) + return func(self, *args, **kwargs) + + return cast(F, csot_wrapper) diff --git a/pymongo/collection.py b/pymongo/collection.py index 4aff5c178..22af5a642 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -35,7 +35,7 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from bson.son import SON from bson.timestamp import Timestamp -from pymongo import ASCENDING, common, helpers, message +from pymongo import ASCENDING, _csot, common, helpers, message from pymongo.aggregation import ( _CollectionAggregationCommand, _CollectionRawAggregationCommand, @@ -217,6 +217,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]): self.__database: Database[_DocumentType] = database self.__name = name self.__full_name = "%s.%s" % (self.__database.name, self.__name) + self.__write_response_codec_options = self.codec_options._replace( + unicode_decode_error_handler="replace", document_class=dict + ) + self._timeout = database.client.options.timeout encrypted_fields = kwargs.pop("encryptedFields", None) if create or kwargs or collation: if encrypted_fields: @@ -230,10 +234,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]): else: self.__create(name, kwargs, collation, session) - self.__write_response_codec_options = self.codec_options._replace( - unicode_decode_error_handler="replace", document_class=dict - ) - def _socket_for_reads(self, session): return self.__database.client._socket_for_reads(self._read_preference_for(session), session) @@ -433,6 +433,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): read_concern or self.read_concern, ) + @_csot.apply def bulk_write( self, requests: Sequence[_WriteOp], @@ -631,6 +632,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): write_concern.acknowledged, ) + @_csot.apply def insert_many( self, documents: Iterable[_DocumentIn], @@ -1892,6 +1894,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): kwargs["comment"] = comment return self.__create_indexes(indexes, session, **kwargs) + @_csot.apply def __create_indexes(self, indexes, session, **kwargs): """Internal createIndexes helper. @@ -2088,6 +2091,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): kwargs["comment"] = comment self.drop_index("*", session=session, **kwargs) + @_csot.apply def drop_index( self, index_or_name: _IndexKeyHint, @@ -2311,6 +2315,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): return options + @_csot.apply def _aggregate( self, aggregation_command, @@ -2618,6 +2623,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): full_document_before_change, ) + @_csot.apply def rename( self, new_name: str, diff --git a/pymongo/database.py b/pymongo/database.py index 9b9d51201..4f87a58dd 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -33,7 +33,7 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions from bson.dbref import DBRef from bson.son import SON from bson.timestamp import Timestamp -from pymongo import common +from pymongo import _csot, common from pymongo.aggregation import _DatabaseAggregationCommand from pymongo.change_stream import DatabaseChangeStream from pymongo.collection import Collection @@ -138,6 +138,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): self.__name = name self.__client: MongoClient[_DocumentType] = client + self._timeout = client.options.timeout @property def client(self) -> "MongoClient[_DocumentType]": @@ -290,6 +291,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): read_concern, ) + @_csot.apply def create_collection( self, name: str, @@ -690,6 +692,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): client=self.__client, ) + @_csot.apply def command( self, command: Union[str, MutableMapping[str, Any]], @@ -964,6 +967,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): session=session, ) + @_csot.apply def drop_collection( self, name_or_collection: Union[str, Collection], diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index e949ba5cd..080ae8757 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -838,6 +838,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): from pymongo.encryption import _Encrypter self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts) + self._timeout = options.timeout def _duplicate(self, **kwargs): args = self.__init_kwargs.copy() @@ -1270,6 +1271,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): def _should_pin_cursor(self, session): return self.__options.load_balanced and not (session and session.in_transaction) + @_csot.apply def _run_operation(self, operation, unpack_res, address=None): """Run a _Query/_GetMore operation and return a Response. @@ -1318,6 +1320,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): ) return self._retry_internal(retryable, func, session, bulk) + @_csot.apply def _retry_internal(self, retryable, func, session, bulk): """Internal retryable write helper.""" max_wire_version = 0 @@ -1384,6 +1387,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): retrying = True last_error = exc + @_csot.apply def _retryable_read(self, func, read_pref, session, address=None, retryable=True): """Execute an operation with at most one consecutive retries @@ -1834,6 +1838,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): """ return [doc["name"] for doc in self.list_databases(session, nameOnly=True, comment=comment)] + @_csot.apply def drop_database( self, name_or_database: Union[str, database.Database], diff --git a/test/unified_format.py b/test/unified_format.py index 7e6c09023..e37bc1bb6 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1140,27 +1140,20 @@ class UnifiedSpecTestMixinV1(IntegrationTest): if isinstance(target, MongoClient): method_name = "_clientOperation_%s" % (opname,) - client = target elif isinstance(target, Database): method_name = "_databaseOperation_%s" % (opname,) - client = target.client elif isinstance(target, Collection): method_name = "_collectionOperation_%s" % (opname,) - client = target.database.client elif isinstance(target, ChangeStream): method_name = "_changeStreamOperation_%s" % (opname,) - client = target._client elif isinstance(target, NonLazyCursor): method_name = "_cursor_%s" % (opname,) - client = target.client elif isinstance(target, ClientSession): method_name = "_sessionOperation_%s" % (opname,) - client = target._client elif isinstance(target, GridFSBucket): raise NotImplementedError elif isinstance(target, ClientEncryption): method_name = "_clientEncryptionOperation_%s" % (opname,) - client = target._key_vault_client else: method_name = "doesNotExist" @@ -1175,13 +1168,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest): cmd = functools.partial(method, target) try: - # TODO: PYTHON-3289 apply inherited timeout by default. - inherit_timeout = client.options.timeout # CSOT: Translate the spec test "timeout" arg into pymongo's context timeout API. - if "timeout" in arguments or inherit_timeout is not None: - timeout = arguments.pop("timeout", None) - if timeout is None: - timeout = inherit_timeout + if "timeout" in arguments: + timeout = arguments.pop("timeout") with pymongo.timeout(timeout): result = cmd(**dict(arguments)) else: