PYTHON-3289 Apply client timeoutMS to every operation (#1011)

This commit is contained in:
Shane Harvey 2022-07-19 01:22:43 -05:00 committed by GitHub
parent 5c38676d53
commit 667046129a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 20 deletions

View File

@ -14,9 +14,10 @@
"""Internal helpers for CSOT."""
import functools
import time
from contextvars import ContextVar, Token
from typing import Optional, Tuple
from typing import Any, Callable, Optional, Tuple, TypeVar, cast
TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None)
RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
@ -83,3 +84,22 @@ class _TimeoutContext(object):
TIMEOUT.reset(timeout_token)
DEADLINE.reset(deadline_token)
RTT.reset(rtt_token)
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
F = TypeVar("F", bound=Callable[..., Any])
def apply(func: F) -> F:
"""Apply the client's timeoutMS to this operation."""
@functools.wraps(func)
def csot_wrapper(self, *args, **kwargs):
if get_timeout() is None:
timeout = self._timeout
if timeout is not None:
with _TimeoutContext(timeout):
return func(self, *args, **kwargs)
return func(self, *args, **kwargs)
return cast(F, csot_wrapper)

View File

@ -35,7 +35,7 @@ from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo import ASCENDING, common, helpers, message
from pymongo import ASCENDING, _csot, common, helpers, message
from pymongo.aggregation import (
_CollectionAggregationCommand,
_CollectionRawAggregationCommand,
@ -217,6 +217,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
self.__database: Database[_DocumentType] = database
self.__name = name
self.__full_name = "%s.%s" % (self.__database.name, self.__name)
self.__write_response_codec_options = self.codec_options._replace(
unicode_decode_error_handler="replace", document_class=dict
)
self._timeout = database.client.options.timeout
encrypted_fields = kwargs.pop("encryptedFields", None)
if create or kwargs or collation:
if encrypted_fields:
@ -230,10 +234,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
else:
self.__create(name, kwargs, collation, session)
self.__write_response_codec_options = self.codec_options._replace(
unicode_decode_error_handler="replace", document_class=dict
)
def _socket_for_reads(self, session):
return self.__database.client._socket_for_reads(self._read_preference_for(session), session)
@ -433,6 +433,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
read_concern or self.read_concern,
)
@_csot.apply
def bulk_write(
self,
requests: Sequence[_WriteOp],
@ -631,6 +632,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern.acknowledged,
)
@_csot.apply
def insert_many(
self,
documents: Iterable[_DocumentIn],
@ -1892,6 +1894,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment
return self.__create_indexes(indexes, session, **kwargs)
@_csot.apply
def __create_indexes(self, indexes, session, **kwargs):
"""Internal createIndexes helper.
@ -2088,6 +2091,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
kwargs["comment"] = comment
self.drop_index("*", session=session, **kwargs)
@_csot.apply
def drop_index(
self,
index_or_name: _IndexKeyHint,
@ -2311,6 +2315,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
return options
@_csot.apply
def _aggregate(
self,
aggregation_command,
@ -2618,6 +2623,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
full_document_before_change,
)
@_csot.apply
def rename(
self,
new_name: str,

View File

@ -33,7 +33,7 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
from bson.dbref import DBRef
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo import common
from pymongo import _csot, common
from pymongo.aggregation import _DatabaseAggregationCommand
from pymongo.change_stream import DatabaseChangeStream
from pymongo.collection import Collection
@ -138,6 +138,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
self.__name = name
self.__client: MongoClient[_DocumentType] = client
self._timeout = client.options.timeout
@property
def client(self) -> "MongoClient[_DocumentType]":
@ -290,6 +291,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
read_concern,
)
@_csot.apply
def create_collection(
self,
name: str,
@ -690,6 +692,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
client=self.__client,
)
@_csot.apply
def command(
self,
command: Union[str, MutableMapping[str, Any]],
@ -964,6 +967,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session=session,
)
@_csot.apply
def drop_collection(
self,
name_or_collection: Union[str, Collection],

View File

@ -838,6 +838,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
from pymongo.encryption import _Encrypter
self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts)
self._timeout = options.timeout
def _duplicate(self, **kwargs):
args = self.__init_kwargs.copy()
@ -1270,6 +1271,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _should_pin_cursor(self, session):
return self.__options.load_balanced and not (session and session.in_transaction)
@_csot.apply
def _run_operation(self, operation, unpack_res, address=None):
"""Run a _Query/_GetMore operation and return a Response.
@ -1318,6 +1320,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
)
return self._retry_internal(retryable, func, session, bulk)
@_csot.apply
def _retry_internal(self, retryable, func, session, bulk):
"""Internal retryable write helper."""
max_wire_version = 0
@ -1384,6 +1387,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
retrying = True
last_error = exc
@_csot.apply
def _retryable_read(self, func, read_pref, session, address=None, retryable=True):
"""Execute an operation with at most one consecutive retries
@ -1834,6 +1838,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
"""
return [doc["name"] for doc in self.list_databases(session, nameOnly=True, comment=comment)]
@_csot.apply
def drop_database(
self,
name_or_database: Union[str, database.Database],

View File

@ -1140,27 +1140,20 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
if isinstance(target, MongoClient):
method_name = "_clientOperation_%s" % (opname,)
client = target
elif isinstance(target, Database):
method_name = "_databaseOperation_%s" % (opname,)
client = target.client
elif isinstance(target, Collection):
method_name = "_collectionOperation_%s" % (opname,)
client = target.database.client
elif isinstance(target, ChangeStream):
method_name = "_changeStreamOperation_%s" % (opname,)
client = target._client
elif isinstance(target, NonLazyCursor):
method_name = "_cursor_%s" % (opname,)
client = target.client
elif isinstance(target, ClientSession):
method_name = "_sessionOperation_%s" % (opname,)
client = target._client
elif isinstance(target, GridFSBucket):
raise NotImplementedError
elif isinstance(target, ClientEncryption):
method_name = "_clientEncryptionOperation_%s" % (opname,)
client = target._key_vault_client
else:
method_name = "doesNotExist"
@ -1175,13 +1168,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
cmd = functools.partial(method, target)
try:
# TODO: PYTHON-3289 apply inherited timeout by default.
inherit_timeout = client.options.timeout
# CSOT: Translate the spec test "timeout" arg into pymongo's context timeout API.
if "timeout" in arguments or inherit_timeout is not None:
timeout = arguments.pop("timeout", None)
if timeout is None:
timeout = inherit_timeout
if "timeout" in arguments:
timeout = arguments.pop("timeout")
with pymongo.timeout(timeout):
result = cmd(**dict(arguments))
else: