From dd6c140d438039e9f6df96cd3d4221f380a37e18 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 2 Feb 2022 21:12:36 -0600 Subject: [PATCH] PYTHON-3060 Add typings to pymongo package (#831) --- bson/__init__.py | 19 +- bson/binary.py | 9 +- gridfs/grid_file.py | 4 +- mypy.ini | 22 +++ pymongo/__init__.py | 28 +-- pymongo/aggregation.py | 6 +- pymongo/auth.py | 14 +- pymongo/auth_aws.py | 7 +- pymongo/bulk.py | 22 +-- pymongo/change_stream.py | 70 ++++--- pymongo/client_options.py | 2 +- pymongo/client_session.py | 106 ++++++----- pymongo/collation.py | 31 ++-- pymongo/collection.py | 311 +++++++++++++++++++++----------- pymongo/command_cursor.py | 71 +++++--- pymongo/common.py | 146 +++++++-------- pymongo/compression_support.py | 7 +- pymongo/cursor.py | 189 +++++++++++-------- pymongo/database.py | 157 +++++++++++----- pymongo/driver_info.py | 3 +- pymongo/encryption.py | 63 ++++--- pymongo/encryption_options.py | 30 +-- pymongo/errors.py | 51 ++++-- pymongo/event_loggers.py | 46 +++-- pymongo/hello.py | 66 ++++--- pymongo/helpers.py | 14 +- pymongo/message.py | 29 ++- pymongo/mongo_client.py | 181 +++++++++++-------- pymongo/monitor.py | 10 +- pymongo/monitoring.py | 203 +++++++++++++-------- pymongo/network.py | 9 +- pymongo/ocsp_cache.py | 2 +- pymongo/ocsp_support.py | 61 +++---- pymongo/operations.py | 47 ++--- pymongo/periodic_executor.py | 18 +- pymongo/pool.py | 60 +++--- pymongo/pyopenssl_context.py | 39 ++-- pymongo/read_concern.py | 12 +- pymongo/read_preferences.py | 74 ++++---- pymongo/results.py | 58 +++--- pymongo/saslprep.py | 9 +- pymongo/server.py | 7 +- pymongo/server_description.py | 88 ++++----- pymongo/server_type.py | 19 +- pymongo/socket_checker.py | 12 +- pymongo/srv_resolver.py | 3 +- pymongo/ssl_context.py | 1 + pymongo/ssl_support.py | 8 +- pymongo/topology.py | 41 ++--- pymongo/topology_description.py | 99 ++++++---- pymongo/typings.py | 29 +++ pymongo/uri_parser.py | 40 ++-- pymongo/write_concern.py | 16 +- test/performance/perf_test.py | 2 +- test/test_cursor.py | 2 +- test/test_grid_file.py | 2 +- tools/clean.py | 2 +- 57 files changed, 1578 insertions(+), 1099 deletions(-) create mode 100644 pymongo/typings.py diff --git a/bson/__init__.py b/bson/__init__.py index 5be673cfc..e518cd91c 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -61,10 +61,10 @@ import re import struct import sys import uuid -from codecs import utf_8_decode as _utf_8_decode # type: ignore -from codecs import utf_8_encode as _utf_8_encode # type: ignore +from codecs import utf_8_decode as _utf_8_decode # type: ignore[attr-defined] +from codecs import utf_8_encode as _utf_8_encode # type: ignore[attr-defined] from collections import abc as _abc -from typing import (TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Generator, +from typing import (IO, TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Generator, Iterator, List, Mapping, MutableMapping, NoReturn, Sequence, Tuple, Type, TypeVar, Union, cast) @@ -88,11 +88,13 @@ from bson.tz_util import utc # Import RawBSONDocument for type-checking only to avoid circular dependency. if TYPE_CHECKING: + from array import array + from mmap import mmap from bson.raw_bson import RawBSONDocument try: - from bson import _cbson # type: ignore + from bson import _cbson # type: ignore[attr-defined] _USE_C = True except ImportError: _USE_C = False @@ -851,6 +853,7 @@ _CODEC_OPTIONS_TYPE_ERROR = TypeError( _DocumentIn = Mapping[str, Any] _DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"] +_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] def encode(document: _DocumentIn, check_keys: bool = False, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> bytes: @@ -880,7 +883,7 @@ def encode(document: _DocumentIn, check_keys: bool = False, codec_options: Codec return _dict_to_bson(document, check_keys, codec_options) -def decode(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> _DocumentOut: +def decode(data: _ReadableBuffer, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Dict[str, Any]: """Decode BSON to a document. By default, returns a BSON document represented as a Python @@ -912,7 +915,7 @@ def decode(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> return _bson_to_dict(data, codec_options) -def decode_all(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> List[_DocumentOut]: +def decode_all(data: _ReadableBuffer, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> List[Dict[str, Any]]: """Decode BSON data to multiple documents. `data` must be a bytes-like object implementing the buffer protocol that @@ -1075,7 +1078,7 @@ def decode_iter(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS yield _bson_to_dict(elements, codec_options) -def decode_file_iter(file_obj: BinaryIO, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Iterator[_DocumentOut]: +def decode_file_iter(file_obj: Union[BinaryIO, IO], codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Iterator[_DocumentOut]: """Decode bson data from a file to multiple documents as a generator. Works similarly to the decode_all function, but reads from the file object @@ -1158,7 +1161,7 @@ class BSON(bytes): """ return cls(encode(document, check_keys, codec_options)) - def decode(self, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> _DocumentOut: # type: ignore[override] + def decode(self, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Dict[str, Any]: # type: ignore[override] """Decode this BSON data. By default, returns a BSON document represented as a Python diff --git a/bson/binary.py b/bson/binary.py index 53d5419b4..de44d4817 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Tuple, Type +from typing import Any, Tuple, Type, Union, TYPE_CHECKING from uuid import UUID """Tools for representing BSON binary data. @@ -57,6 +57,11 @@ by :mod:`bson` using this subtype when using """ +if TYPE_CHECKING: + from array import array as _array + from mmap import mmap as _mmap + + class UuidRepresentation: UNSPECIFIED = 0 """An unspecified UUID representation. @@ -211,7 +216,7 @@ class Binary(bytes): _type_marker = 5 __subtype: int - def __new__(cls: Type["Binary"], data: bytes, subtype: int = BINARY_SUBTYPE) -> "Binary": + def __new__(cls: Type["Binary"], data: Union[memoryview, bytes, "_mmap", "_array"], subtype: int = BINARY_SUBTYPE) -> "Binary": if not isinstance(subtype, int): raise TypeError("subtype must be an instance of int") if subtype >= 256 or subtype < 0: diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index 9353a97a1..686d328a3 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -874,10 +874,10 @@ class GridOutCursor(Cursor): __next__ = next - def add_option(self, *args: Any, **kwargs: Any) -> None: + def add_option(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override] raise NotImplementedError("Method does not exist for GridOutCursor") - def remove_option(self, *args: Any, **kwargs: Any) -> None: + def remove_option(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override] raise NotImplementedError("Method does not exist for GridOutCursor") def _clone_base(self, session: ClientSession) -> "GridOutCursor": diff --git a/mypy.ini b/mypy.ini index 2646febb6..926bf9574 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,11 +1,33 @@ [mypy] +check_untyped_defs = true disallow_subclassing_any = true disallow_incomplete_defs = true no_implicit_optional = true +pretty = true +show_error_context = true +show_error_codes = true strict_equality = true warn_unused_configs = true warn_unused_ignores = true warn_redundant_casts = true +[mypy-kerberos.*] +ignore_missing_imports = True + [mypy-mockupdb] ignore_missing_imports = True + +[mypy-pymongo_auth_aws.*] +ignore_missing_imports = True + +[mypy-pymongocrypt.*] +ignore_missing_imports = True + +[mypy-service_identity.*] +ignore_missing_imports = True + +[mypy-snappy.*] +ignore_missing_imports = True + +[mypy-winkerberos.*] +ignore_missing_imports = True diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 5db9363f9..54a962df5 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -14,6 +14,8 @@ """Python driver for MongoDB.""" +from typing import Tuple, Union + ASCENDING = 1 """Ascending sort order.""" DESCENDING = -1 @@ -53,35 +55,33 @@ TEXT = "text" .. _text index: http://docs.mongodb.org/manual/core/index-text/ """ -version_tuple = (4, 1, 0, '.dev0') +version_tuple: Tuple[Union[int, str], ...] = (4, 1, 0, '.dev0') -def get_version_string(): +def get_version_string() -> str: if isinstance(version_tuple[-1], str): return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] return '.'.join(map(str, version_tuple)) -__version__ = version = get_version_string() +__version__: str = get_version_string() +version = __version__ + """Current version of PyMongo.""" from pymongo.collection import ReturnDocument -from pymongo.common import (MIN_SUPPORTED_WIRE_VERSION, - MAX_SUPPORTED_WIRE_VERSION) +from pymongo.common import (MAX_SUPPORTED_WIRE_VERSION, + MIN_SUPPORTED_WIRE_VERSION) from pymongo.cursor import CursorType from pymongo.mongo_client import MongoClient -from pymongo.operations import (IndexModel, - InsertOne, - DeleteOne, - DeleteMany, - UpdateOne, - UpdateMany, - ReplaceOne) +from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, InsertOne, + ReplaceOne, UpdateMany, UpdateOne) from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern -def has_c(): + +def has_c() -> bool: """Is the C extension installed?""" try: - from pymongo import _cmessage + from pymongo import _cmessage # type: ignore[attr-defined] return True except ImportError: return False diff --git a/pymongo/aggregation.py b/pymongo/aggregation.py index 8fb0225eb..b2e20e9ca 100644 --- a/pymongo/aggregation.py +++ b/pymongo/aggregation.py @@ -15,11 +15,10 @@ """Perform aggregation operations on a collection or database.""" from bson.son import SON - from pymongo import common from pymongo.collation import validate_collation_or_none from pymongo.errors import ConfigurationError -from pymongo.read_preferences import _AggWritePref, ReadPreference +from pymongo.read_preferences import ReadPreference, _AggWritePref class _AggregationCommand(object): @@ -37,7 +36,7 @@ class _AggregationCommand(object): self._target = target - common.validate_list('pipeline', pipeline) + pipeline = common.validate_list('pipeline', pipeline) self._pipeline = pipeline self._performs_write = False if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]): @@ -82,7 +81,6 @@ class _AggregationCommand(object): """The namespace in which the aggregate command is run.""" raise NotImplementedError - @property def _cursor_collection(self, cursor_doc): """The Collection used for the aggregate command cursor.""" raise NotImplementedError diff --git a/pymongo/auth.py b/pymongo/auth.py index a2e206357..34f1c7fc9 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -19,9 +19,9 @@ import hashlib import hmac import os import socket - from base64 import standard_b64decode, standard_b64encode from collections import namedtuple +from typing import Callable, Mapping from urllib.parse import quote from bson.binary import Binary @@ -97,7 +97,7 @@ GSSAPIProperties = namedtuple('GSSAPIProperties', """Mechanism properties for GSSAPI authentication.""" -_AWSProperties = namedtuple('AWSProperties', ['aws_session_token']) +_AWSProperties = namedtuple('_AWSProperties', ['aws_session_token']) """Mechanism properties for MONGODB-AWS authentication.""" @@ -140,9 +140,9 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): properties = extra.get('authmechanismproperties', {}) aws_session_token = properties.get('AWS_SESSION_TOKEN') - props = _AWSProperties(aws_session_token=aws_session_token) + aws_props = _AWSProperties(aws_session_token=aws_session_token) # user can be None for temporary link-local EC2 credentials. - return MongoCredential(mech, '$external', user, passwd, props, None) + return MongoCredential(mech, '$external', user, passwd, aws_props, None) elif mech == 'PLAIN': source_database = source or database or '$external' return MongoCredential(mech, source_database, user, passwd, None, None) @@ -471,7 +471,7 @@ def _authenticate_default(credentials, sock_info): return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1') -_AUTH_MAP = { +_AUTH_MAP: Mapping[str, Callable] = { 'GSSAPI': _authenticate_gssapi, 'MONGODB-CR': _authenticate_mongo_cr, 'MONGODB-X509': _authenticate_x509, @@ -532,7 +532,7 @@ class _X509Context(_AuthContext): return cmd -_SPECULATIVE_AUTH_MAP = { +_SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = { 'MONGODB-X509': _X509Context, 'SCRAM-SHA-1': functools.partial(_ScramContext, mechanism='SCRAM-SHA-1'), 'SCRAM-SHA-256': functools.partial(_ScramContext, @@ -544,6 +544,6 @@ _SPECULATIVE_AUTH_MAP = { def authenticate(credentials, sock_info): """Authenticate sock_info.""" mechanism = credentials.mechanism - auth_func = _AUTH_MAP.get(mechanism) + auth_func = _AUTH_MAP[mechanism] auth_func(credentials, sock_info) diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index ff07a12e7..0233d192d 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -16,12 +16,11 @@ try: import pymongo_auth_aws - from pymongo_auth_aws import (AwsCredential, - AwsSaslContext, + from pymongo_auth_aws import (AwsCredential, AwsSaslContext, PyMongoAuthAwsError) _HAVE_MONGODB_AWS = True except ImportError: - class AwsSaslContext(object): + class AwsSaslContext(object): # type: ignore def __init__(self, credentials): pass _HAVE_MONGODB_AWS = False @@ -32,7 +31,7 @@ from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure -class _AwsSaslContext(AwsSaslContext): +class _AwsSaslContext(AwsSaslContext): # type: ignore # Dependency injection: def binary_type(self): """Return the bson.binary.Binary type.""" diff --git a/pymongo/bulk.py b/pymongo/bulk.py index 1921108a1..8d343bb2c 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -17,31 +17,23 @@ .. versionadded:: 2.7 """ import copy - from itertools import islice from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from bson.son import SON from pymongo.client_session import _validate_session_write_concern -from pymongo.common import (validate_is_mapping, - validate_is_document_type, - validate_ok_for_replace, - validate_ok_for_update) -from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc from pymongo.collation import validate_collation_or_none -from pymongo.errors import (BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure) -from pymongo.message import (_INSERT, _UPDATE, _DELETE, - _randint, - _BulkWriteContext, - _EncryptedBulkWriteContext) +from pymongo.common import (validate_is_document_type, validate_is_mapping, + validate_ok_for_replace, validate_ok_for_update) +from pymongo.errors import (BulkWriteError, ConfigurationError, + InvalidOperation, OperationFailure) +from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc +from pymongo.message import (_DELETE, _INSERT, _UPDATE, _BulkWriteContext, + _EncryptedBulkWriteContext, _randint) from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern - _DELETE_ALL = 0 _DELETE_ONE = 1 diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 54bf98d83..69446fdec 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -15,21 +15,20 @@ """Watch changes on a collection, a database, or the entire cluster.""" import copy +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterator, Mapping, + Optional, Union) from bson import _bson_to_dict from bson.raw_bson import RawBSONDocument - +from bson.timestamp import Timestamp from pymongo import common from pymongo.aggregation import (_CollectionAggregationCommand, _DatabaseAggregationCommand) from pymongo.collation import validate_collation_or_none from pymongo.command_cursor import CommandCursor -from pymongo.errors import (ConnectionFailure, - CursorNotFound, - InvalidOperation, - OperationFailure, - PyMongoError) - +from pymongo.errors import (ConnectionFailure, CursorNotFound, + InvalidOperation, OperationFailure, PyMongoError) +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline # The change streams spec considers the following server errors from the # getMore command non-resumable. All other getMore errors are resumable. @@ -55,7 +54,14 @@ _RESUMABLE_GETMORE_ERRORS = frozenset([ ]) -class ChangeStream(object): +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + from pymongo.collection import Collection + from pymongo.database import Database + from pymongo.mongo_client import MongoClient + + +class ChangeStream(Generic[_DocumentType]): """The internal abstract base class for change stream cursors. Should not be called directly by application developers. Use @@ -66,14 +72,22 @@ class ChangeStream(object): .. versionadded:: 3.6 .. seealso:: The MongoDB documentation on `changeStreams `_. """ - def __init__(self, target, pipeline, full_document, resume_after, - max_await_time_ms, batch_size, collation, - start_at_operation_time, session, start_after): + def __init__( + self, + target: Union["MongoClient[_DocumentType]", "Database[_DocumentType]", "Collection[_DocumentType]"], + pipeline: Optional[_Pipeline], + full_document: Optional[str], + resume_after: Optional[Mapping[str, Any]], + max_await_time_ms: Optional[int], + batch_size: Optional[int], + collation: Optional[_CollationIn], + start_at_operation_time: Optional[Timestamp], + session: Optional["ClientSession"], + start_after: Optional[Mapping[str, Any]], + ) -> None: if pipeline is None: pipeline = [] - elif not isinstance(pipeline, list): - raise TypeError("pipeline must be a list") - + pipeline = common.validate_list('pipeline', pipeline) common.validate_string_or_none('full_document', full_document) validate_collation_or_none(collation) common.validate_non_negative_integer_or_none("batchSize", batch_size) @@ -84,7 +98,7 @@ class ChangeStream(object): self._decode_custom = True # Keep the type registry so that we support encoding custom types # in the pipeline. - self._target = target.with_options( + self._target = target.with_options( # type: ignore codec_options=target.codec_options.with_options( document_class=RawBSONDocument)) else: @@ -117,7 +131,7 @@ class ChangeStream(object): def _change_stream_options(self): """Return the options dict for the $changeStream pipeline stage.""" - options = {} + options: Dict[str, Any] = {} if self._full_document is not None: options['fullDocument'] = self._full_document @@ -144,7 +158,7 @@ class ChangeStream(object): def _aggregation_pipeline(self): """Return the full aggregation pipeline for this ChangeStream.""" options = self._change_stream_options() - full_pipeline = [{'$changeStream': options}] + full_pipeline: list = [{'$changeStream': options}] full_pipeline.extend(self._pipeline) return full_pipeline @@ -197,15 +211,15 @@ class ChangeStream(object): pass self._cursor = self._create_cursor() - def close(self): + def close(self) -> None: """Close this ChangeStream.""" self._cursor.close() - def __iter__(self): + def __iter__(self) -> "ChangeStream[_DocumentType]": return self @property - def resume_token(self): + def resume_token(self) -> Optional[Mapping[str, Any]]: """The cached resume token that will be used to resume after the most recently returned change. @@ -213,7 +227,7 @@ class ChangeStream(object): """ return copy.deepcopy(self._resume_token) - def next(self): + def next(self) -> _DocumentType: """Advance the cursor. This method blocks until the next change document is returned or an @@ -255,7 +269,7 @@ class ChangeStream(object): __next__ = next @property - def alive(self): + def alive(self) -> bool: """Does this cursor have the potential to return more data? .. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise @@ -265,7 +279,7 @@ class ChangeStream(object): """ return self._cursor.alive - def try_next(self): + def try_next(self) -> Optional[_DocumentType]: """Advance the cursor without blocking indefinitely. This method returns the next change document without waiting @@ -354,14 +368,14 @@ class ChangeStream(object): return _bson_to_dict(change.raw, self._orig_codec_options) return change - def __enter__(self): + def __enter__(self) -> "ChangeStream": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() -class CollectionChangeStream(ChangeStream): +class CollectionChangeStream(ChangeStream, Generic[_DocumentType]): """A change stream that watches changes on a single collection. Should not be called directly by application developers. Use @@ -378,7 +392,7 @@ class CollectionChangeStream(ChangeStream): return self._target.database.client -class DatabaseChangeStream(ChangeStream): +class DatabaseChangeStream(ChangeStream, Generic[_DocumentType]): """A change stream that watches changes on all collections in a database. Should not be called directly by application developers. Use @@ -395,7 +409,7 @@ class DatabaseChangeStream(ChangeStream): return self._target.client -class ClusterChangeStream(DatabaseChangeStream): +class ClusterChangeStream(DatabaseChangeStream, Generic[_DocumentType]): """A change stream that watches changes on all collections in the cluster. Should not be called directly by application developers. Use diff --git a/pymongo/client_options.py b/pymongo/client_options.py index c2f5ae01c..14ef0f781 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -15,9 +15,9 @@ """Tools to parse mongo client options.""" from bson.codec_options import _parse_codec_options +from pymongo import common from pymongo.auth import _build_credentials_tuple from pymongo.common import validate_boolean -from pymongo import common from pymongo.compression_support import CompressionSettings from pymongo.errors import ConfigurationError from pymongo.monitoring import _EventListeners diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 8c61623ae..3d4ad514e 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -134,25 +134,23 @@ Classes import collections import time import uuid - from collections.abc import Mapping as _Mapping +from typing import (TYPE_CHECKING, Any, Callable, ContextManager, Generic, + Mapping, Optional, TypeVar) from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from bson.timestamp import Timestamp - from pymongo.cursor import _SocketManager -from pymongo.errors import (ConfigurationError, - ConnectionFailure, - InvalidOperation, - OperationFailure, - PyMongoError, +from pymongo.errors import (ConfigurationError, ConnectionFailure, + InvalidOperation, OperationFailure, PyMongoError, WTimeoutError) from pymongo.helpers import _RETRYABLE_ERROR_CODES from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.server_type import SERVER_TYPE +from pymongo.typings import _DocumentType from pymongo.write_concern import WriteConcern @@ -172,10 +170,12 @@ class SessionOptions(object): .. versionchanged:: 3.12 Added the ``snapshot`` parameter. """ - def __init__(self, - causal_consistency=None, - default_transaction_options=None, - snapshot=False): + def __init__( + self, + causal_consistency: Optional[bool] = None, + default_transaction_options: Optional["TransactionOptions"] = None, + snapshot: Optional[bool] = False, + ) -> None: if snapshot: if causal_consistency: raise ConfigurationError('snapshot reads do not support ' @@ -194,12 +194,12 @@ class SessionOptions(object): self._snapshot = snapshot @property - def causal_consistency(self): + def causal_consistency(self) -> bool: """Whether causal consistency is configured.""" return self._causal_consistency @property - def default_transaction_options(self): + def default_transaction_options(self) -> Optional["TransactionOptions"]: """The default TransactionOptions to use for transactions started on this session. @@ -208,7 +208,7 @@ class SessionOptions(object): return self._default_transaction_options @property - def snapshot(self): + def snapshot(self) -> Optional[bool]: """Whether snapshot reads are configured. .. versionadded:: 3.12 @@ -243,8 +243,13 @@ class TransactionOptions(object): .. versionadded:: 3.7 """ - def __init__(self, read_concern=None, write_concern=None, - read_preference=None, max_commit_time_ms=None): + def __init__( + self, + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + read_preference: Optional[_ServerMode] = None, + max_commit_time_ms: Optional[int] = None + ) -> None: self._read_concern = read_concern self._write_concern = write_concern self._read_preference = read_preference @@ -274,23 +279,23 @@ class TransactionOptions(object): "max_commit_time_ms must be an integer or None") @property - def read_concern(self): + def read_concern(self) -> Optional[ReadConcern]: """This transaction's :class:`~pymongo.read_concern.ReadConcern`.""" return self._read_concern @property - def write_concern(self): + def write_concern(self) -> Optional[WriteConcern]: """This transaction's :class:`~pymongo.write_concern.WriteConcern`.""" return self._write_concern @property - def read_preference(self): + def read_preference(self) -> Optional[_ServerMode]: """This transaction's :class:`~pymongo.read_preferences.ReadPreference`. """ return self._read_preference @property - def max_commit_time_ms(self): + def max_commit_time_ms(self) -> Optional[int]: """The maxTimeMS to use when running a commitTransaction command. .. versionadded:: 3.9 @@ -427,7 +432,13 @@ def _within_time_limit(start_time): return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT -class ClientSession(object): +_T = TypeVar("_T") + +if TYPE_CHECKING: + from pymongo.mongo_client import MongoClient + + +class ClientSession(Generic[_DocumentType]): """A session for ordering sequential operations. :class:`ClientSession` instances are **not thread-safe or fork-safe**. @@ -439,9 +450,11 @@ class ClientSession(object): :class:`ClientSession`, call :meth:`~pymongo.mongo_client.MongoClient.start_session`. """ - def __init__(self, client, server_session, options, implicit): + def __init__( + self, client: "MongoClient[_DocumentType]", server_session: Any, options: SessionOptions, implicit: bool + ) -> None: # A MongoClient, a _ServerSession, a SessionOptions, and a set. - self._client = client + self._client: MongoClient[_DocumentType] = client self._server_session = server_session self._options = options self._cluster_time = None @@ -451,7 +464,7 @@ class ClientSession(object): self._implicit = implicit self._transaction = _Transaction(None, client) - def end_session(self): + def end_session(self) -> None: """Finish this session. If a transaction has started, abort it. It is an error to use the session after the session has ended. @@ -474,39 +487,39 @@ class ClientSession(object): if self._server_session is None: raise InvalidOperation("Cannot use ended session") - def __enter__(self): + def __enter__(self) -> "ClientSession[_DocumentType]": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self._end_session(lock=True) @property - def client(self): + def client(self) -> "MongoClient[_DocumentType]": """The :class:`~pymongo.mongo_client.MongoClient` this session was created from. """ return self._client @property - def options(self): + def options(self) -> SessionOptions: """The :class:`SessionOptions` this session was created with.""" return self._options @property - def session_id(self): + def session_id(self) -> Mapping[str, Any]: """A BSON document, the opaque server session identifier.""" self._check_ended() return self._server_session.session_id @property - def cluster_time(self): + def cluster_time(self) -> Optional[Mapping[str, Any]]: """The cluster time returned by the last operation executed in this session. """ return self._cluster_time @property - def operation_time(self): + def operation_time(self) -> Optional[Timestamp]: """The operation time returned by the last operation executed in this session. """ @@ -522,8 +535,14 @@ class ClientSession(object): return val return getattr(self.client, name) - def with_transaction(self, callback, read_concern=None, write_concern=None, - read_preference=None, max_commit_time_ms=None): + def with_transaction( + self, + callback: Callable[["ClientSession"], _T], + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + read_preference: Optional[_ServerMode] = None, + max_commit_time_ms: Optional[int] = None, + ) -> _T: """Execute a callback in a transaction. This method starts a transaction on this session, executes ``callback`` @@ -649,8 +668,13 @@ class ClientSession(object): # Commit succeeded. return ret - def start_transaction(self, read_concern=None, write_concern=None, - read_preference=None, max_commit_time_ms=None): + def start_transaction( + self, + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + read_preference: Optional[_ServerMode] = None, + max_commit_time_ms: Optional[int] = None, + ) -> ContextManager: """Start a multi-statement transaction. Takes the same arguments as :class:`TransactionOptions`. @@ -685,7 +709,7 @@ class ClientSession(object): self._start_retryable_write() return _TransactionContext(self) - def commit_transaction(self): + def commit_transaction(self) -> None: """Commit a multi-statement transaction. .. versionadded:: 3.7 @@ -729,7 +753,7 @@ class ClientSession(object): finally: self._transaction.state = _TxnState.COMMITTED - def abort_transaction(self): + def abort_transaction(self) -> None: """Abort a multi-statement transaction. .. versionadded:: 3.7 @@ -804,7 +828,7 @@ class ClientSession(object): if cluster_time["clusterTime"] > self._cluster_time["clusterTime"]: self._cluster_time = cluster_time - def advance_cluster_time(self, cluster_time): + def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None: """Update the cluster time for this session. :Parameters: @@ -827,7 +851,7 @@ class ClientSession(object): if operation_time > self._operation_time: self._operation_time = operation_time - def advance_operation_time(self, operation_time): + def advance_operation_time(self, operation_time: Timestamp) -> None: """Update the operation time for this session. :Parameters: @@ -856,12 +880,12 @@ class ClientSession(object): self._transaction.recovery_token = recovery_token @property - def has_ended(self): + def has_ended(self) -> bool: """True if this session is finished.""" return self._server_session is None @property - def in_transaction(self): + def in_transaction(self) -> bool: """True if this session has an active multi-statement transaction. .. versionadded:: 3.10 diff --git a/pymongo/collation.py b/pymongo/collation.py index 873d60333..e398264ac 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -16,6 +16,7 @@ .. _collations: http://userguide.icu-project.org/collation/concepts """ +from typing import Any, Dict, Mapping, Optional, Union from pymongo import common @@ -151,18 +152,18 @@ class Collation(object): __slots__ = ("__document",) - def __init__(self, locale, - caseLevel=None, - caseFirst=None, - strength=None, - numericOrdering=None, - alternate=None, - maxVariable=None, - normalization=None, - backwards=None, - **kwargs): + def __init__(self, locale: str, + caseLevel: Optional[bool] = None, + caseFirst: Optional[str] = None, + strength: Optional[int] = None, + numericOrdering: Optional[bool] = None, + alternate: Optional[str] = None, + maxVariable: Optional[str] = None, + normalization: Optional[bool] = None, + backwards: Optional[bool] = None, + **kwargs: Any) -> None: locale = common.validate_string('locale', locale) - self.__document = {'locale': locale} + self.__document: Dict[str, Any] = {'locale': locale} if caseLevel is not None: self.__document['caseLevel'] = common.validate_boolean( 'caseLevel', caseLevel) @@ -190,7 +191,7 @@ class Collation(object): self.__document.update(kwargs) @property - def document(self): + def document(self) -> Dict[str, Any]: """The document representation of this collation. .. note:: @@ -204,16 +205,16 @@ class Collation(object): return 'Collation(%s)' % ( ', '.join('%s=%r' % (key, document[key]) for key in document),) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Collation): return self.document == other.document return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other -def validate_collation_or_none(value): +def validate_collation_or_none(value: Optional[Union[Mapping[str, Any], Collation]]) -> Optional[Dict[str, Any]]: if value is None: return None if isinstance(value, Collation): diff --git a/pymongo/collection.py b/pymongo/collection.py index ecb82a2ca..aa2d148fb 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -14,44 +14,45 @@ """Collection level utilities for Mongo.""" -import datetime -import warnings - from collections import abc +from typing import (TYPE_CHECKING, Any, Generic, Iterable, List, Mapping, + MutableMapping, Optional, Sequence, Tuple, Union) from bson.code import Code +from bson.codec_options import CodecOptions from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument -from bson.codec_options import CodecOptions from bson.son import SON -from pymongo import (common, - helpers, - message) +from bson.timestamp import Timestamp +from pymongo import common, helpers, message from pymongo.aggregation import (_CollectionAggregationCommand, _CollectionRawAggregationCommand) from pymongo.bulk import _Bulk -from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor -from pymongo.collation import validate_collation_or_none from pymongo.change_stream import CollectionChangeStream +from pymongo.collation import validate_collation_or_none +from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor from pymongo.cursor import Cursor, RawBatchCursor -from pymongo.errors import (ConfigurationError, - InvalidName, - InvalidOperation, +from pymongo.errors import (ConfigurationError, InvalidName, InvalidOperation, OperationFailure) from pymongo.helpers import _check_write_command_response from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS -from pymongo.operations import IndexModel -from pymongo.read_preferences import ReadPreference -from pymongo.results import (BulkWriteResult, - DeleteResult, - InsertOneResult, - InsertManyResult, - UpdateResult) +from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, InsertOne, + ReplaceOne, UpdateMany, UpdateOne) +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.results import (BulkWriteResult, DeleteResult, InsertManyResult, + InsertOneResult, UpdateResult) +from pymongo.typings import _CollationIn, _DocumentIn, _DocumentType, _Pipeline from pymongo.write_concern import WriteConcern _FIND_AND_MODIFY_DOC_FIELDS = {'value': 1} +_WriteOp = Union[InsertOne, DeleteOne, DeleteMany, ReplaceOne, UpdateOne, UpdateMany] +# Hint supports index name, "myIndex", or list of index pairs: [('x', 1), ('y', -1)] +_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]] +_IndexKeyHint = Union[str, _IndexList] + + class ReturnDocument(object): """An enum used with :meth:`~pymongo.collection.Collection.find_one_and_replace` and @@ -65,13 +66,28 @@ class ReturnDocument(object): """Return the updated/replaced or inserted document.""" -class Collection(common.BaseObject): +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + from pymongo.database import Database + from pymongo.read_concern import ReadConcern + + +class Collection(common.BaseObject, Generic[_DocumentType]): """A Mongo collection. """ - def __init__(self, database, name, create=False, codec_options=None, - read_preference=None, write_concern=None, read_concern=None, - session=None, **kwargs): + def __init__( + self, + database: "Database[_DocumentType]", + name: str, + create: Optional[bool] = False, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional["ReadConcern"] = None, + session: Optional["ClientSession"] = None, + **kwargs: Any, + ) -> None: """Get / create a Mongo collection. Raises :class:`TypeError` if `name` is not an instance of @@ -169,7 +185,7 @@ class Collection(common.BaseObject): "null character") collation = validate_collation_or_none(kwargs.pop('collation', None)) - self.__database = database + self.__database: Database[_DocumentType] = database self.__name = name self.__full_name = "%s.%s" % (self.__database.name, self.__name) if create or kwargs or collation: @@ -252,7 +268,7 @@ class Collection(common.BaseObject): write_concern=self._write_concern_for(session), collation=collation, session=session) - def __getattr__(self, name): + def __getattr__(self, name: str) -> "Collection[_DocumentType]": """Get a sub-collection of this collection by name. Raises InvalidName if an invalid collection name is used. @@ -268,7 +284,7 @@ class Collection(common.BaseObject): name, full_name, full_name)) return self.__getitem__(name) - def __getitem__(self, name): + def __getitem__(self, name: str) -> "Collection[_DocumentType]": return Collection(self.__database, "%s.%s" % (self.__name, name), False, @@ -280,25 +296,25 @@ class Collection(common.BaseObject): def __repr__(self): return "Collection(%r, %r)" % (self.__database, self.__name) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Collection): return (self.__database == other.database and self.__name == other.name) return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: return hash((self.__database, self.__name)) - def __bool__(self): + def __bool__(self) -> bool: raise NotImplementedError("Collection objects do not implement truth " "value testing or bool(). Please compare " "with None instead: collection is not None") @property - def full_name(self): + def full_name(self) -> str: """The full name of this :class:`Collection`. The full name is of the form `database_name.collection_name`. @@ -306,19 +322,24 @@ class Collection(common.BaseObject): return self.__full_name @property - def name(self): + def name(self) -> str: """The name of this :class:`Collection`.""" return self.__name @property - def database(self): + def database(self) -> "Database[_DocumentType]": """The :class:`~pymongo.database.Database` that this :class:`Collection` is a part of. """ return self.__database - def with_options(self, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def with_options( + self, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional["ReadConcern"] = None, + ) -> "Collection[_DocumentType]": """Get a clone of this collection changing the specified settings. >>> coll1.read_preference @@ -356,8 +377,13 @@ class Collection(common.BaseObject): write_concern or self.write_concern, read_concern or self.read_concern) - def bulk_write(self, requests, ordered=True, - bypass_document_validation=False, session=None): + def bulk_write( + self, + requests: Sequence[_WriteOp], + ordered: bool = True, + bypass_document_validation: bool = False, + session: Optional["ClientSession"] = None + ) -> BulkWriteResult: """Send a batch of write operations to the server. Requests are passed as a list of write operation instances ( @@ -470,8 +496,10 @@ class Collection(common.BaseObject): if not isinstance(doc, RawBSONDocument): return doc.get('_id') - def insert_one(self, document, bypass_document_validation=False, - session=None): + def insert_one(self, document: _DocumentIn, + bypass_document_validation: bool = False, + session: Optional["ClientSession"] = None + ) -> InsertOneResult: """Insert a single document. >>> db.test.count_documents({'x': 1}) @@ -520,8 +548,12 @@ class Collection(common.BaseObject): bypass_doc_val=bypass_document_validation, session=session), write_concern.acknowledged) - def insert_many(self, documents, ordered=True, - bypass_document_validation=False, session=None): + def insert_many(self, + documents: Iterable[_DocumentIn], + ordered: bool = True, + bypass_document_validation: bool = False, + session: Optional["ClientSession"] = None + ) -> InsertManyResult: """Insert an iterable of documents. >>> db.test.count_documents({}) @@ -565,7 +597,7 @@ class Collection(common.BaseObject): or isinstance(documents, abc.Mapping) or not documents): raise TypeError("documents must be a non-empty list") - inserted_ids = [] + inserted_ids: List[ObjectId] = [] def gen(): """A generator that validates documents and handles _ids.""" for document in documents: @@ -671,9 +703,16 @@ class Collection(common.BaseObject): (write_concern or self.write_concern).acknowledged and not multi, _update, session) - def replace_one(self, filter, replacement, upsert=False, - bypass_document_validation=False, collation=None, - hint=None, session=None, let=None): + def replace_one(self, + filter: Mapping[str, Any], + replacement: Mapping[str, Any], + upsert: bool = False, + bypass_document_validation: bool = False, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional["ClientSession"] = None, + let: Optional[Mapping[str, Any]] = None + ) -> UpdateResult: """Replace a single document matching the filter. >>> for doc in db.test.find({}): @@ -755,10 +794,17 @@ class Collection(common.BaseObject): collation=collation, hint=hint, session=session, let=let), write_concern.acknowledged) - def update_one(self, filter, update, upsert=False, - bypass_document_validation=False, - collation=None, array_filters=None, hint=None, - session=None, let=None): + def update_one(self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + bypass_document_validation: bool = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional["ClientSession"] = None, + let: Optional[Mapping[str, Any]] = None + ) -> UpdateResult: """Update a single document matching the filter. >>> for doc in db.test.find(): @@ -800,8 +846,8 @@ class Collection(common.BaseObject): - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `let` (optional): Map of parameter names and values. Values must be - constant or closed expressions that do not reference document - fields. Parameters can then be accessed as variables in an + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). :Returns: @@ -836,9 +882,17 @@ class Collection(common.BaseObject): hint=hint, session=session, let=let), write_concern.acknowledged) - def update_many(self, filter, update, upsert=False, array_filters=None, - bypass_document_validation=False, collation=None, - hint=None, session=None, let=None): + def update_many(self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + bypass_document_validation: Optional[bool] = None, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional["ClientSession"] = None, + let: Optional[Mapping[str, Any]] = None + ) -> UpdateResult: """Update one or more documents that match the filter. >>> for doc in db.test.find(): @@ -880,8 +934,8 @@ class Collection(common.BaseObject): - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `let` (optional): Map of parameter names and values. Values must be - constant or closed expressions that do not reference document - fields. Parameters can then be accessed as variables in an + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). :Returns: @@ -916,7 +970,7 @@ class Collection(common.BaseObject): hint=hint, session=session, let=let), write_concern.acknowledged) - def drop(self, session=None): + def drop(self, session: Optional["ClientSession"] = None) -> None: """Alias for :meth:`~pymongo.database.Database.drop_collection`. :Parameters: @@ -1005,8 +1059,13 @@ class Collection(common.BaseObject): (write_concern or self.write_concern).acknowledged and not multi, _delete, session) - def delete_one(self, filter, collation=None, hint=None, session=None, - let=None): + def delete_one(self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional["ClientSession"] = None, + let: Optional[Mapping[str, Any]] = None + ) -> DeleteResult: """Delete a single document matching the filter. >>> db.test.count_documents({'x': 1}) @@ -1030,8 +1089,8 @@ class Collection(common.BaseObject): - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `let` (optional): Map of parameter names and values. Values must be - constant or closed expressions that do not reference document - fields. Parameters can then be accessed as variables in an + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). :Returns: @@ -1055,8 +1114,13 @@ class Collection(common.BaseObject): collation=collation, hint=hint, session=session, let=let), write_concern.acknowledged) - def delete_many(self, filter, collation=None, hint=None, session=None, - let=None): + def delete_many(self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional["ClientSession"] = None, + let: Optional[Mapping[str, Any]] = None + ) -> DeleteResult: """Delete one or more documents matching the filter. >>> db.test.count_documents({'x': 1}) @@ -1080,8 +1144,8 @@ class Collection(common.BaseObject): - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `let` (optional): Map of parameter names and values. Values must be - constant or closed expressions that do not reference document - fields. Parameters can then be accessed as variables in an + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). :Returns: @@ -1105,7 +1169,7 @@ class Collection(common.BaseObject): collation=collation, hint=hint, session=session, let=let), write_concern.acknowledged) - def find_one(self, filter=None, *args, **kwargs): + def find_one(self, filter: Optional[Any] = None, *args: Any, **kwargs: Any) -> Optional[_DocumentType]: """Get a single document from the database. All arguments to :meth:`find` are also valid arguments for @@ -1139,7 +1203,7 @@ class Collection(common.BaseObject): return result return None - def find(self, *args, **kwargs): + def find(self, *args: Any, **kwargs: Any) -> Cursor[_DocumentType]: """Query the database. The `filter` argument is a prototype document that all results @@ -1328,7 +1392,7 @@ class Collection(common.BaseObject): """ return Cursor(self, *args, **kwargs) - def find_raw_batches(self, *args, **kwargs): + def find_raw_batches(self, *args: Any, **kwargs: Any) -> RawBatchCursor[_DocumentType]: """Query the database and retrieve batches of raw BSON. Similar to the :meth:`find` method but returns a @@ -1396,7 +1460,7 @@ class Collection(common.BaseObject): batch = result['cursor']['firstBatch'] return batch[0] if batch else None - def estimated_document_count(self, **kwargs): + def estimated_document_count(self, **kwargs: Any) -> int: """Get an estimate of the number of documents in this collection using collection metadata. @@ -1445,7 +1509,7 @@ class Collection(common.BaseObject): return self.__database.client._retryable_read( _cmd, self.read_preference, None) - def count_documents(self, filter, session=None, **kwargs): + def count_documents(self, filter: Mapping[str, Any], session: Optional["ClientSession"] = None, **kwargs: Any) -> int: """Count the number of documents in this collection. .. note:: For a fast count of the total documents in a collection see @@ -1523,7 +1587,7 @@ class Collection(common.BaseObject): return self.__database.client._retryable_read( _cmd, self._read_preference_for(session), session) - def create_indexes(self, indexes, session=None, **kwargs): + def create_indexes(self, indexes: Sequence[IndexModel], session: Optional["ClientSession"] = None, **kwargs: Any) -> List[str]: """Create one or more indexes on this collection. >>> from pymongo import IndexModel, ASCENDING, DESCENDING @@ -1598,7 +1662,7 @@ class Collection(common.BaseObject): session=session) return names - def create_index(self, keys, session=None, **kwargs): + def create_index(self, keys: _IndexKeyHint, session: Optional["ClientSession"] = None, **kwargs: Any) -> str: """Creates an index on this collection. Takes either a single key or a list of (key, direction) pairs. @@ -1701,7 +1765,7 @@ class Collection(common.BaseObject): index = IndexModel(keys, **kwargs) return self.__create_indexes([index], session, **cmd_options)[0] - def drop_indexes(self, session=None, **kwargs): + def drop_indexes(self, session: Optional["ClientSession"] = None, **kwargs: Any) -> None: """Drops all indexes on this collection. Can be used on non-existant collections or collections with no indexes. @@ -1727,7 +1791,7 @@ class Collection(common.BaseObject): """ self.drop_index("*", session=session, **kwargs) - def drop_index(self, index_or_name, session=None, **kwargs): + def drop_index(self, index_or_name: _IndexKeyHint, session: Optional["ClientSession"] = None, **kwargs: Any) -> None: """Drops the specified index on this collection. Can be used on non-existant collections or collections with no @@ -1780,7 +1844,7 @@ class Collection(common.BaseObject): write_concern=self._write_concern_for(session), session=session) - def list_indexes(self, session=None): + def list_indexes(self, session: Optional["ClientSession"] = None) -> CommandCursor[MutableMapping[str, Any]]: """Get a cursor over the index documents for this collection. >>> for index in db.test.list_indexes(): @@ -1829,7 +1893,7 @@ class Collection(common.BaseObject): return self.__database.client._retryable_read( _cmd, read_pref, session) - def index_information(self, session=None): + def index_information(self, session: Optional["ClientSession"] = None) -> MutableMapping[str, Any]: """Get information on this collection's indexes. Returns a dictionary where the keys are index names (as @@ -1863,7 +1927,7 @@ class Collection(common.BaseObject): info[index.pop("name")] = index return info - def options(self, session=None): + def options(self, session: Optional["ClientSession"] = None) -> MutableMapping[str, Any]: """Get the options set on this collection. Returns a dictionary of options and their values - see @@ -1896,6 +1960,7 @@ class Collection(common.BaseObject): return {} options = result.get("options", {}) + assert options is not None if "create" in options: del options["create"] @@ -1911,7 +1976,7 @@ class Collection(common.BaseObject): cmd.get_cursor, cmd.get_read_preference(session), session, retryable=not cmd._performs_write) - def aggregate(self, pipeline, session=None, let=None, **kwargs): + def aggregate(self, pipeline: _Pipeline, session: Optional["ClientSession"] = None, let: Optional[Mapping[str, Any]] = None, **kwargs: Any) -> CommandCursor[_DocumentType]: """Perform an aggregation using the aggregation framework on this collection. @@ -1993,7 +2058,9 @@ class Collection(common.BaseObject): let=let, **kwargs) - def aggregate_raw_batches(self, pipeline, session=None, **kwargs): + def aggregate_raw_batches( + self, pipeline: _Pipeline, session: Optional["ClientSession"] = None, **kwargs: Any + ) -> RawBatchCursor[_DocumentType]: """Perform an aggregation and retrieve batches of raw BSON. Similar to the :meth:`aggregate` method but returns a @@ -2030,9 +2097,17 @@ class Collection(common.BaseObject): explicit_session=session is not None, **kwargs) - def watch(self, pipeline=None, full_document=None, resume_after=None, - max_await_time_ms=None, batch_size=None, collation=None, - start_at_operation_time=None, session=None, start_after=None): + def watch(self, + pipeline: Optional[_Pipeline] = None, + full_document: Optional[str] = None, + resume_after: Optional[Mapping[str, Any]] = None, + max_await_time_ms: Optional[int] = None, + batch_size: Optional[int] = None, + collation: Optional[_CollationIn] = None, + start_at_operation_time: Optional[Timestamp] = None, + session: Optional["ClientSession"] = None, + start_after: Optional[Mapping[str, Any]] = None, + ) -> CollectionChangeStream[_DocumentType]: """Watch changes on this collection. Performs an aggregation with an implicit initial ``$changeStream`` @@ -2132,7 +2207,7 @@ class Collection(common.BaseObject): batch_size, collation, start_at_operation_time, session, start_after) - def rename(self, new_name, session=None, **kwargs): + def rename(self, new_name: str, session: Optional["ClientSession"] = None, **kwargs: Any) -> MutableMapping[str, Any]: """Rename this collection. If operating in auth mode, client must be authorized as an @@ -2183,7 +2258,9 @@ class Collection(common.BaseObject): parse_write_concern_error=True, session=s, client=self.__database.client) - def distinct(self, key, filter=None, session=None, **kwargs): + def distinct( + self, key: str, filter: Optional[Mapping[str, Any]] = None, session: Optional["ClientSession"] = None, **kwargs: Any + ) -> List: """Get a list of distinct values for `key` among all documents in this collection. @@ -2283,7 +2360,7 @@ class Collection(common.BaseObject): raise ConfigurationError( 'arrayFilters is unsupported for unacknowledged ' 'writes.') - cmd["arrayFilters"] = array_filters + cmd["arrayFilters"] = list(array_filters) if hint is not None: if sock_info.max_wire_version < 8: raise ConfigurationError( @@ -2307,9 +2384,15 @@ class Collection(common.BaseObject): return self.__database.client._retryable_write( write_concern.acknowledged, _find_and_modify, session) - def find_one_and_delete(self, filter, - projection=None, sort=None, hint=None, - session=None, let=None, **kwargs): + def find_one_and_delete(self, + filter: Mapping[str, Any], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + sort: Optional[_IndexList] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional["ClientSession"] = None, + let: Optional[Mapping[str, Any]] = None, + **kwargs: Any, + ) -> _DocumentType: """Finds a single document and deletes it, returning the document. >>> db.test.count_documents({'x': 1}) @@ -2357,8 +2440,8 @@ class Collection(common.BaseObject): as keyword arguments (for example maxTimeMS can be used with recent server versions). - `let` (optional): Map of parameter names and values. Values must be - constant or closed expressions that do not reference document - fields. Parameters can then be accessed as variables in an + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). .. versionchanged:: 4.1 @@ -2384,10 +2467,18 @@ class Collection(common.BaseObject): return self.__find_and_modify(filter, projection, sort, let=let, hint=hint, session=session, **kwargs) - def find_one_and_replace(self, filter, replacement, - projection=None, sort=None, upsert=False, - return_document=ReturnDocument.BEFORE, - hint=None, session=None, let=None, **kwargs): + def find_one_and_replace(self, + filter: Mapping[str, Any], + replacement: Mapping[str, Any], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + sort: Optional[_IndexList] = None, + upsert: bool = False, + return_document: bool = ReturnDocument.BEFORE, + hint: Optional[_IndexKeyHint] = None, + session: Optional["ClientSession"] = None, + let: Optional[Mapping[str, Any]] = None, + **kwargs: Any, + ) -> _DocumentType: """Finds a single document and replaces it, returning either the original or the replaced document. @@ -2438,8 +2529,8 @@ class Collection(common.BaseObject): - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `let` (optional): Map of parameter names and values. Values must be - constant or closed expressions that do not reference document - fields. Parameters can then be accessed as variables in an + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). - `**kwargs` (optional): additional command arguments can be passed as keyword arguments (for example maxTimeMS can be used with @@ -2470,11 +2561,19 @@ class Collection(common.BaseObject): sort, upsert, return_document, let=let, hint=hint, session=session, **kwargs) - def find_one_and_update(self, filter, update, - projection=None, sort=None, upsert=False, - return_document=ReturnDocument.BEFORE, - array_filters=None, hint=None, session=None, - let=None, **kwargs): + def find_one_and_update(self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + sort: Optional[_IndexList] = None, + upsert: bool = False, + return_document: bool = ReturnDocument.BEFORE, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional["ClientSession"] = None, + let: Optional[Mapping[str, Any]] = None, + **kwargs: Any, + ) -> _DocumentType: """Finds a single document and updates it, returning either the original or the updated document. @@ -2564,8 +2663,8 @@ class Collection(common.BaseObject): - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `let` (optional): Map of parameter names and values. Values must be - constant or closed expressions that do not reference document - fields. Parameters can then be accessed as variables in an + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). - `**kwargs` (optional): additional command arguments can be passed as keyword arguments (for example maxTimeMS can be used with @@ -2600,15 +2699,15 @@ class Collection(common.BaseObject): array_filters, hint=hint, let=let, session=session, **kwargs) - def __iter__(self): + def __iter__(self) -> "Collection[_DocumentType]": return self - def __next__(self): + def __next__(self) -> None: raise TypeError("'Collection' object is not iterable") next = __next__ - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> None: """This is only here so that some API misusages are easier to debug. """ if "." not in self.__name: diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 21822ac61..b7dbf7a8e 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -15,28 +15,38 @@ """CommandCursor class to iterate over command results.""" from collections import deque +from typing import (TYPE_CHECKING, Any, Generic, Iterator, Mapping, Optional, + Tuple) from bson import _convert_raw_document_lists_to_streams -from pymongo.cursor import _SocketManager, _CURSOR_CLOSED_ERRORS -from pymongo.errors import (ConnectionFailure, - InvalidOperation, +from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager +from pymongo.errors import (ConnectionFailure, InvalidOperation, OperationFailure) -from pymongo.message import (_CursorAddress, - _GetMore, - _RawBatchGetMore) +from pymongo.message import _CursorAddress, _GetMore, _RawBatchGetMore from pymongo.response import PinnedResponse +from pymongo.typings import _DocumentType + +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + from pymongo.collection import Collection -class CommandCursor(object): +class CommandCursor(Generic[_DocumentType]): """A cursor / iterator over command cursors.""" _getmore_class = _GetMore - def __init__(self, collection, cursor_info, address, - batch_size=0, max_await_time_ms=None, session=None, - explicit_session=False): + def __init__(self, + collection: "Collection[_DocumentType]", + cursor_info: Mapping[str, Any], + address: Optional[Tuple[str, Optional[int]]], + batch_size: int = 0, + max_await_time_ms: Optional[int] = None, + session: Optional["ClientSession"] = None, + explicit_session: bool = False, + ) -> None: """Create a new command cursor.""" - self.__sock_mgr = None - self.__collection = collection + self.__sock_mgr: Any = None + self.__collection: Collection[_DocumentType] = collection self.__id = cursor_info['id'] self.__data = deque(cursor_info['firstBatch']) self.__postbatchresumetoken = cursor_info.get('postBatchResumeToken') @@ -60,7 +70,7 @@ class CommandCursor(object): and max_await_time_ms is not None): raise TypeError("max_await_time_ms must be an integer or None") - def __del__(self): + def __del__(self) -> None: self.__die() def __die(self, synchronous=False): @@ -92,12 +102,12 @@ class CommandCursor(object): self.__session._end_session(lock=synchronous) self.__session = None - def close(self): + def close(self) -> None: """Explicitly close / kill this cursor. """ self.__die(True) - def batch_size(self, batch_size): + def batch_size(self, batch_size: int) -> "CommandCursor[_DocumentType]": """Limits the number of documents returned in one batch. Each batch requires a round trip to the server. It can be adjusted to optimize performance and limit data transfer. @@ -222,7 +232,7 @@ class CommandCursor(object): return len(self.__data) @property - def alive(self): + def alive(self) -> bool: """Does this cursor have the potential to return more data? Even if :attr:`alive` is ``True``, :meth:`next` can raise @@ -239,12 +249,12 @@ class CommandCursor(object): return bool(len(self.__data) or (not self.__killed)) @property - def cursor_id(self): + def cursor_id(self) -> int: """Returns the id of the cursor.""" return self.__id @property - def address(self): + def address(self) -> Optional[Tuple[str, Optional[int]]]: """The (host, port) of the server used, or None. .. versionadded:: 3.0 @@ -252,18 +262,19 @@ class CommandCursor(object): return self.__address @property - def session(self): + def session(self) -> Optional["ClientSession"]: """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. .. versionadded:: 3.6 """ if self.__explicit_session: return self.__session + return None - def __iter__(self): + def __iter__(self) -> Iterator[_DocumentType]: return self - def next(self): + def next(self) -> _DocumentType: """Advance the cursor.""" # Block until a document is returnable. while self.alive: @@ -284,19 +295,25 @@ class CommandCursor(object): else: return None - def __enter__(self): + def __enter__(self) -> "CommandCursor[_DocumentType]": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() -class RawBatchCommandCursor(CommandCursor): +class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]): _getmore_class = _RawBatchGetMore - def __init__(self, collection, cursor_info, address, - batch_size=0, max_await_time_ms=None, session=None, - explicit_session=False): + def __init__(self, + collection: "Collection[_DocumentType]", + cursor_info: Mapping[str, Any], + address: Optional[Tuple[str, Optional[int]]], + batch_size: int = 0, + max_await_time_ms: Optional[int] = None, + session: Optional["ClientSession"] = None, + explicit_session: bool = False, + ) -> None: """Create a new cursor / iterator over raw batches of BSON data. Should not be called directly by application developers - diff --git a/pymongo/common.py b/pymongo/common.py index 14789c810..fa2fe9bf1 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -17,8 +17,9 @@ import datetime import warnings - -from collections import abc, OrderedDict +from collections import OrderedDict, abc +from typing import (Any, Callable, Dict, List, Mapping, MutableMapping, + Optional, Sequence, Tuple, Type, Union, cast) from urllib.parse import unquote_plus from bson import SON @@ -29,18 +30,18 @@ from pymongo.auth import MECHANISMS from pymongo.compression_support import (validate_compressors, validate_zlib_compression_level) from pymongo.driver_info import DriverInfo -from pymongo.server_api import ServerApi from pymongo.errors import ConfigurationError from pymongo.monitoring import _validate_event_listeners from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _MONGOS_MODES, _ServerMode +from pymongo.server_api import ServerApi from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern -ORDERED_TYPES = (SON, OrderedDict) +ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict) # Defaults until we connect to a server and get updated limits. MAX_BSON_SIZE = 16 * (1024 ** 2) -MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE +MAX_MESSAGE_SIZE: int = 2 * MAX_BSON_SIZE MIN_WIRE_VERSION = 0 MAX_WIRE_VERSION = 0 MAX_WRITE_BATCH_SIZE = 1000 @@ -85,13 +86,13 @@ MIN_POOL_SIZE = 0 MAX_CONNECTING = 2 # Default value for maxIdleTimeMS. -MAX_IDLE_TIME_MS = None +MAX_IDLE_TIME_MS: Optional[int] = None # Default value for maxIdleTimeMS in seconds. -MAX_IDLE_TIME_SEC = None +MAX_IDLE_TIME_SEC: Optional[int] = None # Default value for waitQueueTimeoutMS in seconds. -WAIT_QUEUE_TIMEOUT = None +WAIT_QUEUE_TIMEOUT: Optional[int] = None # Default value for localThresholdMS. LOCAL_THRESHOLD_MS = 15 @@ -103,10 +104,10 @@ RETRY_WRITES = True RETRY_READS = True # The error code returned when a command doesn't exist. -COMMAND_NOT_FOUND_CODES = (59,) +COMMAND_NOT_FOUND_CODES: Sequence[int] = (59,) # Error codes to ignore if GridFS calls createIndex on a secondary -UNAUTHORIZED_CODES = (13, 16547, 16548) +UNAUTHORIZED_CODES: Sequence[int] = (13, 16547, 16548) # Maximum number of sessions to send in a single endSessions command. # From the driver sessions spec. @@ -116,7 +117,7 @@ _MAX_END_SESSIONS = 10000 SRV_SERVICE_NAME = "mongodb" -def partition_node(node): +def partition_node(node: str) -> Tuple[str, int]: """Split a host:port string into (host, int(port)) pair.""" host = node port = 27017 @@ -128,7 +129,7 @@ def partition_node(node): return host, port -def clean_node(node): +def clean_node(node: str) -> Tuple[str, int]: """Split and normalize a node name from a hello response.""" host, port = partition_node(node) @@ -139,7 +140,7 @@ def clean_node(node): return host.lower(), port -def raise_config_error(key, dummy): +def raise_config_error(key: str, dummy: Any) -> None: """Raise ConfigurationError with the given key name.""" raise ConfigurationError("Unknown option %s" % (key,)) @@ -154,14 +155,14 @@ _UUID_REPRESENTATIONS = { } -def validate_boolean(option, value): +def validate_boolean(option: str, value: Any) -> bool: """Validates that 'value' is True or False.""" if isinstance(value, bool): return value raise TypeError("%s must be True or False" % (option,)) -def validate_boolean_or_string(option, value): +def validate_boolean_or_string(option: str, value: Any) -> bool: """Validates that value is True, False, 'true', or 'false'.""" if isinstance(value, str): if value not in ('true', 'false'): @@ -171,7 +172,7 @@ def validate_boolean_or_string(option, value): return validate_boolean(option, value) -def validate_integer(option, value): +def validate_integer(option: str, value: Any) -> int: """Validates that 'value' is an integer (or basestring representation). """ if isinstance(value, int): @@ -185,7 +186,7 @@ def validate_integer(option, value): raise TypeError("Wrong type for %s, value must be an integer" % (option,)) -def validate_positive_integer(option, value): +def validate_positive_integer(option: str, value: Any) -> int: """Validate that 'value' is a positive integer, which does not include 0. """ val = validate_integer(option, value) @@ -195,7 +196,7 @@ def validate_positive_integer(option, value): return val -def validate_non_negative_integer(option, value): +def validate_non_negative_integer(option: str, value: Any) -> int: """Validate that 'value' is a positive integer or 0. """ val = validate_integer(option, value) @@ -205,7 +206,7 @@ def validate_non_negative_integer(option, value): return val -def validate_readable(option, value): +def validate_readable(option: str, value: Any) -> Optional[str]: """Validates that 'value' is file-like and readable. """ if value is None: @@ -217,7 +218,7 @@ def validate_readable(option, value): return value -def validate_positive_integer_or_none(option, value): +def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]: """Validate that 'value' is a positive integer or None. """ if value is None: @@ -225,7 +226,7 @@ def validate_positive_integer_or_none(option, value): return validate_positive_integer(option, value) -def validate_non_negative_integer_or_none(option, value): +def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]: """Validate that 'value' is a positive integer or 0 or None. """ if value is None: @@ -233,7 +234,7 @@ def validate_non_negative_integer_or_none(option, value): return validate_non_negative_integer(option, value) -def validate_string(option, value): +def validate_string(option: str, value: Any) -> str: """Validates that 'value' is an instance of `str`. """ if isinstance(value, str): @@ -242,7 +243,7 @@ def validate_string(option, value): "str" % (option,)) -def validate_string_or_none(option, value): +def validate_string_or_none(option: str, value: Any) -> Optional[str]: """Validates that 'value' is an instance of `basestring` or `None`. """ if value is None: @@ -250,7 +251,7 @@ def validate_string_or_none(option, value): return validate_string(option, value) -def validate_int_or_basestring(option, value): +def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]: """Validates that 'value' is an integer or string. """ if isinstance(value, int): @@ -264,7 +265,7 @@ def validate_int_or_basestring(option, value): "integer or a string" % (option,)) -def validate_non_negative_int_or_basestring(option, value): +def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]: """Validates that 'value' is an integer or string. """ if isinstance(value, int): @@ -279,7 +280,7 @@ def validate_non_negative_int_or_basestring(option, value): "non negative integer or a string" % (option,)) -def validate_positive_float(option, value): +def validate_positive_float(option: str, value: Any) -> float: """Validates that 'value' is a float, or can be converted to one, and is positive. """ @@ -299,7 +300,7 @@ def validate_positive_float(option, value): return value -def validate_positive_float_or_zero(option, value): +def validate_positive_float_or_zero(option: str, value: Any) -> float: """Validates that 'value' is 0 or a positive float, or can be converted to 0 or a positive float. """ @@ -308,7 +309,7 @@ def validate_positive_float_or_zero(option, value): return validate_positive_float(option, value) -def validate_timeout_or_none(option, value): +def validate_timeout_or_none(option: str, value: Any) -> Optional[float]: """Validates a timeout specified in milliseconds returning a value in floating point seconds. """ @@ -317,7 +318,7 @@ def validate_timeout_or_none(option, value): return validate_positive_float(option, value) / 1000.0 -def validate_timeout_or_zero(option, value): +def validate_timeout_or_zero(option: str, value: Any) -> float: """Validates a timeout specified in milliseconds returning a value in floating point seconds for the case where None is an error and 0 is valid. Setting the timeout to nothing in the URI string is a @@ -330,7 +331,7 @@ def validate_timeout_or_zero(option, value): return validate_positive_float(option, value) / 1000.0 -def validate_timeout_or_none_or_zero(option, value): +def validate_timeout_or_none_or_zero(option: Any, value: Any) -> Optional[float]: """Validates a timeout specified in milliseconds returning a value in floating point seconds. value=0 and value="0" are treated the same as value=None which means unlimited timeout. @@ -340,7 +341,7 @@ def validate_timeout_or_none_or_zero(option, value): return validate_positive_float(option, value) / 1000.0 -def validate_max_staleness(option, value): +def validate_max_staleness(option: str, value: Any) -> int: """Validates maxStalenessSeconds according to the Max Staleness Spec.""" if value == -1 or value == "-1": # Default: No maximum staleness. @@ -348,7 +349,7 @@ def validate_max_staleness(option, value): return validate_positive_integer(option, value) -def validate_read_preference(dummy, value): +def validate_read_preference(dummy: Any, value: Any) -> _ServerMode: """Validate a read preference. """ if not isinstance(value, _ServerMode): @@ -356,7 +357,7 @@ def validate_read_preference(dummy, value): return value -def validate_read_preference_mode(dummy, value): +def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode: """Validate read preference mode for a MongoClient. .. versionchanged:: 3.5 @@ -368,7 +369,7 @@ def validate_read_preference_mode(dummy, value): return value -def validate_auth_mechanism(option, value): +def validate_auth_mechanism(option: str, value: Any) -> str: """Validate the authMechanism URI option. """ if value not in MECHANISMS: @@ -376,7 +377,7 @@ def validate_auth_mechanism(option, value): return value -def validate_uuid_representation(dummy, value): +def validate_uuid_representation(dummy: Any, value: Any) -> int: """Validate the uuid representation option selected in the URI. """ try: @@ -387,13 +388,13 @@ def validate_uuid_representation(dummy, value): "%s" % (value, tuple(_UUID_REPRESENTATIONS))) -def validate_read_preference_tags(name, value): +def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]]: """Parse readPreferenceTags if passed as a client kwarg. """ if not isinstance(value, list): value = [value] - tag_sets = [] + tag_sets: List = [] for tag_set in value: if tag_set == '': tag_sets.append({}) @@ -416,10 +417,10 @@ _MECHANISM_PROPS = frozenset(['SERVICE_NAME', 'AWS_SESSION_TOKEN']) -def validate_auth_mechanism_properties(option, value): +def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]: """Validate authMechanismProperties.""" value = validate_string(option, value) - props = {} + props: Dict[str, Any] = {} for opt in value.split(','): try: key, val = opt.split(':') @@ -443,7 +444,7 @@ def validate_auth_mechanism_properties(option, value): return props -def validate_document_class(option, value): +def validate_document_class(option: str, value: Any) -> Union[Type[MutableMapping], Type[RawBSONDocument]]: """Validate the document_class option.""" if not issubclass(value, (abc.MutableMapping, RawBSONDocument)): raise TypeError("%s must be dict, bson.son.SON, " @@ -452,7 +453,7 @@ def validate_document_class(option, value): return value -def validate_type_registry(option, value): +def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: """Validate the type_registry option.""" if value is not None and not isinstance(value, TypeRegistry): raise TypeError("%s must be an instance of %s" % ( @@ -460,21 +461,21 @@ def validate_type_registry(option, value): return value -def validate_list(option, value): +def validate_list(option: str, value: Any) -> List: """Validates that 'value' is a list.""" if not isinstance(value, list): raise TypeError("%s must be a list" % (option,)) return value -def validate_list_or_none(option, value): +def validate_list_or_none(option: Any, value: Any) -> Optional[List]: """Validates that 'value' is a list or None.""" if value is None: return value return validate_list(option, value) -def validate_list_or_mapping(option, value): +def validate_list_or_mapping(option: Any, value: Any) -> None: """Validates that 'value' is a list or a document.""" if not isinstance(value, (abc.Mapping, list)): raise TypeError("%s must either be a list or an instance of dict, " @@ -482,7 +483,7 @@ def validate_list_or_mapping(option, value): "collections.Mapping" % (option,)) -def validate_is_mapping(option, value): +def validate_is_mapping(option: str, value: Any) -> None: """Validate the type of method arguments that expect a document.""" if not isinstance(value, abc.Mapping): raise TypeError("%s must be an instance of dict, bson.son.SON, or " @@ -490,7 +491,7 @@ def validate_is_mapping(option, value): "collections.Mapping" % (option,)) -def validate_is_document_type(option, value): +def validate_is_document_type(option: str, value: Any) -> None: """Validate the type of method arguments that expect a MongoDB document.""" if not isinstance(value, (abc.MutableMapping, RawBSONDocument)): raise TypeError("%s must be an instance of dict, bson.son.SON, " @@ -499,7 +500,7 @@ def validate_is_document_type(option, value): "collections.MutableMapping" % (option,)) -def validate_appname_or_none(option, value): +def validate_appname_or_none(option: str, value: Any) -> Optional[str]: """Validate the appname option.""" if value is None: return value @@ -510,7 +511,7 @@ def validate_appname_or_none(option, value): return value -def validate_driver_or_none(option, value): +def validate_driver_or_none(option: Any, value: Any) -> Optional[DriverInfo]: """Validate the driver keyword arg.""" if value is None: return value @@ -519,7 +520,7 @@ def validate_driver_or_none(option, value): return value -def validate_server_api_or_none(option, value): +def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]: """Validate the server_api keyword arg.""" if value is None: return value @@ -528,7 +529,7 @@ def validate_server_api_or_none(option, value): return value -def validate_is_callable_or_none(option, value): +def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]: """Validates that 'value' is a callable.""" if value is None: return value @@ -537,7 +538,7 @@ def validate_is_callable_or_none(option, value): return value -def validate_ok_for_replace(replacement): +def validate_ok_for_replace(replacement: Mapping[str, Any]) -> None: """Validate a replacement document.""" validate_is_mapping("replacement", replacement) # Replacement can be {} @@ -547,7 +548,7 @@ def validate_ok_for_replace(replacement): raise ValueError('replacement can not include $ operators') -def validate_ok_for_update(update): +def validate_ok_for_update(update: Any) -> None: """Validate an update document.""" validate_list_or_mapping("update", update) # Update cannot be {}. @@ -563,7 +564,7 @@ def validate_ok_for_update(update): _UNICODE_DECODE_ERROR_HANDLERS = frozenset(['strict', 'replace', 'ignore']) -def validate_unicode_decode_error_handler(dummy, value): +def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str: """Validate the Unicode decode error handler option of CodecOptions. """ if value not in _UNICODE_DECODE_ERROR_HANDLERS: @@ -573,7 +574,7 @@ def validate_unicode_decode_error_handler(dummy, value): return value -def validate_tzinfo(dummy, value): +def validate_tzinfo(dummy: Any, value: Any) -> Optional[datetime.tzinfo]: """Validate the tzinfo option """ if value is not None and not isinstance(value, datetime.tzinfo): @@ -581,7 +582,7 @@ def validate_tzinfo(dummy, value): return value -def validate_auto_encryption_opts_or_none(option, value): +def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[Any]: """Validate the driver keyword arg.""" if value is None: return value @@ -595,7 +596,7 @@ def validate_auto_encryption_opts_or_none(option, value): # Dictionary where keys are the names of public URI options, and values # are lists of aliases for that option. -URI_OPTIONS_ALIAS_MAP = { +URI_OPTIONS_ALIAS_MAP: Dict[str, List[str]] = { 'tls': ['ssl'], } @@ -603,7 +604,7 @@ URI_OPTIONS_ALIAS_MAP = { # are functions that validate user-input values for that option. If an option # alias uses a different validator than its public counterpart, it should be # included here as a key, value pair. -URI_OPTIONS_VALIDATOR_MAP = { +URI_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = { 'appname': validate_appname_or_none, 'authmechanism': validate_auth_mechanism, 'authmechanismproperties': validate_auth_mechanism_properties, @@ -644,7 +645,7 @@ URI_OPTIONS_VALIDATOR_MAP = { # Dictionary where keys are the names of URI options specific to pymongo, # and values are functions that validate user-input values for those options. -NONSPEC_OPTIONS_VALIDATOR_MAP = { +NONSPEC_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = { 'connect': validate_boolean_or_string, 'driver': validate_driver_or_none, 'server_api': validate_server_api_or_none, @@ -661,7 +662,7 @@ NONSPEC_OPTIONS_VALIDATOR_MAP = { # Dictionary where keys are the names of keyword-only options for the # MongoClient constructor, and values are functions that validate user-input # values for those options. -KW_VALIDATORS = { +KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = { 'document_class': validate_document_class, 'type_registry': validate_type_registry, 'read_preference': validate_read_preference, @@ -677,14 +678,14 @@ KW_VALIDATORS = { # internally-used names of that URI option. Options with only one name # variant need not be included here. Options whose public and internal # names are the same need not be included here. -INTERNAL_URI_OPTION_NAME_MAP = { +INTERNAL_URI_OPTION_NAME_MAP: Dict[str, str] = { 'ssl': 'tls', } # Map from deprecated URI option names to a tuple indicating the method of # their deprecation and any additional information that may be needed to # construct the warning message. -URI_OPTIONS_DEPRECATION_MAP = { +URI_OPTIONS_DEPRECATION_MAP: Dict[str, Tuple[str, str]] = { # format: : (, ), # Supported values: # - 'renamed': should be the new option name. Note that case is @@ -704,11 +705,11 @@ for optname, aliases in URI_OPTIONS_ALIAS_MAP.items(): URI_OPTIONS_VALIDATOR_MAP[optname]) # Map containing all URI option and keyword argument validators. -VALIDATORS = URI_OPTIONS_VALIDATOR_MAP.copy() +VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy() VALIDATORS.update(KW_VALIDATORS) # List of timeout-related options. -TIMEOUT_OPTIONS = [ +TIMEOUT_OPTIONS: List[str] = [ 'connecttimeoutms', 'heartbeatfrequencyms', 'maxidletimems', @@ -722,7 +723,7 @@ TIMEOUT_OPTIONS = [ _AUTH_OPTIONS = frozenset(['authmechanismproperties']) -def validate_auth_option(option, value): +def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]: """Validate optional authentication parameters. """ lower, value = validate(option, value) @@ -732,7 +733,7 @@ def validate_auth_option(option, value): return option, value -def validate(option, value): +def validate(option: str, value: Any) -> Tuple[str, Any]: """Generic validation function. """ lower = option.lower() @@ -741,7 +742,7 @@ def validate(option, value): return option, value -def get_validated_options(options, warn=True): +def get_validated_options(options: Mapping[str, Any], warn: bool = True) -> MutableMapping[str, Any]: """Validate each entry in options and raise a warning if it is not valid. Returns a copy of options with invalid entries removed. @@ -751,6 +752,7 @@ def get_validated_options(options, warn=True): invalid options will be ignored. Otherwise, invalid options will cause errors. """ + validated_options: MutableMapping[str, Any] if isinstance(options, _CaseInsensitiveDictionary): validated_options = _CaseInsensitiveDictionary() get_normed_key = lambda x: x @@ -794,8 +796,8 @@ class BaseObject(object): SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB. """ - def __init__(self, codec_options, read_preference, write_concern, - read_concern): + def __init__(self, codec_options: CodecOptions, read_preference: _ServerMode, write_concern: WriteConcern, + read_concern: ReadConcern) -> None: if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of " @@ -819,14 +821,14 @@ class BaseObject(object): self.__read_concern = read_concern @property - def codec_options(self): + def codec_options(self) -> CodecOptions: """Read only access to the :class:`~bson.codec_options.CodecOptions` of this instance. """ return self.__codec_options @property - def write_concern(self): + def write_concern(self) -> WriteConcern: """Read only access to the :class:`~pymongo.write_concern.WriteConcern` of this instance. @@ -844,7 +846,7 @@ class BaseObject(object): return self.write_concern @property - def read_preference(self): + def read_preference(self) -> _ServerMode: """Read only access to the read preference of this instance. .. versionchanged:: 3.0 @@ -861,7 +863,7 @@ class BaseObject(object): return self.__read_preference @property - def read_concern(self): + def read_concern(self) -> ReadConcern: """Read only access to the :class:`~pymongo.read_concern.ReadConcern` of this instance. diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index d36759528..c9cc041af 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -13,6 +13,7 @@ # limitations under the License. import warnings +from typing import Callable try: import snappy @@ -99,7 +100,7 @@ class CompressionSettings(object): return ZstdContext() -def _zlib_no_compress(data): +def _zlib_no_compress(data, level=None): """Compress data with zlib level 0.""" cobj = zlib.compressobj(0) return b"".join([cobj.compress(data), cobj.flush()]) @@ -117,6 +118,8 @@ class ZlibContext(object): compressor_id = 2 def __init__(self, level): + self.compress: Callable[[bytes], bytes] + # Jython zlib.compress doesn't support -1 if level == -1: self.compress = zlib.compress @@ -124,7 +127,7 @@ class ZlibContext(object): elif level == 0: self.compress = _zlib_no_compress else: - self.compress = lambda data: zlib.compress(data, level) + self.compresss = lambda data, _: zlib.compress(data, level) class ZstdContext(object): diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 3e78c2d97..152acaca6 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -13,29 +13,26 @@ # limitations under the License. """Cursor class to iterate over Mongo query results.""" - import copy import threading import warnings - from collections import deque +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Mapping, + MutableMapping, Optional, Sequence, Tuple, Union, cast, overload) from bson import RE_TYPE, _convert_raw_document_lists_to_streams from bson.code import Code from bson.son import SON from pymongo import helpers -from pymongo.common import (validate_boolean, validate_is_mapping, - validate_is_document_type) from pymongo.collation import validate_collation_or_none -from pymongo.errors import (ConnectionFailure, - InvalidOperation, +from pymongo.common import (validate_boolean, validate_is_document_type, + validate_is_mapping) +from pymongo.errors import (ConnectionFailure, InvalidOperation, OperationFailure) -from pymongo.message import (_CursorAddress, - _GetMore, - _RawBatchGetMore, - _Query, - _RawBatchQuery) +from pymongo.message import (_CursorAddress, _GetMore, _Query, + _RawBatchGetMore, _RawBatchQuery) from pymongo.response import PinnedResponse +from pymongo.typings import _CollationIn, _DocumentType # These errors mean that the server has already killed the cursor so there is # no need to send killCursors. @@ -126,22 +123,47 @@ class _SocketManager(object): self.sock.unpin() self.sock = None +_Sort = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]] +_Hint = Union[str, _Sort] -class Cursor(object): + +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + from pymongo.collection import Collection + + +class Cursor(Generic[_DocumentType]): """A cursor / iterator over Mongo query results. """ _query_class = _Query _getmore_class = _GetMore - def __init__(self, collection, filter=None, projection=None, skip=0, - limit=0, no_cursor_timeout=False, - cursor_type=CursorType.NON_TAILABLE, - sort=None, allow_partial_results=False, oplog_replay=False, - batch_size=0, - collation=None, hint=None, max_scan=None, max_time_ms=None, - max=None, min=None, return_key=None, show_record_id=None, - snapshot=None, comment=None, session=None, - allow_disk_use=None, let=None): + def __init__(self, + collection: "Collection[_DocumentType]", + filter: Optional[Mapping[str, Any]] = None, + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + skip: int = 0, + limit: int = 0, + no_cursor_timeout: bool = False, + cursor_type: int = CursorType.NON_TAILABLE, + sort: Optional[_Sort] = None, + allow_partial_results: bool = False, + oplog_replay: bool = False, + batch_size: int = 0, + collation: Optional[_CollationIn] = None, + hint: Optional[_Hint] = None, + max_scan: Optional[int] = None, + max_time_ms: Optional[int] = None, + max: Optional[_Sort] = None, + min: Optional[_Sort] = None, + return_key: Optional[bool] = None, + show_record_id: Optional[bool] = None, + snapshot: Optional[bool] = None, + comment: Any = None, + session: Optional["ClientSession"] = None, + allow_disk_use: Optional[bool] = None, + let: Optional[bool] = None + ) -> None: """Create a new cursor. Should not be called directly by application developers - see @@ -151,11 +173,12 @@ class Cursor(object): """ # Initialize all attributes used in __del__ before possibly raising # an error to avoid attribute errors during garbage collection. - self.__collection = collection - self.__id = None + self.__collection: Collection[_DocumentType] = collection + self.__id: Any = None self.__exhaust = False - self.__sock_mgr = None + self.__sock_mgr: Any = None self.__killed = False + self.__session: Optional["ClientSession"] if session: self.__session = session @@ -164,10 +187,7 @@ class Cursor(object): self.__session = None self.__explicit_session = False - spec = filter - if spec is None: - spec = {} - + spec: Mapping[str, Any] = filter or {} validate_is_mapping("filter", spec) if not isinstance(skip, int): raise TypeError("skip must be an instance of int") @@ -203,6 +223,7 @@ class Cursor(object): self.__let = let self.__spec = spec + self.__has_filter = filter is not None self.__projection = projection self.__skip = skip self.__limit = limit @@ -212,9 +233,9 @@ class Cursor(object): self.__explain = False self.__comment = comment self.__max_time_ms = max_time_ms - self.__max_await_time_ms = None - self.__max = max - self.__min = min + self.__max_await_time_ms: Optional[int] = None + self.__max: Optional[Union[SON[Any, Any], _Sort]] = max + self.__min: Optional[Union[SON[Any, Any], _Sort]] = min self.__collation = validate_collation_or_none(collation) self.__return_key = return_key self.__show_record_id = show_record_id @@ -239,7 +260,7 @@ class Cursor(object): # it anytime we change __limit. self.__empty = False - self.__data = deque() + self.__data: deque = deque() self.__address = None self.__retrieved = 0 @@ -261,22 +282,22 @@ class Cursor(object): self.__collname = collection.name @property - def collection(self): + def collection(self) -> "Collection[_DocumentType]": """The :class:`~pymongo.collection.Collection` that this :class:`Cursor` is iterating. """ return self.__collection @property - def retrieved(self): + def retrieved(self) -> int: """The number of documents retrieved so far. """ return self.__retrieved - def __del__(self): + def __del__(self) -> None: self.__die() - def rewind(self): + def rewind(self) -> "Cursor[_DocumentType]": """Rewind this cursor to its unevaluated state. Reset this cursor if it has been partially or completely evaluated. @@ -294,7 +315,7 @@ class Cursor(object): return self - def clone(self): + def clone(self) -> "Cursor[_DocumentType]": """Get a clone of this cursor. Returns a new Cursor instance with options matching those that have @@ -318,7 +339,7 @@ class Cursor(object): "batch_size", "max_scan", "query_flags", "collation", "empty", "show_record_id", "return_key", "allow_disk_use", - "snapshot", "exhaust") + "snapshot", "exhaust", "has_filter") data = dict((k, v) for k, v in self.__dict__.items() if k.startswith('_Cursor__') and k[9:] in values_to_clone) if deepcopy: @@ -360,7 +381,7 @@ class Cursor(object): self.__session = None self.__sock_mgr = None - def close(self): + def close(self) -> None: """Explicitly close / kill this cursor. """ self.__die(True) @@ -397,7 +418,7 @@ class Cursor(object): if operators: # Make a shallow copy so we can cleanly rewind or clone. - spec = self.__spec.copy() + spec = copy.copy(self.__spec) # Allow-listed commands must be wrapped in $query. if "$query" not in spec: @@ -429,7 +450,7 @@ class Cursor(object): if self.__retrieved or self.__id is not None: raise InvalidOperation("cannot set options after executing query") - def add_option(self, mask): + def add_option(self, mask: int) -> "Cursor[_DocumentType]": """Set arbitrary query flags using a bitmask. To set the tailable flag: @@ -450,7 +471,7 @@ class Cursor(object): self.__query_flags |= mask return self - def remove_option(self, mask): + def remove_option(self, mask: int) -> "Cursor[_DocumentType]": """Unset arbitrary query flags using a bitmask. To unset the tailable flag: @@ -466,7 +487,7 @@ class Cursor(object): self.__query_flags &= ~mask return self - def allow_disk_use(self, allow_disk_use): + def allow_disk_use(self, allow_disk_use: bool) -> "Cursor[_DocumentType]": """Specifies whether MongoDB can use temporary disk files while processing a blocking sort operation. @@ -488,7 +509,7 @@ class Cursor(object): self.__allow_disk_use = allow_disk_use return self - def limit(self, limit): + def limit(self, limit: int) -> "Cursor[_DocumentType]": """Limits the number of results to be returned by this cursor. Raises :exc:`TypeError` if `limit` is not an integer. Raises @@ -511,7 +532,7 @@ class Cursor(object): self.__limit = limit return self - def batch_size(self, batch_size): + def batch_size(self, batch_size: int) -> "Cursor[_DocumentType]": """Limits the number of documents returned in one batch. Each batch requires a round trip to the server. It can be adjusted to optimize performance and limit data transfer. @@ -539,7 +560,7 @@ class Cursor(object): self.__batch_size = batch_size return self - def skip(self, skip): + def skip(self, skip: int) -> "Cursor[_DocumentType]": """Skips the first `skip` results of this cursor. Raises :exc:`TypeError` if `skip` is not an integer. Raises @@ -560,7 +581,7 @@ class Cursor(object): self.__skip = skip return self - def max_time_ms(self, max_time_ms): + def max_time_ms(self, max_time_ms: Optional[int]) -> "Cursor[_DocumentType]": """Specifies a time limit for a query operation. If the specified time is exceeded, the operation will be aborted and :exc:`~pymongo.errors.ExecutionTimeout` is raised. If `max_time_ms` @@ -581,7 +602,7 @@ class Cursor(object): self.__max_time_ms = max_time_ms return self - def max_await_time_ms(self, max_await_time_ms): + def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> "Cursor[_DocumentType]": """Specifies a time limit for a getMore operation on a :attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` cursor. For all other types of cursor max_await_time_ms is ignored. @@ -609,6 +630,14 @@ class Cursor(object): return self + @overload + def __getitem__(self, index: int) -> _DocumentType: + ... + + @overload + def __getitem__(self, index: slice) -> "Cursor[_DocumentType]": + ... + def __getitem__(self, index): """Get a single document or a slice of documents from this cursor. @@ -691,7 +720,7 @@ class Cursor(object): raise TypeError("index %r cannot be applied to Cursor " "instances" % index) - def max_scan(self, max_scan): + def max_scan(self, max_scan: Optional[int]) -> "Cursor[_DocumentType]": """**DEPRECATED** - Limit the number of documents to scan when performing the query. @@ -711,7 +740,7 @@ class Cursor(object): self.__max_scan = max_scan return self - def max(self, spec): + def max(self, spec: _Sort) -> "Cursor[_DocumentType]": """Adds ``max`` operator that specifies upper bound for specific index. When using ``max``, :meth:`~hint` should also be configured to ensure @@ -734,7 +763,7 @@ class Cursor(object): self.__max = SON(spec) return self - def min(self, spec): + def min(self, spec: _Sort) -> "Cursor[_DocumentType]": """Adds ``min`` operator that specifies lower bound for specific index. When using ``min``, :meth:`~hint` should also be configured to ensure @@ -757,7 +786,7 @@ class Cursor(object): self.__min = SON(spec) return self - def sort(self, key_or_list, direction=None): + def sort(self, key_or_list: _Hint, direction: Optional[Union[int, str]] = None) -> "Cursor[_DocumentType]": """Sorts this cursor's results. Pass a field name and a direction, either @@ -803,7 +832,7 @@ class Cursor(object): self.__ordering = helpers._index_document(keys) return self - def distinct(self, key): + def distinct(self, key: str) -> List: """Get a list of distinct values for `key` among all documents in the result set of this query. @@ -820,7 +849,7 @@ class Cursor(object): .. seealso:: :meth:`pymongo.collection.Collection.distinct` """ - options = {} + options: Dict[str, Any] = {} if self.__spec: options["query"] = self.__spec if self.__max_time_ms is not None: @@ -833,7 +862,7 @@ class Cursor(object): return self.__collection.distinct( key, session=self.__session, **options) - def explain(self): + def explain(self) -> _DocumentType: """Returns an explain plan record for this cursor. .. note:: This method uses the default verbosity mode of the @@ -863,7 +892,7 @@ class Cursor(object): else: self.__hint = helpers._index_document(index) - def hint(self, index): + def hint(self, index: Optional[_Hint]) -> "Cursor[_DocumentType]": """Adds a 'hint', telling Mongo the proper index to use for the query. Judicious use of hints can greatly improve query @@ -888,7 +917,7 @@ class Cursor(object): self.__set_hint(index) return self - def comment(self, comment): + def comment(self, comment: Any) -> "Cursor[_DocumentType]": """Adds a 'comment' to the cursor. http://docs.mongodb.org/manual/reference/operator/comment/ @@ -903,7 +932,7 @@ class Cursor(object): self.__comment = comment return self - def where(self, code): + def where(self, code: Union[str, Code]) -> "Cursor[_DocumentType]": """Adds a `$where`_ clause to this query. The `code` argument must be an instance of :class:`basestring` @@ -937,10 +966,18 @@ class Cursor(object): if not isinstance(code, Code): code = Code(code) - self.__spec["$where"] = code + # Avoid overwriting a filter argument that was given by the user + # when updating the spec. + spec: Dict[str, Any] + if self.__has_filter: + spec = dict(self.__spec) + else: + spec = cast(Dict, self.__spec) + spec["$where"] = code + self.__spec = spec return self - def collation(self, collation): + def collation(self, collation: Optional[_CollationIn]) -> "Cursor[_DocumentType]": """Adds a :class:`~pymongo.collation.Collation` to this query. Raises :exc:`TypeError` if `collation` is not an instance of @@ -1106,7 +1143,7 @@ class Cursor(object): return len(self.__data) @property - def alive(self): + def alive(self) -> bool: """Does this cursor have the potential to return more data? This is mostly useful with `tailable cursors @@ -1128,7 +1165,7 @@ class Cursor(object): return bool(len(self.__data) or (not self.__killed)) @property - def cursor_id(self): + def cursor_id(self) -> Optional[int]: """Returns the id of the cursor .. versionadded:: 2.2 @@ -1136,7 +1173,7 @@ class Cursor(object): return self.__id @property - def address(self): + def address(self) -> Optional[Tuple[str, Any]]: """The (host, port) of the server used, or None. .. versionchanged:: 3.0 @@ -1145,18 +1182,19 @@ class Cursor(object): return self.__address @property - def session(self): + def session(self) -> Optional["ClientSession"]: """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. .. versionadded:: 3.6 """ if self.__explicit_session: return self.__session + return None - def __iter__(self): + def __iter__(self) -> "Cursor[_DocumentType]": return self - def next(self): + def next(self) -> _DocumentType: """Advance the cursor.""" if self.__empty: raise StopIteration @@ -1167,20 +1205,20 @@ class Cursor(object): __next__ = next - def __enter__(self): + def __enter__(self) -> "Cursor[_DocumentType]": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def __copy__(self): + def __copy__(self) -> "Cursor[_DocumentType]": """Support function for `copy.copy()`. .. versionadded:: 2.4 """ return self._clone(deepcopy=False) - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> Any: """Support function for `copy.deepcopy()`. .. versionadded:: 2.4 @@ -1193,6 +1231,7 @@ class Cursor(object): Regular expressions cannot be deep copied but as they are immutable we don't have to copy them when cloning. """ + y: Any if not hasattr(x, 'items'): y, is_list, iterator = [], True, enumerate(x) else: @@ -1220,13 +1259,13 @@ class Cursor(object): return y -class RawBatchCursor(Cursor): +class RawBatchCursor(Cursor, Generic[_DocumentType]): """A cursor / iterator over raw batches of BSON data from a query result.""" _query_class = _RawBatchQuery _getmore_class = _RawBatchGetMore - def __init__(self, *args, **kwargs): + def __init__(self, collection: "Collection[_DocumentType]", *args: Any, **kwargs: Any) -> None: """Create a new cursor / iterator over raw batches of BSON data. Should not be called directly by application developers - @@ -1235,7 +1274,7 @@ class RawBatchCursor(Cursor): .. seealso:: The MongoDB documentation on `cursors `_. """ - super(RawBatchCursor, self).__init__(*args, **kwargs) + super(RawBatchCursor, self).__init__(collection, *args, **kwargs) def _unpack_response(self, response, cursor_id, codec_options, user_fields=None, legacy_response=False): @@ -1247,7 +1286,7 @@ class RawBatchCursor(Cursor): _convert_raw_document_lists_to_streams(raw_response[0]) return raw_response - def explain(self): + def explain(self) -> _DocumentType: """Returns an explain plan record for this cursor. .. seealso:: The MongoDB documentation on `explain `_. @@ -1255,5 +1294,5 @@ class RawBatchCursor(Cursor): clone = self._clone(deepcopy=True, base=Cursor(self.collection)) return clone.explain() - def __getitem__(self, index): + def __getitem__(self, index: Any) -> "Cursor[_DocumentType]": raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor") diff --git a/pymongo/database.py b/pymongo/database.py index a6c127512..4f5f93135 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -13,18 +13,21 @@ # limitations under the License. """Database level operations.""" +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Mapping, MutableMapping, Optional, + Sequence, Union) -from bson.codec_options import DEFAULT_CODEC_OPTIONS +from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions from bson.dbref import DBRef from bson.son import SON +from bson.timestamp import Timestamp from pymongo import common from pymongo.aggregation import _DatabaseAggregationCommand from pymongo.change_stream import DatabaseChangeStream from pymongo.collection import Collection from pymongo.command_cursor import CommandCursor -from pymongo.errors import (CollectionInvalid, - InvalidName) -from pymongo.read_preferences import ReadPreference +from pymongo.errors import CollectionInvalid, InvalidName +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline def _check_name(name): @@ -39,12 +42,24 @@ def _check_name(name): "character %r" % invalid_char) -class Database(common.BaseObject): +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + from pymongo.mongo_client import MongoClient + from pymongo.read_concern import ReadConcern + from pymongo.write_concern import WriteConcern + + +class Database(common.BaseObject, Generic[_DocumentType]): """A Mongo database. """ - - def __init__(self, client, name, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def __init__(self, + client: "MongoClient[_DocumentType]", + name: str, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional["WriteConcern"] = None, + read_concern: Optional["ReadConcern"] = None, + ) -> None: """Get a database by client and name. Raises :class:`TypeError` if `name` is not an instance of @@ -104,20 +119,24 @@ class Database(common.BaseObject): _check_name(name) self.__name = name - self.__client = client + self.__client: MongoClient[_DocumentType] = client @property - def client(self): + def client(self) -> "MongoClient[_DocumentType]": """The client instance for this :class:`Database`.""" return self.__client @property - def name(self): + def name(self) -> str: """The name of this :class:`Database`.""" return self.__name - def with_options(self, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def with_options(self, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional["WriteConcern"] = None, + read_concern: Optional["ReadConcern"] = None, + ) -> "Database[_DocumentType]": """Get a clone of this database changing the specified settings. >>> db1.read_preference @@ -156,22 +175,22 @@ class Database(common.BaseObject): write_concern or self.write_concern, read_concern or self.read_concern) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Database): return (self.__client == other.client and self.__name == other.name) return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: return hash((self.__client, self.__name)) def __repr__(self): return "Database(%r, %r)" % (self.__client, self.__name) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Collection[_DocumentType]: """Get a collection of this database by name. Raises InvalidName if an invalid collection name is used. @@ -185,7 +204,7 @@ class Database(common.BaseObject): " collection, use database[%r]." % (name, name, name)) return self.__getitem__(name) - def __getitem__(self, name): + def __getitem__(self, name: str) -> "Collection[_DocumentType]": """Get a collection of this database by name. Raises InvalidName if an invalid collection name is used. @@ -195,8 +214,13 @@ class Database(common.BaseObject): """ return Collection(self, name) - def get_collection(self, name, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def get_collection(self, + name: str, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional["WriteConcern"] = None, + read_concern: Optional["ReadConcern"] = None, + ) -> Collection[_DocumentType]: """Get a :class:`~pymongo.collection.Collection` with the given name and options. @@ -238,9 +262,15 @@ class Database(common.BaseObject): self, name, False, codec_options, read_preference, write_concern, read_concern) - def create_collection(self, name, codec_options=None, - read_preference=None, write_concern=None, - read_concern=None, session=None, **kwargs): + def create_collection(self, + name: str, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional["WriteConcern"] = None, + read_concern: Optional["ReadConcern"] = None, + session: Optional["ClientSession"] = None, + **kwargs: Any, + ) -> Collection[_DocumentType]: """Create a new :class:`~pymongo.collection.Collection` in this database. @@ -286,20 +316,20 @@ class Database(common.BaseObject): timeseries collections - ``expireAfterSeconds`` (int): the number of seconds after which a document in a timeseries collection expires - - ``validator`` (dict): a document specifying validation rules or expressions + - ``validator`` (dict): a document specifying validation rules or expressions for the collection - - ``validationLevel`` (str): how strictly to apply the + - ``validationLevel`` (str): how strictly to apply the validation rules to existing documents during an update. The default level is "strict" - ``validationAction`` (str): whether to "error" on invalid documents - (the default) or just "warn" about the violations but allow invalid + (the default) or just "warn" about the violations but allow invalid documents to be inserted - ``indexOptionDefaults`` (dict): a document specifying a default configuration for indexes when creating a collection - - ``viewOn`` (str): the name of the source collection or view from which + - ``viewOn`` (str): the name of the source collection or view from which to create the view - ``pipeline`` (list): a list of aggregation pipeline stages - - ``comment`` (str): a user-provided comment to attach to this command. + - ``comment`` (str): a user-provided comment to attach to this command. This option is only supported on MongoDB >= 4.4. .. versionchanged:: 3.11 @@ -330,7 +360,11 @@ class Database(common.BaseObject): read_preference, write_concern, read_concern, session=s, **kwargs) - def aggregate(self, pipeline, session=None, **kwargs): + def aggregate(self, + pipeline: _Pipeline, + session: Optional["ClientSession"] = None, + **kwargs: Any + ) -> CommandCursor[_DocumentType]: """Perform a database-level aggregation. See the `aggregation pipeline`_ documentation for a list of stages @@ -400,9 +434,17 @@ class Database(common.BaseObject): cmd.get_cursor, cmd.get_read_preference(s), s, retryable=not cmd._performs_write) - def watch(self, pipeline=None, full_document=None, resume_after=None, - max_await_time_ms=None, batch_size=None, collation=None, - start_at_operation_time=None, session=None, start_after=None): + def watch(self, + pipeline: Optional[_Pipeline] = None, + full_document: Optional[str] = None, + resume_after: Optional[Mapping[str, Any]] = None, + max_await_time_ms: Optional[int] = None, + batch_size: Optional[int] = None, + collation: Optional[_CollationIn] = None, + start_at_operation_time: Optional[Timestamp] = None, + session: Optional["ClientSession"] = None, + start_after: Optional[Mapping[str, Any]] = None, + ) -> DatabaseChangeStream[_DocumentType]: """Watch changes on this database. Performs an aggregation with an implicit initial ``$changeStream`` @@ -515,9 +557,16 @@ class Database(common.BaseObject): session=s, client=self.__client) - def command(self, command, value=1, check=True, - allowable_errors=None, read_preference=None, - codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs): + def command(self, + command: Union[str, MutableMapping[str, Any]], + value: Any = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = None, + codec_options: Optional[CodecOptions] = DEFAULT_CODEC_OPTIONS, + session: Optional["ClientSession"] = None, + **kwargs: Any, + ) -> Dict[str, Any]: """Issue a MongoDB command. Send command `command` to the database and return the @@ -648,7 +697,11 @@ class Database(common.BaseObject): cmd_cursor._maybe_pin_connection(sock_info) return cmd_cursor - def list_collections(self, session=None, filter=None, **kwargs): + def list_collections(self, + session: Optional["ClientSession"] = None, + filter: Optional[Mapping[str, Any]] = None, + **kwargs: Any + ) -> CommandCursor[Dict[str, Any]]: """Get a cursor over the collections of this database. :Parameters: @@ -680,7 +733,11 @@ class Database(common.BaseObject): return self.__client._retryable_read( _cmd, read_pref, session) - def list_collection_names(self, session=None, filter=None, **kwargs): + def list_collection_names(self, + session: Optional["ClientSession"] = None, + filter: Optional[Mapping[str, Any]] = None, + **kwargs: Any + ) -> List[str]: """Get a list of all the collection names in this database. For example, to list all non-system collections:: @@ -717,7 +774,10 @@ class Database(common.BaseObject): return [result["name"] for result in self.list_collections(session=session, **kwargs)] - def drop_collection(self, name_or_collection, session=None): + def drop_collection(self, + name_or_collection: Union[str, Collection], + session: Optional["ClientSession"] = None + ) -> Dict[str, Any]: """Drop a collection. :Parameters: @@ -752,9 +812,13 @@ class Database(common.BaseObject): parse_write_concern_error=True, session=session) - def validate_collection(self, name_or_collection, - scandata=False, full=False, session=None, - background=None): + def validate_collection(self, + name_or_collection: Union[str, Collection], + scandata: bool = False, + full: bool = False, + session: Optional["ClientSession"] = None, + background: Optional[bool] = None, + ) -> Dict[str, Any]: """Validate a collection. Returns a dict of validation info. Raises CollectionInvalid if @@ -827,20 +891,23 @@ class Database(common.BaseObject): return result - def __iter__(self): + def __iter__(self) -> "Database[_DocumentType]": return self - def __next__(self): + def __next__(self) -> "Database[_DocumentType]": raise TypeError("'Database' object is not iterable") next = __next__ - def __bool__(self): + def __bool__(self) -> bool: raise NotImplementedError("Database objects do not implement truth " "value testing or bool(). Please compare " "with None instead: database is not None") - def dereference(self, dbref, session=None, **kwargs): + def dereference(self, dbref: DBRef, + session: Optional["ClientSession"] = None, + **kwargs: Any + ) -> Optional[_DocumentType]: """Dereference a :class:`~bson.dbref.DBRef`, getting the document it points to. diff --git a/pymongo/driver_info.py b/pymongo/driver_info.py index 5e0843e4d..1bb599af3 100644 --- a/pymongo/driver_info.py +++ b/pymongo/driver_info.py @@ -15,6 +15,7 @@ """Advanced options for MongoDB drivers implemented on top of PyMongo.""" from collections import namedtuple +from typing import Optional class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])): @@ -26,7 +27,7 @@ class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])): like 'MyDriver', '1.2.3', 'some platform info'. Any of these strings may be None to accept PyMongo's default. """ - def __new__(cls, name, version=None, platform=None): + def __new__(cls, name: str, version: Optional[str] = None, platform: Optional[str] = None) -> "DriverInfo": self = super(DriverInfo, cls).__new__(cls, name, version, platform) for key, value in self._asdict().items(): if value is not None and not isinstance(value, str): diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 4b08492ee..b076f490f 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -15,15 +15,16 @@ """Support for explicit client-side field level encryption.""" import contextlib -import os -import subprocess import uuid import weakref +from typing import Any, Mapping, Optional, Sequence try: from pymongocrypt.auto_encrypter import AutoEncrypter from pymongocrypt.errors import MongoCryptError - from pymongocrypt.explicit_encrypter import ExplicitEncrypter + from pymongocrypt.explicit_encrypter import ( + ExplicitEncrypter + ) from pymongocrypt.mongocrypt import MongoCryptOptions from pymongocrypt.state_machine import MongoCryptCallback _HAVE_PYMONGOCRYPT = True @@ -32,29 +33,22 @@ except ImportError: MongoCryptCallback = object from bson import _dict_to_bson, decode, encode +from bson.binary import STANDARD, UUID_SUBTYPE, Binary from bson.codec_options import CodecOptions -from bson.binary import (Binary, - STANDARD, - UUID_SUBTYPE) from bson.errors import BSONError -from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, - RawBSONDocument, +from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson) from bson.son import SON - -from pymongo.errors import (ConfigurationError, - EncryptionError, - InvalidOperation, - ServerSelectionTimeoutError) +from pymongo.daemon import _spawn_daemon from pymongo.encryption_options import AutoEncryptionOpts +from pymongo.errors import (ConfigurationError, EncryptionError, + InvalidOperation, ServerSelectionTimeoutError) from pymongo.mongo_client import MongoClient -from pymongo.pool import _configured_socket, PoolOptions +from pymongo.pool import PoolOptions, _configured_socket from pymongo.read_concern import ReadConcern from pymongo.ssl_support import get_ssl_context from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern -from pymongo.daemon import _spawn_daemon - _HTTPS_PORT = 443 _KMS_CONNECT_TIMEOUT = 10 # TODO: CDRIVER-3262 will define this value. @@ -80,9 +74,10 @@ def _wrap_encryption_errors(): raise EncryptionError(exc) -class _EncryptionIO(MongoCryptCallback): +class _EncryptionIO(MongoCryptCallback): # type: ignore def __init__(self, client, key_vault_coll, mongocryptd_client, opts): """Internal class to perform I/O on behalf of pymongocrypt.""" + self.client_ref: Any # Use a weak ref to break reference cycle. if client is not None: self.client_ref = weakref.ref(client) @@ -355,11 +350,17 @@ class Algorithm(object): "AEAD_AES_256_CBC_HMAC_SHA_512-Random") + class ClientEncryption(object): """Explicit client-side field level encryption.""" - def __init__(self, kms_providers, key_vault_namespace, key_vault_client, - codec_options, kms_tls_options=None): + def __init__(self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: MongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None + ) -> None: """Explicit client-side field level encryption. The ClientEncryption class encapsulates explicit operations on a key @@ -449,12 +450,15 @@ class ClientEncryption(object): opts = AutoEncryptionOpts(kms_providers, key_vault_namespace, kms_tls_options=kms_tls_options) - self._io_callbacks = _EncryptionIO(None, key_vault_coll, None, opts) + self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(None, key_vault_coll, None, opts) self._encryption = ExplicitEncrypter( self._io_callbacks, MongoCryptOptions(kms_providers, None)) - def create_data_key(self, kms_provider, master_key=None, - key_alt_names=None): + def create_data_key(self, + kms_provider: str, + master_key: Optional[Mapping[str, Any]] = None, + key_alt_names: Optional[Sequence[str]] = None + ) -> Binary: """Create and insert a new data key into the key vault collection. :Parameters: @@ -526,7 +530,12 @@ class ClientEncryption(object): kms_provider, master_key=master_key, key_alt_names=key_alt_names) - def encrypt(self, value, algorithm, key_id=None, key_alt_name=None): + def encrypt(self, + value: Any, + algorithm: str, + key_id: Optional[Binary] = None, + key_alt_name: Optional[str] = None + ) -> Binary: """Encrypt a BSON value with a given key and algorithm. Note that exactly one of ``key_id`` or ``key_alt_name`` must be @@ -557,7 +566,7 @@ class ClientEncryption(object): doc, algorithm, key_id=key_id, key_alt_name=key_alt_name) return decode(encrypted_doc)['v'] - def decrypt(self, value): + def decrypt(self, value: Binary) -> Any: """Decrypt an encrypted value. :Parameters: @@ -578,17 +587,17 @@ class ClientEncryption(object): return decode(decrypted_doc, codec_options=self._codec_options)['v'] - def __enter__(self): + def __enter__(self) -> "ClientEncryption": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() def _check_closed(self): if self._encryption is None: raise InvalidOperation("Cannot use closed ClientEncryption") - def close(self): + def close(self) -> None: """Release resources. Note that using this class in a with-statement will automatically call diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index d0c2d5ce7..21a13f6a5 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -15,6 +15,7 @@ """Support for automatic client-side field level encryption.""" import copy +from typing import TYPE_CHECKING, Any, List, Mapping, Optional try: import pymongocrypt @@ -25,19 +26,25 @@ except ImportError: from pymongo.errors import ConfigurationError from pymongo.uri_parser import _parse_kms_tls_options +if TYPE_CHECKING: + from pymongo.mongo_client import MongoClient + class AutoEncryptionOpts(object): """Options to configure automatic client-side field level encryption.""" - def __init__(self, kms_providers, key_vault_namespace, - key_vault_client=None, - schema_map=None, - bypass_auto_encryption=False, - mongocryptd_uri='mongodb://localhost:27020', - mongocryptd_bypass_spawn=False, - mongocryptd_spawn_path='mongocryptd', - mongocryptd_spawn_args=None, - kms_tls_options=None): + def __init__(self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: Optional["MongoClient"] = None, + schema_map: Optional[Mapping[str, Any]] = None, + bypass_auto_encryption: Optional[bool] = False, + mongocryptd_uri: str = 'mongodb://localhost:27020', + mongocryptd_bypass_spawn: bool = False, + mongocryptd_spawn_path: str = 'mongocryptd', + mongocryptd_spawn_args: Optional[List[str]] = None, + kms_tls_options: Optional[Mapping[str, Any]] = None + ) -> None: """Options to configure automatic client-side field level encryption. Automatic client-side field level encryption requires MongoDB 4.2 @@ -152,8 +159,9 @@ class AutoEncryptionOpts(object): self._mongocryptd_uri = mongocryptd_uri self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn self._mongocryptd_spawn_path = mongocryptd_spawn_path - self._mongocryptd_spawn_args = (copy.copy(mongocryptd_spawn_args) or - ['--idleShutdownTimeoutSecs=60']) + if mongocryptd_spawn_args is None: + mongocryptd_spawn_args = ['--idleShutdownTimeoutSecs=60'] + self._mongocryptd_spawn_args = mongocryptd_spawn_args if not isinstance(self._mongocryptd_spawn_args, list): raise TypeError('mongocryptd_spawn_args must be a list') if not any('idleShutdownTimeoutSecs' in s diff --git a/pymongo/errors.py b/pymongo/errors.py index 0ee35827a..89c45730c 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -13,6 +13,8 @@ # limitations under the License. """Exceptions raised by PyMongo.""" +from typing import (Any, Iterable, List, Mapping, Optional, Sequence, Tuple, + Union) from bson.errors import * @@ -23,18 +25,21 @@ except ImportError: try: from ssl import CertificateError as _CertificateError except ImportError: - class _CertificateError(ValueError): + class _CertificateError(ValueError): # type: ignore pass class PyMongoError(Exception): """Base class for all PyMongo exceptions.""" - def __init__(self, message='', error_labels=None): + def __init__(self, + message: str = '', + error_labels: Optional[Iterable[str]] = None + ) -> None: super(PyMongoError, self).__init__(message) self._message = message self._error_labels = set(error_labels or []) - def has_error_label(self, label): + def has_error_label(self, label: str) -> bool: """Return True if this error contains the given label. .. versionadded:: 3.7 @@ -70,10 +75,17 @@ class AutoReconnect(ConnectionFailure): Subclass of :exc:`~pymongo.errors.ConnectionFailure`. """ - def __init__(self, message='', errors=None): + errors: Union[Mapping[str, Any], Sequence] + details: Union[Mapping[str, Any], Sequence] + + def __init__(self, + message: str = '', + errors: Optional[Union[Mapping[str, Any], Sequence]] = None + ) -> None: error_labels = None - if errors is not None and isinstance(errors, dict): - error_labels = errors.get('errorLabels') + if errors is not None: + if isinstance(errors, dict): + error_labels = errors.get('errorLabels') super(AutoReconnect, self).__init__(message, error_labels) self.errors = self.details = errors or [] @@ -109,7 +121,10 @@ class NotPrimaryError(AutoReconnect): .. versionadded:: 3.12 """ - def __init__(self, message='', errors=None): + def __init__(self, + message: str = '', + errors: Optional[Union[Mapping[str, Any], List]] = None + ) -> None: super(NotPrimaryError, self).__init__( _format_detailed_error(message, errors), errors=errors) @@ -139,7 +154,12 @@ class OperationFailure(PyMongoError): The :attr:`details` attribute. """ - def __init__(self, error, code=None, details=None, max_wire_version=None): + def __init__(self, + error: str, + code: Optional[int] = None, + details: Optional[Mapping[str, Any]] = None, + max_wire_version: Optional[int] = None, + ) -> None: error_labels = None if details is not None: error_labels = details.get('errorLabels') @@ -154,13 +174,13 @@ class OperationFailure(PyMongoError): return self.__max_wire_version @property - def code(self): + def code(self) -> Optional[int]: """The error code returned by the server, if any. """ return self.__code @property - def details(self): + def details(self) -> Optional[Mapping[str, Any]]: """The complete error document returned by the server. Depending on the error that occurred, the error document @@ -225,14 +245,17 @@ class BulkWriteError(OperationFailure): .. versionadded:: 2.7 """ - def __init__(self, results): + details: Mapping[str, Any] + + def __init__(self, results: Mapping[str, Any]) -> None: super(BulkWriteError, self).__init__( "batch op errors occurred", 65, results) - def __reduce__(self): + def __reduce__(self) -> Tuple[Any, Any]: return self.__class__, (self.details,) + class InvalidOperation(PyMongoError): """Raised when a client attempts to perform an invalid operation.""" @@ -264,12 +287,12 @@ class EncryptionError(PyMongoError): .. versionadded:: 3.9 """ - def __init__(self, cause): + def __init__(self, cause: Exception) -> None: super(EncryptionError, self).__init__(str(cause)) self.__cause = cause @property - def cause(self): + def cause(self) -> Exception: """The exception that caused this encryption or decryption error.""" return self.__cause diff --git a/pymongo/event_loggers.py b/pymongo/event_loggers.py index 7d5501c37..f0857f8f4 100644 --- a/pymongo/event_loggers.py +++ b/pymongo/event_loggers.py @@ -26,8 +26,6 @@ or ``MongoClient(event_listeners=[CommandLogger()])`` """ - - import logging from pymongo import monitoring @@ -42,18 +40,18 @@ class CommandLogger(monitoring.CommandListener): logs them at the `INFO` severity level using :mod:`logging`. .. versionadded:: 3.11 """ - def started(self, event): + def started(self, event: monitoring.CommandStartedEvent) -> None: logging.info("Command {0.command_name} with request id " "{0.request_id} started on server " "{0.connection_id}".format(event)) - def succeeded(self, event): + def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: logging.info("Command {0.command_name} with request id " "{0.request_id} on server {0.connection_id} " "succeeded in {0.duration_micros} " "microseconds".format(event)) - def failed(self, event): + def failed(self, event: monitoring.CommandFailedEvent) -> None: logging.info("Command {0.command_name} with request id " "{0.request_id} on server {0.connection_id} " "failed in {0.duration_micros} " @@ -70,11 +68,11 @@ class ServerLogger(monitoring.ServerListener): .. versionadded:: 3.11 """ - def opened(self, event): + def opened(self, event: monitoring.ServerOpeningEvent) -> None: logging.info("Server {0.server_address} added to topology " "{0.topology_id}".format(event)) - def description_changed(self, event): + def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None: previous_server_type = event.previous_description.server_type new_server_type = event.new_description.server_type if new_server_type != previous_server_type: @@ -84,7 +82,7 @@ class ServerLogger(monitoring.ServerListener): "{0.previous_description.server_type_name} to " "{0.new_description.server_type_name}".format(event)) - def closed(self, event): + def closed(self, event: monitoring.ServerClosedEvent) -> None: logging.warning("Server {0.server_address} removed from topology " "{0.topology_id}".format(event)) @@ -99,17 +97,17 @@ class HeartbeatLogger(monitoring.ServerHeartbeatListener): .. versionadded:: 3.11 """ - def started(self, event): + def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: logging.info("Heartbeat sent to server " "{0.connection_id}".format(event)) - def succeeded(self, event): + def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: # The reply.document attribute was added in PyMongo 3.4. logging.info("Heartbeat to server {0.connection_id} " "succeeded with reply " "{0.reply.document}".format(event)) - def failed(self, event): + def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: logging.warning("Heartbeat to server {0.connection_id} " "failed with error {0.reply}".format(event)) @@ -124,11 +122,11 @@ class TopologyLogger(monitoring.TopologyListener): .. versionadded:: 3.11 """ - def opened(self, event): + def opened(self, event: monitoring.TopologyOpenedEvent) -> None: logging.info("Topology with id {0.topology_id} " "opened".format(event)) - def description_changed(self, event): + def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None: logging.info("Topology description updated for " "topology id {0.topology_id}".format(event)) previous_topology_type = event.previous_description.topology_type @@ -146,7 +144,7 @@ class TopologyLogger(monitoring.TopologyListener): if not event.new_description.has_readable_server(): logging.warning("No readable servers available.") - def closed(self, event): + def closed(self, event: monitoring.TopologyClosedEvent) -> None: logging.info("Topology with id {0.topology_id} " "closed".format(event)) @@ -168,43 +166,43 @@ class ConnectionPoolLogger(monitoring.ConnectionPoolListener): .. versionadded:: 3.11 """ - def pool_created(self, event): + def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: logging.info("[pool {0.address}] pool created".format(event)) def pool_ready(self, event): logging.info("[pool {0.address}] pool ready".format(event)) - def pool_cleared(self, event): + def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: logging.info("[pool {0.address}] pool cleared".format(event)) - def pool_closed(self, event): + def pool_closed(self, event: monitoring.PoolClosedEvent) -> None: logging.info("[pool {0.address}] pool closed".format(event)) - def connection_created(self, event): + def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None: logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection created".format(event)) - def connection_ready(self, event): + def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None: logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection setup succeeded".format(event)) - def connection_closed(self, event): + def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None: logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection closed, reason: " "{0.reason}".format(event)) - def connection_check_out_started(self, event): + def connection_check_out_started(self, event: monitoring.ConnectionCheckOutStartedEvent) -> None: logging.info("[pool {0.address}] connection check out " "started".format(event)) - def connection_check_out_failed(self, event): + def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None: logging.info("[pool {0.address}] connection check out " "failed, reason: {0.reason}".format(event)) - def connection_checked_out(self, event): + def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None: logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection checked out of pool".format(event)) - def connection_checked_in(self, event): + def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None: logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection checked into pool".format(event)) diff --git a/pymongo/hello.py b/pymongo/hello.py index 0ad06e961..ba09d80e3 100644 --- a/pymongo/hello.py +++ b/pymongo/hello.py @@ -14,10 +14,15 @@ """Helpers for the 'hello' and legacy hello commands.""" +import copy +import datetime import itertools +from typing import Any, Generic, List, Mapping, Optional, Set, Tuple +from bson.objectid import ObjectId from pymongo import common from pymongo.server_type import SERVER_TYPE +from pymongo.typings import _DocumentType class HelloCompat: @@ -56,7 +61,7 @@ def _get_server_type(doc): return SERVER_TYPE.Standalone -class Hello(object): +class Hello(Generic[_DocumentType]): """Parse a hello response from the server. .. versionadded:: 3.12 @@ -64,9 +69,9 @@ class Hello(object): __slots__ = ('_doc', '_server_type', '_is_writable', '_is_readable', '_awaitable') - def __init__(self, doc, awaitable=False): + def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None: self._server_type = _get_server_type(doc) - self._doc = doc + self._doc: _DocumentType = doc self._is_writable = self._server_type in ( SERVER_TYPE.RSPrimary, SERVER_TYPE.Standalone, @@ -79,19 +84,19 @@ class Hello(object): self._awaitable = awaitable @property - def document(self): + def document(self) -> _DocumentType: """The complete hello command response document. .. versionadded:: 3.4 """ - return self._doc.copy() + return copy.copy(self._doc) @property - def server_type(self): + def server_type(self) -> int: return self._server_type @property - def all_hosts(self): + def all_hosts(self) -> Set[Tuple[str, int]]: """List of hosts, passives, and arbiters known to this server.""" return set(map(common.clean_node, itertools.chain( self._doc.get('hosts', []), @@ -99,12 +104,12 @@ class Hello(object): self._doc.get('arbiters', [])))) @property - def tags(self): + def tags(self) -> Mapping[str, Any]: """Replica set member tags or empty dict.""" return self._doc.get('tags', {}) @property - def primary(self): + def primary(self) -> Optional[Tuple[str, int]]: """This server's opinion about who the primary is, or None.""" if self._doc.get('primary'): return common.partition_node(self._doc['primary']) @@ -112,70 +117,71 @@ class Hello(object): return None @property - def replica_set_name(self): + def replica_set_name(self) -> Optional[str]: """Replica set name or None.""" return self._doc.get('setName') @property - def max_bson_size(self): + def max_bson_size(self) -> int: return self._doc.get('maxBsonObjectSize', common.MAX_BSON_SIZE) @property - def max_message_size(self): + def max_message_size(self) -> int: return self._doc.get('maxMessageSizeBytes', 2 * self.max_bson_size) @property - def max_write_batch_size(self): + def max_write_batch_size(self) -> int: return self._doc.get('maxWriteBatchSize', common.MAX_WRITE_BATCH_SIZE) @property - def min_wire_version(self): + def min_wire_version(self) -> int: return self._doc.get('minWireVersion', common.MIN_WIRE_VERSION) @property - def max_wire_version(self): + def max_wire_version(self) -> int: return self._doc.get('maxWireVersion', common.MAX_WIRE_VERSION) @property - def set_version(self): + def set_version(self) -> Optional[int]: return self._doc.get('setVersion') @property - def election_id(self): + def election_id(self) -> Optional[ObjectId]: return self._doc.get('electionId') @property - def cluster_time(self): + def cluster_time(self) -> Optional[Mapping[str, Any]]: return self._doc.get('$clusterTime') @property - def logical_session_timeout_minutes(self): + def logical_session_timeout_minutes(self) -> Optional[int]: return self._doc.get('logicalSessionTimeoutMinutes') @property - def is_writable(self): + def is_writable(self) -> bool: return self._is_writable @property - def is_readable(self): + def is_readable(self) -> bool: return self._is_readable @property - def me(self): + def me(self) -> Optional[Tuple[str, int]]: me = self._doc.get('me') if me: return common.clean_node(me) + return None @property - def last_write_date(self): + def last_write_date(self) -> Optional[datetime.datetime]: return self._doc.get('lastWrite', {}).get('lastWriteDate') @property - def compressors(self): + def compressors(self) -> Optional[List[str]]: return self._doc.get('compression') @property - def sasl_supported_mechs(self): + def sasl_supported_mechs(self) -> List[str]: """Supported authentication mechanisms for the current user. For example:: @@ -187,22 +193,22 @@ class Hello(object): return self._doc.get('saslSupportedMechs', []) @property - def speculative_authenticate(self): + def speculative_authenticate(self) -> Optional[Mapping[str, Any]]: """The speculativeAuthenticate field.""" return self._doc.get('speculativeAuthenticate') @property - def topology_version(self): + def topology_version(self) -> Optional[Mapping[str, Any]]: return self._doc.get('topologyVersion') @property - def awaitable(self): + def awaitable(self) -> bool: return self._awaitable @property - def service_id(self): + def service_id(self) -> Optional[ObjectId]: return self._doc.get('serviceId') @property - def hello_ok(self): + def hello_ok(self) -> bool: return self._doc.get('helloOk', False) diff --git a/pymongo/helpers.py b/pymongo/helpers.py index a9d40d810..b2726dca6 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -16,18 +16,14 @@ import sys import traceback - from collections import abc +from typing import Any from bson.son import SON from pymongo import ASCENDING -from pymongo.errors import (CursorNotFound, - DuplicateKeyError, - ExecutionTimeout, - NotPrimaryError, - OperationFailure, - WriteError, - WriteConcernError, +from pymongo.errors import (CursorNotFound, DuplicateKeyError, + ExecutionTimeout, NotPrimaryError, + OperationFailure, WriteConcernError, WriteError, WTimeoutError) from pymongo.hello import HelloCompat @@ -95,7 +91,7 @@ def _index_document(index_list): if not len(index_list): raise ValueError("key_or_list must not be the empty list") - index = SON() + index: SON[str, Any] = SON() for (key, value) in index_list: if not isinstance(key, str): raise TypeError( diff --git a/pymongo/message.py b/pymongo/message.py index f632214a0..ac6000cfd 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -23,8 +23,8 @@ MongoDB. import datetime import random import struct - from io import BytesIO as _BytesIO +from typing import Any import bson from bson import (CodecOptions, @@ -32,30 +32,24 @@ from bson import (CodecOptions, _decode_selective, _dict_to_bson, _make_c_string) -from bson import codec_options from bson.int64 import Int64 -from bson.raw_bson import (_inflate_bson, DEFAULT_RAW_BSON_OPTIONS, - RawBSONDocument) +from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, + _inflate_bson) from bson.son import SON try: - from pymongo import _cmessage + from pymongo import _cmessage # type: ignore[attr-defined] _use_c = True except ImportError: _use_c = False -from pymongo.errors import (ConfigurationError, - CursorNotFound, - DocumentTooLarge, - ExecutionTimeout, - InvalidOperation, - NotPrimaryError, - OperationFailure, - ProtocolError) +from pymongo.errors import (ConfigurationError, CursorNotFound, + DocumentTooLarge, ExecutionTimeout, + InvalidOperation, NotPrimaryError, + OperationFailure, ProtocolError) from pymongo.hello import HelloCompat from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern - MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 @@ -457,6 +451,7 @@ class _RawBatchGetMore(_GetMore): class _CursorAddress(tuple): """The server address (host, port) of a cursor, with namespace property.""" + __namespace: Any def __new__(cls, address, namespace): self = tuple.__new__(cls, address) @@ -762,6 +757,7 @@ class _BulkWriteContext(object): """A proxy for SocketInfo.unack_write that handles event publishing. """ if self.publish: + assert self.start_time is not None duration = datetime.datetime.now() - self.start_time cmd = self._start(cmd, request_id, docs) start = datetime.datetime.now() @@ -777,6 +773,7 @@ class _BulkWriteContext(object): self._succeed(request_id, reply, duration) except Exception as exc: if self.publish: + assert self.start_time is not None duration = (datetime.datetime.now() - start) + duration if isinstance(exc, OperationFailure): failure = _convert_write_result( @@ -795,6 +792,7 @@ class _BulkWriteContext(object): """A proxy for SocketInfo.write_command that handles event publishing. """ if self.publish: + assert self.start_time is not None duration = datetime.datetime.now() - self.start_time self._start(cmd, request_id, docs) start = datetime.datetime.now() @@ -1171,7 +1169,8 @@ class _OpReply(object): if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): raise NotPrimaryError(error_object["$err"], error_object) elif error_object.get("code") == 50: - raise ExecutionTimeout(error_object.get("$err"), + default_msg = "operation exceeded time limit" + raise ExecutionTimeout(error_object.get("$err", default_msg), error_object.get("code"), error_object) raise OperationFailure("database error: %s" % diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 052ade385..975fc8761 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -34,46 +34,41 @@ access: import contextlib import threading import weakref - from collections import defaultdict +from typing import (TYPE_CHECKING, Any, Dict, FrozenSet, Generic, List, + Mapping, Optional, Sequence, Set, Tuple, Type, Union, cast) -from bson.codec_options import DEFAULT_CODEC_OPTIONS +import bson +from bson.codec_options import (DEFAULT_CODEC_OPTIONS, CodecOptions, + TypeRegistry) from bson.son import SON -from pymongo import (common, - database, - helpers, - message, - periodic_executor, - uri_parser, - client_session) -from pymongo.change_stream import ClusterChangeStream +from bson.timestamp import Timestamp +from pymongo import (client_session, common, database, helpers, message, + periodic_executor, uri_parser) +from pymongo.change_stream import ChangeStream, ClusterChangeStream from pymongo.client_options import ClientOptions from pymongo.command_cursor import CommandCursor -from pymongo.errors import (AutoReconnect, - BulkWriteError, - ConfigurationError, - ConnectionFailure, - InvalidOperation, - NotPrimaryError, - OperationFailure, - PyMongoError, +from pymongo.errors import (AutoReconnect, BulkWriteError, ConfigurationError, + ConnectionFailure, InvalidOperation, + NotPrimaryError, OperationFailure, PyMongoError, ServerSelectionTimeoutError) from pymongo.pool import ConnectionClosedReason -from pymongo.read_preferences import ReadPreference +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.topology import (Topology, - _ErrorContext) -from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.settings import TopologySettings -from pymongo.uri_parser import (_handle_option_deprecations, - _handle_security_options, - _normalize_options, - _check_options) -from pymongo.write_concern import DEFAULT_WRITE_CONCERN +from pymongo.topology import Topology, _ErrorContext +from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.uri_parser import (_check_options, _handle_option_deprecations, + _handle_security_options, _normalize_options) +from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern + +if TYPE_CHECKING: + from pymongo.read_concern import ReadConcern -class MongoClient(common.BaseObject): +class MongoClient(common.BaseObject, Generic[_DocumentType]): """ A client-side representation of a MongoDB cluster. @@ -89,15 +84,15 @@ class MongoClient(common.BaseObject): # No host/port; these are retrieved from TopologySettings. _constructor_args = ('document_class', 'tz_aware', 'connect') - def __init__( - self, - host=None, - port=None, - document_class=dict, - tz_aware=None, - connect=None, - type_registry=None, - **kwargs): + def __init__(self, + host: Optional[Union[str, Sequence[str]]] = None, + port: Optional[int] = None, + document_class: Type[_DocumentType] = dict, + tz_aware: Optional[bool] = None, + connect: Optional[bool] = None, + type_registry: Optional[TypeRegistry] = None, + **kwargs: Any, + ) -> None: """Client for a MongoDB instance, a replica set, or a set of mongoses. The client object is thread-safe and has connection-pooling built in. @@ -621,7 +616,7 @@ class MongoClient(common.BaseObject): client.__my_database__ """ - self.__init_kwargs = {'host': host, + self.__init_kwargs: Dict[str, Any] = {'host': host, 'port': port, 'document_class': document_class, 'tz_aware': tz_aware, @@ -722,7 +717,7 @@ class MongoClient(common.BaseObject): self.__default_database_name = dbase self.__lock = threading.Lock() - self.__kill_cursors_queue = [] + self.__kill_cursors_queue: List = [] self._event_listeners = options.pool_options._event_listeners super(MongoClient, self).__init__(options.codec_options, @@ -765,7 +760,7 @@ class MongoClient(common.BaseObject): # We strongly reference the executor and it weakly references us via # this closure. When the client is freed, stop the executor soon. - self_ref = weakref.ref(self, executor.close) + self_ref: Any = weakref.ref(self, executor.close) self._kill_cursors_executor = executor if connect: @@ -798,9 +793,17 @@ class MongoClient(common.BaseObject): return getattr(server.description, attr_name) - def watch(self, pipeline=None, full_document=None, resume_after=None, - max_await_time_ms=None, batch_size=None, collation=None, - start_at_operation_time=None, session=None, start_after=None): + def watch(self, + pipeline: Optional[_Pipeline] = None, + full_document: Optional[str] = None, + resume_after: Optional[Mapping[str, Any]] = None, + max_await_time_ms: Optional[int] = None, + batch_size: Optional[int] = None, + collation: Optional[_CollationIn] = None, + start_at_operation_time: Optional[Timestamp] = None, + session: Optional[client_session.ClientSession] = None, + start_after: Optional[Mapping[str, Any]] = None, + ) -> ChangeStream[_DocumentType]: """Watch changes on this cluster. Performs an aggregation with an implicit initial ``$changeStream`` @@ -891,7 +894,7 @@ class MongoClient(common.BaseObject): start_after) @property - def topology_description(self): + def topology_description(self) -> TopologyDescription: """The description of the connected MongoDB deployment. >>> client.topology_description @@ -913,7 +916,7 @@ class MongoClient(common.BaseObject): return self._topology.description @property - def address(self): + def address(self) -> Optional[Tuple[str, int]]: """(host, port) of the current standalone, primary, or mongos, or None. Accessing :attr:`address` raises :exc:`~.errors.InvalidOperation` if @@ -940,7 +943,7 @@ class MongoClient(common.BaseObject): return self._server_property('address') @property - def primary(self): + def primary(self) -> Optional[Tuple[str, int]]: """The (host, port) of the current primary of the replica set. Returns ``None`` if this client is not connected to a replica set, @@ -953,7 +956,7 @@ class MongoClient(common.BaseObject): return self._topology.get_primary() @property - def secondaries(self): + def secondaries(self) -> Set[Tuple[str, int]]: """The secondary members known to this client. A sequence of (host, port) pairs. Empty if this client is not @@ -966,7 +969,7 @@ class MongoClient(common.BaseObject): return self._topology.get_secondaries() @property - def arbiters(self): + def arbiters(self) -> Set[Tuple[str, int]]: """Arbiters in the replica set. A sequence of (host, port) pairs. Empty if this client is not @@ -976,7 +979,7 @@ class MongoClient(common.BaseObject): return self._topology.get_arbiters() @property - def is_primary(self): + def is_primary(self) -> bool: """If this client is connected to a server that can accept writes. True if the current server is a standalone, mongos, or the primary of @@ -987,7 +990,7 @@ class MongoClient(common.BaseObject): return self._server_property('is_writable') @property - def is_mongos(self): + def is_mongos(self) -> bool: """If this client is connected to mongos. If the client is not connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available.. @@ -995,7 +998,7 @@ class MongoClient(common.BaseObject): return self._server_property('server_type') == SERVER_TYPE.Mongos @property - def nodes(self): + def nodes(self) -> FrozenSet[Tuple[str, Optional[int]]]: """Set of all currently connected servers. .. warning:: When connected to a replica set the value of :attr:`nodes` @@ -1009,7 +1012,7 @@ class MongoClient(common.BaseObject): return frozenset(s.address for s in description.known_servers) @property - def options(self): + def options(self) -> ClientOptions: """The configuration options for this client. :Returns: @@ -1040,7 +1043,7 @@ class MongoClient(common.BaseObject): # command. pass - def close(self): + def close(self) -> None: """Cleanup client resources and disconnect from MongoDB. End all server sessions created by this client by sending one or more @@ -1214,7 +1217,7 @@ class MongoClient(common.BaseObject): def _retry_internal(self, retryable, func, session, bulk): """Internal retryable write helper.""" max_wire_version = 0 - last_error = None + last_error: Optional[Exception] = None retrying = False def is_retrying(): @@ -1239,6 +1242,7 @@ class MongoClient(common.BaseObject): if is_retrying(): # A retry is not possible because this server does # not support sessions raise the last error. + assert last_error is not None raise last_error retryable = False return func(session, sock_info, retryable) @@ -1247,6 +1251,7 @@ class MongoClient(common.BaseObject): # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry # attempt. Raise the original exception instead. + assert last_error is not None raise last_error # A ServerSelectionTimeoutError error indicates that there may # be a persistent outage. Attempting to retry in this case will @@ -1280,7 +1285,7 @@ class MongoClient(common.BaseObject): retryable = (retryable and self.options.retry_reads and not (session and session.in_transaction)) - last_error = None + last_error: Optional[Exception] = None retrying = False while True: @@ -1292,6 +1297,7 @@ class MongoClient(common.BaseObject): if retrying and not retryable: # A retry is not possible because this server does # not support retryable reads, raise the last error. + assert last_error is not None raise last_error return func(session, server, sock_info, read_pref) except ServerSelectionTimeoutError: @@ -1299,6 +1305,7 @@ class MongoClient(common.BaseObject): # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry # attempt. Raise the original exception instead. + assert last_error is not None raise last_error # A ServerSelectionTimeoutError error indicates that there may # be a persistent outage. Attempting to retry in this case will @@ -1322,15 +1329,15 @@ class MongoClient(common.BaseObject): with self._tmp_session(session) as s: return self._retry_with_session(retryable, func, s, None) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): return self._topology == other._topology return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: return hash(self._topology) def _repr_helper(self): @@ -1366,7 +1373,7 @@ class MongoClient(common.BaseObject): def __repr__(self): return ("MongoClient(%s)" % (self._repr_helper(),)) - def __getattr__(self, name): + def __getattr__(self, name: str) -> database.Database[_DocumentType]: """Get a database by name. Raises :class:`~pymongo.errors.InvalidName` if an invalid @@ -1381,7 +1388,7 @@ class MongoClient(common.BaseObject): " database, use client[%r]." % (name, name, name)) return self.__getitem__(name) - def __getitem__(self, name): + def __getitem__(self, name: str) -> database.Database[_DocumentType]: """Get a database by name. Raises :class:`~pymongo.errors.InvalidName` if an invalid @@ -1539,9 +1546,10 @@ class MongoClient(common.BaseObject): self, server_session, opts, implicit) def start_session(self, - causal_consistency=None, - default_transaction_options=None, - snapshot=False): + causal_consistency: Optional[bool] = None, + default_transaction_options: Optional[client_session.TransactionOptions] = None, + snapshot: Optional[bool] = False, + ) -> client_session.ClientSession[_DocumentType]: """Start a logical session. This method takes the same parameters as @@ -1630,7 +1638,9 @@ class MongoClient(common.BaseObject): if session is not None: session._process_response(reply) - def server_info(self, session=None): + def server_info(self, + session: Optional[client_session.ClientSession] = None + ) -> Dict[str, Any]: """Get information about the MongoDB server we're connected to. :Parameters: @@ -1644,7 +1654,10 @@ class MongoClient(common.BaseObject): read_preference=ReadPreference.PRIMARY, session=session) - def list_databases(self, session=None, **kwargs): + def list_databases(self, + session: Optional[client_session.ClientSession] = None, + **kwargs: Any + ) -> CommandCursor[Dict[str, Any]]: """Get a cursor over the databases of the connected server. :Parameters: @@ -1673,7 +1686,9 @@ class MongoClient(common.BaseObject): } return CommandCursor(admin["$cmd"], cursor, None) - def list_database_names(self, session=None): + def list_database_names(self, + session: Optional[client_session.ClientSession] = None + ) -> List[str]: """Get a list of the names of all databases on the connected server. :Parameters: @@ -1685,7 +1700,10 @@ class MongoClient(common.BaseObject): return [doc["name"] for doc in self.list_databases(session, nameOnly=True)] - def drop_database(self, name_or_database, session=None): + def drop_database(self, + name_or_database: Union[str, database.Database], + session: Optional[client_session.ClientSession] = None + ) -> None: """Drop a database. Raises :class:`TypeError` if `name_or_database` is not an instance of @@ -1727,8 +1745,13 @@ class MongoClient(common.BaseObject): parse_write_concern_error=True, session=session) - def get_default_database(self, default=None, codec_options=None, - read_preference=None, write_concern=None, read_concern=None): + def get_default_database(self, + default: Optional[str] = None, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional["ReadConcern"] = None, + ) -> database.Database[_DocumentType]: """Get the database named in the MongoDB connection URI. >>> uri = 'mongodb://host/my_database' @@ -1773,12 +1796,18 @@ class MongoClient(common.BaseObject): raise ConfigurationError( 'No default database name defined or provided.') + name = cast(str, self.__default_database_name or default) return database.Database( - self, self.__default_database_name or default, codec_options, + self, name, codec_options, read_preference, write_concern, read_concern) - def get_database(self, name=None, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def get_database(self, + name: Optional[str] = None, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional["ReadConcern"] = None, + ) -> database.Database[_DocumentType]: """Get a :class:`~pymongo.database.Database` with the given name and options. @@ -1838,16 +1867,16 @@ class MongoClient(common.BaseObject): read_preference=ReadPreference.PRIMARY, write_concern=DEFAULT_WRITE_CONCERN) - def __enter__(self): + def __enter__(self) -> "MongoClient[_DocumentType]": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def __iter__(self): + def __iter__(self) -> "MongoClient[_DocumentType]": return self - def __next__(self): + def __next__(self) -> None: raise TypeError("'MongoClient' object is not iterable") next = __next__ diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 039ec5194..388ba6168 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -18,10 +18,10 @@ import atexit import threading import time import weakref +from typing import Any, Mapping, cast from pymongo import common, periodic_executor -from pymongo.errors import (NotPrimaryError, - OperationFailure, +from pymongo.errors import (NotPrimaryError, OperationFailure, _OperationCancelled) from pymongo.hello import Hello from pymongo.periodic_executor import _shutdown_executors @@ -50,7 +50,7 @@ class MonitorBase(object): monitor = self_ref() if monitor is None: return False # Stop the executor. - monitor._run() + monitor._run() # type:ignore[attr-defined] return True executor = periodic_executor.PeriodicExecutor( @@ -214,8 +214,8 @@ class Monitor(MonitorBase): return self._check_once() except (OperationFailure, NotPrimaryError) as exc: # Update max cluster time even when hello fails. - self._topology.receive_cluster_time( - exc.details.get('$clusterTime')) + details = cast(Mapping[str, Any], exc.details) + self._topology.receive_cluster_time(details.get('$clusterTime')) raise except ReferenceError: raise diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index b877e19a2..6f57200a3 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -180,12 +180,21 @@ will not add that listener to existing client instances. handler first. """ +import datetime from collections import abc, namedtuple +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional -from pymongo.hello import HelloCompat +from bson.objectid import ObjectId +from pymongo.hello import Hello, HelloCompat from pymongo.helpers import _handle_exception +from pymongo.typings import _Address -_Listeners = namedtuple('Listeners', +if TYPE_CHECKING: + from pymongo.server_description import ServerDescription + from pymongo.topology_description import TopologyDescription + + +_Listeners = namedtuple('_Listeners', ('command_listeners', 'server_listeners', 'server_heartbeat_listeners', 'topology_listeners', 'cmap_listeners')) @@ -193,6 +202,9 @@ _Listeners = namedtuple('Listeners', _LISTENERS = _Listeners([], [], [], [], []) +_DocumentOut = Mapping[str, Any] + + class _EventListener(object): """Abstract base class for all event listeners.""" @@ -204,7 +216,7 @@ class CommandListener(_EventListener): and `CommandFailedEvent`. """ - def started(self, event): + def started(self, event: "CommandStartedEvent") -> None: """Abstract method to handle a `CommandStartedEvent`. :Parameters: @@ -212,7 +224,7 @@ class CommandListener(_EventListener): """ raise NotImplementedError - def succeeded(self, event): + def succeeded(self, event: "CommandSucceededEvent") -> None: """Abstract method to handle a `CommandSucceededEvent`. :Parameters: @@ -220,7 +232,7 @@ class CommandListener(_EventListener): """ raise NotImplementedError - def failed(self, event): + def failed(self, event: "CommandFailedEvent") -> None: """Abstract method to handle a `CommandFailedEvent`. :Parameters: @@ -245,7 +257,7 @@ class ConnectionPoolListener(_EventListener): .. versionadded:: 3.9 """ - def pool_created(self, event): + def pool_created(self, event: "PoolCreatedEvent") -> None: """Abstract method to handle a :class:`PoolCreatedEvent`. Emitted when a Connection Pool is created. @@ -255,7 +267,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def pool_ready(self, event): + def pool_ready(self, event: "PoolReadyEvent") -> None: """Abstract method to handle a :class:`PoolReadyEvent`. Emitted when a Connection Pool is marked ready. @@ -267,7 +279,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def pool_cleared(self, event): + def pool_cleared(self, event: "PoolClearedEvent") -> None: """Abstract method to handle a `PoolClearedEvent`. Emitted when a Connection Pool is cleared. @@ -277,7 +289,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def pool_closed(self, event): + def pool_closed(self, event: "PoolClosedEvent") -> None: """Abstract method to handle a `PoolClosedEvent`. Emitted when a Connection Pool is closed. @@ -287,7 +299,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_created(self, event): + def connection_created(self, event: "ConnectionCreatedEvent") -> None: """Abstract method to handle a :class:`ConnectionCreatedEvent`. Emitted when a Connection Pool creates a Connection object. @@ -297,7 +309,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_ready(self, event): + def connection_ready(self, event: "ConnectionReadyEvent") -> None: """Abstract method to handle a :class:`ConnectionReadyEvent`. Emitted when a Connection has finished its setup, and is now ready to @@ -308,7 +320,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_closed(self, event): + def connection_closed(self, event: "ConnectionClosedEvent") -> None: """Abstract method to handle a :class:`ConnectionClosedEvent`. Emitted when a Connection Pool closes a Connection. @@ -318,7 +330,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_check_out_started(self, event): + def connection_check_out_started(self, event: "ConnectionCheckOutStartedEvent") -> None: """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. Emitted when the driver starts attempting to check out a connection. @@ -328,7 +340,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_check_out_failed(self, event): + def connection_check_out_failed(self, event: "ConnectionCheckOutFailedEvent") -> None: """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. Emitted when the driver's attempt to check out a connection fails. @@ -338,7 +350,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_checked_out(self, event): + def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None: """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. Emitted when the driver successfully checks out a Connection. @@ -348,7 +360,7 @@ class ConnectionPoolListener(_EventListener): """ raise NotImplementedError - def connection_checked_in(self, event): + def connection_checked_in(self, event: "ConnectionCheckedInEvent") -> None: """Abstract method to handle a :class:`ConnectionCheckedInEvent`. Emitted when the driver checks in a Connection back to the Connection @@ -369,7 +381,7 @@ class ServerHeartbeatListener(_EventListener): .. versionadded:: 3.3 """ - def started(self, event): + def started(self, event: "ServerHeartbeatStartedEvent") -> None: """Abstract method to handle a `ServerHeartbeatStartedEvent`. :Parameters: @@ -377,7 +389,7 @@ class ServerHeartbeatListener(_EventListener): """ raise NotImplementedError - def succeeded(self, event): + def succeeded(self, event: "ServerHeartbeatSucceededEvent") -> None: """Abstract method to handle a `ServerHeartbeatSucceededEvent`. :Parameters: @@ -385,7 +397,7 @@ class ServerHeartbeatListener(_EventListener): """ raise NotImplementedError - def failed(self, event): + def failed(self, event: "ServerHeartbeatFailedEvent") -> None: """Abstract method to handle a `ServerHeartbeatFailedEvent`. :Parameters: @@ -402,7 +414,7 @@ class TopologyListener(_EventListener): .. versionadded:: 3.3 """ - def opened(self, event): + def opened(self, event: "TopologyOpenedEvent") -> None: """Abstract method to handle a `TopologyOpenedEvent`. :Parameters: @@ -410,7 +422,7 @@ class TopologyListener(_EventListener): """ raise NotImplementedError - def description_changed(self, event): + def description_changed(self, event: "TopologyDescriptionChangedEvent") -> None: """Abstract method to handle a `TopologyDescriptionChangedEvent`. :Parameters: @@ -418,7 +430,7 @@ class TopologyListener(_EventListener): """ raise NotImplementedError - def closed(self, event): + def closed(self, event: "TopologyClosedEvent") -> None: """Abstract method to handle a `TopologyClosedEvent`. :Parameters: @@ -435,7 +447,7 @@ class ServerListener(_EventListener): .. versionadded:: 3.3 """ - def opened(self, event): + def opened(self, event: "ServerOpeningEvent") -> None: """Abstract method to handle a `ServerOpeningEvent`. :Parameters: @@ -443,7 +455,7 @@ class ServerListener(_EventListener): """ raise NotImplementedError - def description_changed(self, event): + def description_changed(self, event: "ServerDescriptionChangedEvent") -> None: """Abstract method to handle a `ServerDescriptionChangedEvent`. :Parameters: @@ -451,7 +463,7 @@ class ServerListener(_EventListener): """ raise NotImplementedError - def closed(self, event): + def closed(self, event: "ServerClosedEvent") -> None: """Abstract method to handle a `ServerClosedEvent`. :Parameters: @@ -478,7 +490,7 @@ def _validate_event_listeners(option, listeners): return listeners -def register(listener): +def register(listener: _EventListener) -> None: """Register a global event listener. :Parameters: @@ -525,8 +537,14 @@ class _CommandEvent(object): __slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id", "__service_id") - def __init__(self, command_name, request_id, connection_id, operation_id, - service_id=None): + def __init__( + self, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + ) -> None: self.__cmd_name = command_name self.__rqst_id = request_id self.__conn_id = connection_id @@ -534,22 +552,22 @@ class _CommandEvent(object): self.__service_id = service_id @property - def command_name(self): + def command_name(self) -> str: """The command name.""" return self.__cmd_name @property - def request_id(self): + def request_id(self) -> int: """The request id for this operation.""" return self.__rqst_id @property - def connection_id(self): + def connection_id(self) -> _Address: """The address (host, port) of the server this command was sent to.""" return self.__conn_id @property - def service_id(self): + def service_id(self) -> Optional[ObjectId]: """The service_id this command was sent to, or ``None``. .. versionadded:: 3.12 @@ -557,7 +575,7 @@ class _CommandEvent(object): return self.__service_id @property - def operation_id(self): + def operation_id(self) -> Optional[int]: """An id for this series of events or None.""" return self.__op_id @@ -576,28 +594,36 @@ class CommandStartedEvent(_CommandEvent): """ __slots__ = ("__cmd", "__db") - def __init__(self, command, database_name, *args, service_id=None): + def __init__( + self, + command: _DocumentOut, + database_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + ) -> None: if not command: raise ValueError("%r is not a valid command" % (command,)) # Command name must be first key. command_name = next(iter(command)) super(CommandStartedEvent, self).__init__( - command_name, *args, service_id=service_id) + command_name, request_id, connection_id, operation_id, service_id=service_id) cmd_name, cmd_doc = command_name.lower(), command[command_name] if (cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command)): - self.__cmd = {} + self.__cmd: Mapping[str, Any] = {} else: self.__cmd = command self.__db = database_name @property - def command(self): + def command(self) -> _DocumentOut: """The command document.""" return self.__cmd @property - def database_name(self): + def database_name(self) -> str: """The name of the database this command was run against.""" return self.__db @@ -625,8 +651,16 @@ class CommandSucceededEvent(_CommandEvent): """ __slots__ = ("__duration_micros", "__reply") - def __init__(self, duration, reply, command_name, - request_id, connection_id, operation_id, service_id=None): + def __init__( + self, + duration: datetime.timedelta, + reply: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + ) -> None: super(CommandSucceededEvent, self).__init__( command_name, request_id, connection_id, operation_id, service_id=service_id) @@ -634,17 +668,17 @@ class CommandSucceededEvent(_CommandEvent): cmd_name = command_name.lower() if (cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply)): - self.__reply = {} + self.__reply: Mapping[str, Any] = {} else: self.__reply = reply @property - def duration_micros(self): + def duration_micros(self) -> int: """The duration of this operation in microseconds.""" return self.__duration_micros @property - def reply(self): + def reply(self) -> _DocumentOut: """The server failure document for this operation.""" return self.__reply @@ -672,18 +706,27 @@ class CommandFailedEvent(_CommandEvent): """ __slots__ = ("__duration_micros", "__failure") - def __init__(self, duration, failure, *args, service_id=None): - super(CommandFailedEvent, self).__init__(*args, service_id=service_id) + def __init__( + self, + duration: datetime.timedelta, + failure: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + ) -> None: + super(CommandFailedEvent, self).__init__(command_name, request_id, connection_id, operation_id, service_id=service_id) self.__duration_micros = _to_micros(duration) self.__failure = failure @property - def duration_micros(self): + def duration_micros(self) -> int: """The duration of this operation in microseconds.""" return self.__duration_micros @property - def failure(self): + def failure(self) -> _DocumentOut: """The server failure document for this operation.""" return self.__failure @@ -700,11 +743,11 @@ class _PoolEvent(object): """Base class for pool events.""" __slots__ = ("__address",) - def __init__(self, address): + def __init__(self, address: _Address) -> None: self.__address = address @property - def address(self): + def address(self) -> _Address: """The address (host, port) pair of the server the pool is attempting to connect to. """ @@ -725,12 +768,12 @@ class PoolCreatedEvent(_PoolEvent): """ __slots__ = ("__options",) - def __init__(self, address, options): + def __init__(self, address: _Address, options: Dict[str, Any]) -> None: super(PoolCreatedEvent, self).__init__(address) self.__options = options @property - def options(self): + def options(self) -> Dict[str, Any]: """Any non-default pool options that were set on this Connection Pool. """ return self.__options @@ -764,12 +807,12 @@ class PoolClearedEvent(_PoolEvent): """ __slots__ = ("__service_id",) - def __init__(self, address, service_id=None): + def __init__(self, address: _Address, service_id: Optional[ObjectId] = None) -> None: super(PoolClearedEvent, self).__init__(address) self.__service_id = service_id @property - def service_id(self): + def service_id(self) -> Optional[ObjectId]: """Connections with this service_id are cleared. When service_id is ``None``, all connections in the pool are cleared. @@ -839,19 +882,19 @@ class _ConnectionEvent(object): """Private base class for some connection events.""" __slots__ = ("__address", "__connection_id") - def __init__(self, address, connection_id): + def __init__(self, address: _Address, connection_id: int) -> None: self.__address = address self.__connection_id = connection_id @property - def address(self): + def address(self) -> _Address: """The address (host, port) pair of the server this connection is attempting to connect to. """ return self.__address @property - def connection_id(self): + def connection_id(self) -> int: """The ID of the Connection.""" return self.__connection_id @@ -958,19 +1001,19 @@ class ConnectionCheckOutFailedEvent(object): """ __slots__ = ("__address", "__reason") - def __init__(self, address, reason): + def __init__(self, address: _Address, reason: str) -> None: self.__address = address self.__reason = reason @property - def address(self): + def address(self) -> _Address: """The address (host, port) pair of the server this connection is attempting to connect to. """ return self.__address @property - def reason(self): + def reason(self) -> str: """A reason explaining why connection check out failed. The reason must be one of the strings from the @@ -1014,17 +1057,17 @@ class _ServerEvent(object): __slots__ = ("__server_address", "__topology_id") - def __init__(self, server_address, topology_id): + def __init__(self, server_address: _Address, topology_id: ObjectId) -> None: self.__server_address = server_address self.__topology_id = topology_id @property - def server_address(self): + def server_address(self) -> _Address: """The address (host, port) pair of the server""" return self.__server_address @property - def topology_id(self): + def topology_id(self) -> ObjectId: """A unique identifier for the topology this server is a part of.""" return self.__topology_id @@ -1041,19 +1084,19 @@ class ServerDescriptionChangedEvent(_ServerEvent): __slots__ = ('__previous_description', '__new_description') - def __init__(self, previous_description, new_description, *args): + def __init__(self, previous_description: "ServerDescription", new_description: "ServerDescription", *args: Any) -> None: super(ServerDescriptionChangedEvent, self).__init__(*args) self.__previous_description = previous_description self.__new_description = new_description @property - def previous_description(self): + def previous_description(self) -> "ServerDescription": """The previous :class:`~pymongo.server_description.ServerDescription`.""" return self.__previous_description @property - def new_description(self): + def new_description(self) -> "ServerDescription": """The new :class:`~pymongo.server_description.ServerDescription`.""" return self.__new_description @@ -1087,11 +1130,11 @@ class TopologyEvent(object): __slots__ = ('__topology_id') - def __init__(self, topology_id): + def __init__(self, topology_id: ObjectId) -> None: self.__topology_id = topology_id @property - def topology_id(self): + def topology_id(self) -> ObjectId: """A unique identifier for the topology this server is a part of.""" return self.__topology_id @@ -1108,19 +1151,19 @@ class TopologyDescriptionChangedEvent(TopologyEvent): __slots__ = ('__previous_description', '__new_description') - def __init__(self, previous_description, new_description, *args): + def __init__(self, previous_description: "TopologyDescription", new_description: "TopologyDescription", *args: Any) -> None: super(TopologyDescriptionChangedEvent, self).__init__(*args) self.__previous_description = previous_description self.__new_description = new_description @property - def previous_description(self): + def previous_description(self) -> "TopologyDescription": """The previous :class:`~pymongo.topology_description.TopologyDescription`.""" return self.__previous_description @property - def new_description(self): + def new_description(self) -> "TopologyDescription": """The new :class:`~pymongo.topology_description.TopologyDescription`.""" return self.__new_description @@ -1154,11 +1197,11 @@ class _ServerHeartbeatEvent(object): __slots__ = ('__connection_id') - def __init__(self, connection_id): + def __init__(self, connection_id: _Address) -> None: self.__connection_id = connection_id @property - def connection_id(self): + def connection_id(self) -> _Address: """The address (host, port) of the server this heartbeat was sent to.""" return self.__connection_id @@ -1184,24 +1227,24 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): __slots__ = ('__duration', '__reply', '__awaited') - def __init__(self, duration, reply, connection_id, awaited=False): + def __init__(self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False) -> None: super(ServerHeartbeatSucceededEvent, self).__init__(connection_id) self.__duration = duration self.__reply = reply self.__awaited = awaited @property - def duration(self): + def duration(self) -> float: """The duration of this heartbeat in microseconds.""" return self.__duration @property - def reply(self): + def reply(self) -> Hello: """An instance of :class:`~pymongo.hello.Hello`.""" return self.__reply @property - def awaited(self): + def awaited(self) -> bool: """Whether the heartbeat was awaited. If true, then :meth:`duration` reflects the sum of the round trip time @@ -1225,24 +1268,24 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): __slots__ = ('__duration', '__reply', '__awaited') - def __init__(self, duration, reply, connection_id, awaited=False): + def __init__(self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False) -> None: super(ServerHeartbeatFailedEvent, self).__init__(connection_id) self.__duration = duration self.__reply = reply self.__awaited = awaited @property - def duration(self): + def duration(self) -> float: """The duration of this heartbeat in microseconds.""" return self.__duration @property - def reply(self): + def reply(self) -> Exception: """A subclass of :exc:`Exception`.""" return self.__reply @property - def awaited(self): + def awaited(self) -> bool: """Whether the heartbeat was awaited. If true, then :meth:`duration` reflects the sum of the round trip time diff --git a/pymongo/network.py b/pymongo/network.py index a14e9924a..48e5084e3 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -20,21 +20,16 @@ import socket import struct import time - from bson import _decode_all_selective - from pymongo import helpers, message from pymongo.common import MAX_MESSAGE_SIZE -from pymongo.compression_support import decompress, _NO_COMPRESSION -from pymongo.errors import (NotPrimaryError, - OperationFailure, - ProtocolError, +from pymongo.compression_support import _NO_COMPRESSION, decompress +from pymongo.errors import (NotPrimaryError, OperationFailure, ProtocolError, _OperationCancelled) from pymongo.message import _UNPACK_REPLY, _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.socket_checker import _errno_from_exception - _UNPACK_HEADER = struct.Struct(" None: """Create an InsertOne instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -43,21 +45,25 @@ class InsertOne(object): def __repr__(self): return "InsertOne(%r)" % (self._doc,) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if type(other) == type(self): return other._doc == self._doc return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other +_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]] +_IndexKeyHint = Union[str, _IndexList] + + class DeleteOne(object): """Represents a delete_one operation.""" __slots__ = ("_filter", "_collation", "_hint") - def __init__(self, filter, collation=None, hint=None): + def __init__(self, filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None) -> None: """Create a DeleteOne instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -95,13 +101,13 @@ class DeleteOne(object): def __repr__(self): return "DeleteOne(%r, %r)" % (self._filter, self._collation) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if type(other) == type(self): return ((other._filter, other._collation) == (self._filter, self._collation)) return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -110,7 +116,7 @@ class DeleteMany(object): __slots__ = ("_filter", "_collation", "_hint") - def __init__(self, filter, collation=None, hint=None): + def __init__(self, filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None) -> None: """Create a DeleteMany instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -148,13 +154,13 @@ class DeleteMany(object): def __repr__(self): return "DeleteMany(%r, %r)" % (self._filter, self._collation) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if type(other) == type(self): return ((other._filter, other._collation) == (self._filter, self._collation)) return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other @@ -163,8 +169,8 @@ class ReplaceOne(object): __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") - def __init__(self, filter, replacement, upsert=False, collation=None, - hint=None): + def __init__(self, filter: Mapping[str, Any], replacement: Mapping[str, Any], upsert: bool = False, collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None) -> None: """Create a ReplaceOne instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -207,7 +213,7 @@ class ReplaceOne(object): bulkobj.add_replace(self._filter, self._doc, self._upsert, collation=self._collation, hint=self._hint) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if type(other) == type(self): return ( (other._filter, other._doc, other._upsert, other._collation, @@ -215,7 +221,7 @@ class ReplaceOne(object): self._collation, other._hint)) return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other def __repr__(self): @@ -241,7 +247,6 @@ class _UpdateOp(object): if not isinstance(hint, str): hint = helpers._index_document(hint) - self._filter = filter self._doc = doc self._upsert = upsert @@ -272,8 +277,8 @@ class UpdateOne(_UpdateOp): __slots__ = () - def __init__(self, filter, update, upsert=False, collation=None, - array_filters=None, hint=None): + def __init__(self, filter: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, collation: Optional[_CollationIn] = None, + array_filters: Optional[List[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None) -> None: """Represents an update_one operation. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -319,8 +324,8 @@ class UpdateMany(_UpdateOp): __slots__ = () - def __init__(self, filter, update, upsert=False, collation=None, - array_filters=None, hint=None): + def __init__(self, filter: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, collation: Optional[_CollationIn] = None, + array_filters: Optional[List[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None) -> None: """Create an UpdateMany instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -366,7 +371,7 @@ class IndexModel(object): __slots__ = ("__document",) - def __init__(self, keys, **kwargs): + def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None: """Create an Index instance. For use with :meth:`~pymongo.collection.Collection.create_indexes`. @@ -437,7 +442,7 @@ class IndexModel(object): self.__document['collation'] = collation @property - def document(self): + def document(self) -> Dict[str, Any]: """An index document suitable for passing to the createIndexes command. """ diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index e1690ee9b..36e094c4c 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -17,6 +17,7 @@ import threading import time import weakref +from typing import Any, Optional class PeriodicExecutor(object): @@ -41,7 +42,7 @@ class PeriodicExecutor(object): self._min_interval = min_interval self._target = target self._stopped = False - self._thread = None + self._thread: Optional[threading.Thread] = None self._name = name self._skip_sleep = False @@ -52,7 +53,7 @@ class PeriodicExecutor(object): return '<%s(name=%s) object at 0x%x>' % ( self.__class__.__name__, self._name, id(self)) - def open(self): + def open(self) -> None: """Start. Multiple calls have no effect. Not safe to call from multiple threads at once. @@ -64,13 +65,14 @@ class PeriodicExecutor(object): # join should not block indefinitely because there is no # other work done outside the while loop in self._run. try: + assert self._thread is not None self._thread.join() except ReferenceError: # Thread terminated. pass self._thread_will_exit = False self._stopped = False - started = False + started: Any = False try: started = self._thread and self._thread.is_alive() except ReferenceError: @@ -84,7 +86,7 @@ class PeriodicExecutor(object): _register_executor(self) thread.start() - def close(self, dummy=None): + def close(self, dummy: Any = None) -> None: """Stop. To restart, call open(). The dummy parameter allows an executor's close method to be a weakref @@ -92,7 +94,7 @@ class PeriodicExecutor(object): """ self._stopped = True - def join(self, timeout=None): + def join(self, timeout: Optional[int] = None) -> None: if self._thread is not None: try: self._thread.join(timeout) @@ -100,14 +102,14 @@ class PeriodicExecutor(object): # Thread already terminated, or not yet started. pass - def wake(self): + def wake(self) -> None: """Execute the target function soon.""" self._event = True - def update_interval(self, new_interval): + def update_interval(self, new_interval: int) -> None: self._interval = new_interval - def skip_sleep(self): + def skip_sleep(self) -> None: self._skip_sleep = True def __should_stop(self): diff --git a/pymongo/pool.py b/pymongo/pool.py index 70920d5b2..c53c9f473 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -18,50 +18,38 @@ import copy import ipaddress import os import platform -import ssl import socket +import ssl import sys import threading import time import weakref +from typing import Any from bson import DEFAULT_CODEC_OPTIONS from bson.son import SON -from pymongo import auth, helpers, __version__ +from pymongo import __version__, auth, helpers from pymongo.client_session import _validate_session_write_concern -from pymongo.common import (MAX_BSON_SIZE, - MAX_CONNECTING, - MAX_IDLE_TIME_SEC, - MAX_MESSAGE_SIZE, - MAX_POOL_SIZE, - MAX_WIRE_VERSION, - MAX_WRITE_BATCH_SIZE, - MIN_POOL_SIZE, - ORDERED_TYPES, +from pymongo.common import (MAX_BSON_SIZE, MAX_CONNECTING, MAX_IDLE_TIME_SEC, + MAX_MESSAGE_SIZE, MAX_POOL_SIZE, MAX_WIRE_VERSION, + MAX_WRITE_BATCH_SIZE, MIN_POOL_SIZE, ORDERED_TYPES, WAIT_QUEUE_TIMEOUT) -from pymongo.errors import (AutoReconnect, - _CertificateError, - ConnectionFailure, - ConfigurationError, - InvalidOperation, - DocumentTooLarge, - NetworkTimeout, - NotPrimaryError, - OperationFailure, - PyMongoError) -from pymongo.hello import HelloCompat, Hello +from pymongo.errors import (AutoReconnect, ConfigurationError, + ConnectionFailure, DocumentTooLarge, + InvalidOperation, NetworkTimeout, NotPrimaryError, + OperationFailure, PyMongoError, _CertificateError) +from pymongo.hello import Hello, HelloCompat from pymongo.monitoring import (ConnectionCheckOutFailedReason, ConnectionClosedReason) -from pymongo.network import (command, - receive_message) +from pymongo.network import command, receive_message from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import ( - SSLError as _SSLError, - HAS_SNI as _HAVE_SNI, - IPADDR_SAFE as _IPADDR_SAFE) +from pymongo.ssl_support import HAS_SNI as _HAVE_SNI +from pymongo.ssl_support import IPADDR_SAFE as _IPADDR_SAFE +from pymongo.ssl_support import SSLError as _SSLError + # For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are # not permitted for SNI hostname. @@ -73,7 +61,7 @@ def is_ip_address(address): return False try: - from fcntl import fcntl, F_GETFD, F_SETFD, FD_CLOEXEC + from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl def _set_non_inheritable_non_atomic(fd): """Set the close-on-exec flag on the given file descriptor.""" flags = fcntl(fd, F_GETFD) @@ -82,7 +70,7 @@ except ImportError: # Windows, various platforms we don't claim to support # (Jython, IronPython, ...), systems that don't provide # everything we need from fcntl, etc. - def _set_non_inheritable_non_atomic(dummy): + def _set_non_inheritable_non_atomic(fd): """Dummy function for platforms that don't provide fcntl.""" pass @@ -145,7 +133,7 @@ else: _set_tcp_option(sock, 'TCP_KEEPINTVL', _MAX_TCP_KEEPINTVL) _set_tcp_option(sock, 'TCP_KEEPCNT', _MAX_TCP_KEEPCNT) -_METADATA = SON([ +_METADATA: SON[str, Any] = SON([ ('driver', SON([('name', 'PyMongo'), ('version', __version__)])), ]) @@ -205,7 +193,7 @@ else: if platform.python_implementation().startswith('PyPy'): _METADATA['platform'] = ' '.join( (platform.python_implementation(), - '.'.join(map(str, sys.pypy_version_info)), + '.'.join(map(str, sys.pypy_version_info)), # type: ignore '(Python %s)' % '.'.join(map(str, sys.version_info)))) elif sys.platform.startswith('java'): _METADATA['platform'] = ' '.join( @@ -688,7 +676,7 @@ class SocketInfo(object): session = _validate_session_write_concern(session, write_concern) # Ensure command name remains in first place. - if not isinstance(spec, ORDERED_TYPES): + if not isinstance(spec, ORDERED_TYPES): # type:ignore[arg-type] spec = SON(spec) if not (write_concern is None or write_concern.acknowledged or @@ -1088,7 +1076,7 @@ class Pool: # LIFO pool. Sockets are ordered on idle time. Sockets claimed # and returned to pool from the left side. Stale sockets removed # from the right side. - self.sockets = collections.deque() + self.sockets: collections.deque = collections.deque() self.lock = threading.Lock() self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. @@ -1165,8 +1153,8 @@ class Pool: if service_id is None: sockets, self.sockets = self.sockets, collections.deque() else: - discard = collections.deque() - keep = collections.deque() + discard: collections.deque = collections.deque() + keep: collections.deque = collections.deque() for sock_info in self.sockets: if sock_info.service_id == service_id: discard.append(sock_info) diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index f7c53a59e..c5a5f0936 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -20,29 +20,28 @@ import socket as _socket import ssl as _stdlibssl import sys as _sys import time as _time - from errno import EINTR as _EINTR - from ipaddress import ip_address as _ip_address -from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate -from OpenSSL import crypto as _crypto, SSL as _SSL -from service_identity.pyopenssl import ( - verify_hostname as _verify_hostname, - verify_ip_address as _verify_ip_address) +from cryptography.x509 import \ + load_der_x509_certificate as _load_der_x509_certificate +from OpenSSL import SSL as _SSL +from OpenSSL import crypto as _crypto from service_identity import ( - CertificateError as _SICertificateError, - VerificationError as _SIVerificationError) + CertificateError as _SICertificateError +) +from service_identity import VerificationError as _SIVerificationError +from service_identity.pyopenssl import ( # + verify_hostname as _verify_hostname +) +from service_identity.pyopenssl import verify_ip_address as _verify_ip_address -from pymongo.errors import ( - _CertificateError, - ConfigurationError as _ConfigurationError) -from pymongo.ocsp_support import ( - _load_trusted_ca_certs, - _ocsp_callback) +from pymongo.errors import ConfigurationError as _ConfigurationError +from pymongo.errors import _CertificateError from pymongo.ocsp_cache import _OCSPCache -from pymongo.socket_checker import ( - _errno_from_exception, SocketChecker as _SocketChecker) +from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback +from pymongo.socket_checker import SocketChecker as _SocketChecker +from pymongo.socket_checker import _errno_from_exception try: import certifi @@ -132,7 +131,7 @@ class _sslConn(_SSL.Connection): def recv_into(self, *args, **kwargs): try: - return self._call(super(_sslConn, self).recv_into, *args, **kwargs) + return self._call(super(_sslConn, self).recv_into, *args, **kwargs) # type: ignore except _SSL.SysCallError as exc: # Suppress ragged EOFs to match the stdlib. if self.suppress_ragged_eofs and _ragged_eof(exc): @@ -147,7 +146,7 @@ class _sslConn(_SSL.Connection): while total_sent < total_length: try: sent = self._call( - super(_sslConn, self).send, view[total_sent:], flags) + super(_sslConn, self).send, view[total_sent:], flags) # type: ignore # XXX: It's not clear if this can actually happen. PyOpenSSL # doesn't appear to have any interrupt handling, nor any interrupt # errors for OpenSSL connections. @@ -296,7 +295,7 @@ class SSLContext(object): """Attempt to load CA certs from Windows trust store.""" cert_store = self._ctx.get_cert_store() oid = _stdlibssl.Purpose.SERVER_AUTH.oid - for cert, encoding, trust in _stdlibssl.enum_certificates(store): + for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore if encoding == "x509_asn": if trust is True or oid in trust: cert_store.add_cert( diff --git a/pymongo/read_concern.py b/pymongo/read_concern.py index 7e9cc4485..aaf67ef5a 100644 --- a/pymongo/read_concern.py +++ b/pymongo/read_concern.py @@ -14,6 +14,8 @@ """Tools for working with read concerns.""" +from typing import Any, Dict, Optional + class ReadConcern(object): """ReadConcern @@ -29,7 +31,7 @@ class ReadConcern(object): """ - def __init__(self, level=None): + def __init__(self, level: Optional[str] = None) -> None: if level is None or isinstance(level, str): self.__level = level else: @@ -37,18 +39,18 @@ class ReadConcern(object): 'level must be a string or None.') @property - def level(self): + def level(self) -> Optional[str]: """The read concern level.""" return self.__level @property - def ok_for_legacy(self): + def ok_for_legacy(self) -> bool: """Return ``True`` if this read concern is compatible with old wire protocol versions.""" return self.level is None or self.level == 'local' @property - def document(self): + def document(self) -> Dict[str, Any]: """The document representation of this read concern. .. note:: @@ -60,7 +62,7 @@ class ReadConcern(object): doc['level'] = self.level return doc - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, ReadConcern): return self.document == other.document return NotImplemented diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index 2471d5834..cc1317fb8 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -15,13 +15,13 @@ """Utilities for choosing which member of a replica set to read from.""" from collections import abc +from typing import Any, Dict, Mapping, Optional, Sequence from pymongo import max_staleness_selectors from pymongo.errors import ConfigurationError from pymongo.server_selectors import (member_with_tags_server_selector, secondary_with_tags_server_selector) - _PRIMARY = 0 _PRIMARY_PREFERRED = 1 _SECONDARY = 2 @@ -44,9 +44,9 @@ def _validate_tag_sets(tag_sets): if tag_sets is None: return tag_sets - if not isinstance(tag_sets, list): + if not isinstance(tag_sets, (list, tuple)): raise TypeError(( - "Tag sets %r invalid, must be a list") % (tag_sets,)) + "Tag sets %r invalid, must be a sequence") % (tag_sets,)) if len(tag_sets) == 0: raise ValueError(( "Tag sets %r invalid, must be None or contain at least one set of" @@ -59,7 +59,7 @@ def _validate_tag_sets(tag_sets): "bson.son.SON or other type that inherits from " "collection.Mapping" % (tags,)) - return tag_sets + return list(tag_sets) def _invalid_max_staleness_msg(max_staleness): @@ -93,6 +93,10 @@ def _validate_hedge(hedge): return hedge +_Hedge = Mapping[str, Any] +_TagSets = Sequence[Mapping[str, Any]] + + class _ServerMode(object): """Base class for all read preferences. """ @@ -100,7 +104,7 @@ class _ServerMode(object): __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") - def __init__(self, mode, tag_sets=None, max_staleness=-1, hedge=None): + def __init__(self, mode: int, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None: self.__mongos_mode = _MONGOS_MODES[mode] self.__mode = mode self.__tag_sets = _validate_tag_sets(tag_sets) @@ -108,22 +112,22 @@ class _ServerMode(object): self.__hedge = _validate_hedge(hedge) @property - def name(self): + def name(self) -> str: """The name of this read preference. """ return self.__class__.__name__ @property - def mongos_mode(self): + def mongos_mode(self) -> str: """The mongos mode of this read preference. """ return self.__mongos_mode @property - def document(self): + def document(self) -> Dict[str, Any]: """Read preference as a document. """ - doc = {'mode': self.__mongos_mode} + doc: Dict[str, Any] = {'mode': self.__mongos_mode} if self.__tag_sets not in (None, [{}]): doc['tags'] = self.__tag_sets if self.__max_staleness != -1: @@ -133,13 +137,13 @@ class _ServerMode(object): return doc @property - def mode(self): + def mode(self) -> int: """The mode of this read preference instance. """ return self.__mode @property - def tag_sets(self): + def tag_sets(self) -> _TagSets: """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to read only from members whose ``dc`` tag has the value ``"ny"``. To specify a priority-order for tag sets, provide a list of @@ -154,14 +158,14 @@ class _ServerMode(object): return list(self.__tag_sets) if self.__tag_sets else [{}] @property - def max_staleness(self): + def max_staleness(self) -> int: """The maximum estimated length of time (in seconds) a replica set secondary can fall behind the primary in replication before it will no longer be selected for operations, or -1 for no maximum.""" return self.__max_staleness @property - def hedge(self): + def hedge(self) -> Optional[_Hedge]: """The read preference ``hedge`` parameter. A dictionary that configures how the server will perform hedged reads. @@ -185,7 +189,7 @@ class _ServerMode(object): return self.__hedge @property - def min_wire_version(self): + def min_wire_version(self) -> int: """The wire protocol version the server must support. Some read preferences impose version requirements on all servers (e.g. @@ -201,7 +205,7 @@ class _ServerMode(object): return "%s(tag_sets=%r, max_staleness=%r, hedge=%r)" % ( self.name, self.__tag_sets, self.__max_staleness, self.__hedge) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, _ServerMode): return (self.mode == other.mode and self.tag_sets == other.tag_sets and @@ -209,7 +213,7 @@ class _ServerMode(object): self.hedge == other.hedge) return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other def __getstate__(self): @@ -243,17 +247,17 @@ class Primary(_ServerMode): __slots__ = () - def __init__(self): + def __init__(self) -> None: super(Primary, self).__init__(_PRIMARY) - def __call__(self, selection): + def __call__(self, selection: Any) -> Any: """Apply this read preference to a Selection.""" return selection.primary_selection def __repr__(self): return "Primary()" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, _ServerMode): return other.mode == _PRIMARY return NotImplemented @@ -289,11 +293,11 @@ class PrimaryPreferred(_ServerMode): __slots__ = () - def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): + def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None: super(PrimaryPreferred, self).__init__( _PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) - def __call__(self, selection): + def __call__(self, selection: Any) -> Any: """Apply this read preference to Selection.""" if selection.primary: return selection.primary_selection @@ -329,11 +333,11 @@ class Secondary(_ServerMode): __slots__ = () - def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): + def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None: super(Secondary, self).__init__( _SECONDARY, tag_sets, max_staleness, hedge) - def __call__(self, selection): + def __call__(self, selection: Any) -> Any: """Apply this read preference to Selection.""" return secondary_with_tags_server_selector( self.tag_sets, @@ -370,11 +374,11 @@ class SecondaryPreferred(_ServerMode): __slots__ = () - def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): + def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None: super(SecondaryPreferred, self).__init__( _SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) - def __call__(self, selection): + def __call__(self, selection: Any) -> Any: """Apply this read preference to Selection.""" secondaries = secondary_with_tags_server_selector( self.tag_sets, @@ -412,11 +416,11 @@ class Nearest(_ServerMode): __slots__ = () - def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): + def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None: super(Nearest, self).__init__( _NEAREST, tag_sets, max_staleness, hedge) - def __call__(self, selection): + def __call__(self, selection: Any) -> Any: """Apply this read preference to Selection.""" return member_with_tags_server_selector( self.tag_sets, @@ -467,7 +471,7 @@ _ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) -def make_read_preference(mode, tag_sets, max_staleness=-1): +def make_read_preference(mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1) -> _ServerMode: if mode == _PRIMARY: if tag_sets not in (None, [{}]): raise ConfigurationError("Read preference primary " @@ -476,7 +480,7 @@ def make_read_preference(mode, tag_sets, max_staleness=-1): raise ConfigurationError("Read preference primary cannot be " "combined with maxStalenessSeconds") return Primary() - return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) + return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore _MODES = ( @@ -545,7 +549,7 @@ class ReadPreference(object): NEAREST = Nearest() -def read_pref_mode_from_name(name): +def read_pref_mode_from_name(name: str) -> int: """Get the read preference mode from mongos/uri name. """ return _MONGOS_MODES.index(name) @@ -553,10 +557,12 @@ def read_pref_mode_from_name(name): class MovingAverage(object): """Tracks an exponentially-weighted moving average.""" - def __init__(self): + average: Optional[float] + + def __init__(self) -> None: self.average = None - def add_sample(self, sample): + def add_sample(self, sample: float) -> None: if sample < 0: # Likely system time change while waiting for hello response # and not using time.monotonic. Ignore it, the next one will @@ -569,9 +575,9 @@ class MovingAverage(object): # average with alpha = 0.2. self.average = 0.8 * self.average + 0.2 * sample - def get(self): + def get(self) -> Optional[float]: """Get the calculated average, or None if no samples yet.""" return self.average - def reset(self): + def reset(self) -> None: self.average = None diff --git a/pymongo/results.py b/pymongo/results.py index 037480324..637bf73b0 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -13,6 +13,7 @@ # limitations under the License. """Result class definitions.""" +from typing import Any, Dict, List, Mapping, Optional, Sequence, cast from pymongo.errors import InvalidOperation @@ -22,7 +23,7 @@ class _WriteResult(object): __slots__ = ("__acknowledged",) - def __init__(self, acknowledged): + def __init__(self, acknowledged: bool) -> None: self.__acknowledged = acknowledged def _raise_if_unacknowledged(self, property_name): @@ -34,7 +35,7 @@ class _WriteResult(object): "error." % (property_name,)) @property - def acknowledged(self): + def acknowledged(self) -> bool: """Is this the result of an acknowledged write operation? The :attr:`acknowledged` attribute will be ``False`` when using @@ -59,12 +60,12 @@ class InsertOneResult(_WriteResult): __slots__ = ("__inserted_id", "__acknowledged") - def __init__(self, inserted_id, acknowledged): + def __init__(self, inserted_id: Any, acknowledged: bool) -> None: self.__inserted_id = inserted_id super(InsertOneResult, self).__init__(acknowledged) @property - def inserted_id(self): + def inserted_id(self) -> Any: """The inserted document's _id.""" return self.__inserted_id @@ -75,12 +76,12 @@ class InsertManyResult(_WriteResult): __slots__ = ("__inserted_ids", "__acknowledged") - def __init__(self, inserted_ids, acknowledged): + def __init__(self, inserted_ids: List[Any], acknowledged: bool) -> None: self.__inserted_ids = inserted_ids super(InsertManyResult, self).__init__(acknowledged) @property - def inserted_ids(self): + def inserted_ids(self) -> List: """A list of _ids of the inserted documents, in the order provided. .. note:: If ``False`` is passed for the `ordered` parameter to @@ -99,17 +100,17 @@ class UpdateResult(_WriteResult): __slots__ = ("__raw_result", "__acknowledged") - def __init__(self, raw_result, acknowledged): + def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None: self.__raw_result = raw_result super(UpdateResult, self).__init__(acknowledged) @property - def raw_result(self): + def raw_result(self) -> Dict[str, Any]: """The raw result document returned by the server.""" return self.__raw_result @property - def matched_count(self): + def matched_count(self) -> int: """The number of documents matched for this update.""" self._raise_if_unacknowledged("matched_count") if self.upserted_id is not None: @@ -117,13 +118,13 @@ class UpdateResult(_WriteResult): return self.__raw_result.get("n", 0) @property - def modified_count(self): + def modified_count(self) -> int: """The number of documents modified. """ self._raise_if_unacknowledged("modified_count") - return self.__raw_result.get("nModified") + return cast(int, self.__raw_result.get("nModified")) @property - def upserted_id(self): + def upserted_id(self) -> Any: """The _id of the inserted document if an upsert took place. Otherwise ``None``. """ @@ -137,17 +138,17 @@ class DeleteResult(_WriteResult): __slots__ = ("__raw_result", "__acknowledged") - def __init__(self, raw_result, acknowledged): + def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None: self.__raw_result = raw_result super(DeleteResult, self).__init__(acknowledged) @property - def raw_result(self): + def raw_result(self) -> Dict[str, Any]: """The raw result document returned by the server.""" return self.__raw_result @property - def deleted_count(self): + def deleted_count(self) -> int: """The number of documents deleted.""" self._raise_if_unacknowledged("deleted_count") return self.__raw_result.get("n", 0) @@ -158,7 +159,7 @@ class BulkWriteResult(_WriteResult): __slots__ = ("__bulk_api_result", "__acknowledged") - def __init__(self, bulk_api_result, acknowledged): + def __init__(self, bulk_api_result: Dict[str, Any], acknowledged: bool) -> None: """Create a BulkWriteResult instance. :Parameters: @@ -171,44 +172,45 @@ class BulkWriteResult(_WriteResult): super(BulkWriteResult, self).__init__(acknowledged) @property - def bulk_api_result(self): + def bulk_api_result(self) -> Dict[str, Any]: """The raw bulk API result.""" return self.__bulk_api_result @property - def inserted_count(self): + def inserted_count(self) -> int: """The number of documents inserted.""" self._raise_if_unacknowledged("inserted_count") - return self.__bulk_api_result.get("nInserted") + return cast(int, self.__bulk_api_result.get("nInserted")) @property - def matched_count(self): + def matched_count(self) -> int: """The number of documents matched for an update.""" self._raise_if_unacknowledged("matched_count") - return self.__bulk_api_result.get("nMatched") + return cast(int, self.__bulk_api_result.get("nMatched")) @property - def modified_count(self): + def modified_count(self) -> int: """The number of documents modified.""" self._raise_if_unacknowledged("modified_count") - return self.__bulk_api_result.get("nModified") + return cast(int, self.__bulk_api_result.get("nModified")) @property - def deleted_count(self): + def deleted_count(self) -> int: """The number of documents deleted.""" self._raise_if_unacknowledged("deleted_count") - return self.__bulk_api_result.get("nRemoved") + return cast(int, self.__bulk_api_result.get("nRemoved")) @property - def upserted_count(self): + def upserted_count(self) -> int: """The number of documents upserted.""" self._raise_if_unacknowledged("upserted_count") - return self.__bulk_api_result.get("nUpserted") + return cast(int, self.__bulk_api_result.get("nUpserted")) @property - def upserted_ids(self): + def upserted_ids(self) -> Optional[Dict[int, Any]]: """A map of operation index to the _id of the upserted document.""" self._raise_if_unacknowledged("upserted_ids") if self.__bulk_api_result: return dict((upsert["index"], upsert["_id"]) for upsert in self.bulk_api_result["upserted"]) + return None diff --git a/pymongo/saslprep.py b/pymongo/saslprep.py index 08a780c05..99445b06f 100644 --- a/pymongo/saslprep.py +++ b/pymongo/saslprep.py @@ -13,13 +13,13 @@ # limitations under the License. """An implementation of RFC4013 SASLprep.""" - +from typing import Any, Optional try: import stringprep except ImportError: HAVE_STRINGPREP = False - def saslprep(data): + def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> str: """SASLprep dummy""" if isinstance(data, str): raise TypeError( @@ -29,6 +29,7 @@ except ImportError: else: HAVE_STRINGPREP = True import unicodedata + # RFC4013 section 2.3 prohibited output. _PROHIBITED = ( # A strict reading of RFC 4013 requires table c12 here, but @@ -44,7 +45,7 @@ else: stringprep.in_table_c8, stringprep.in_table_c9) - def saslprep(data, prohibit_unassigned_code_points=True): + def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> str: """An implementation of RFC4013 SASLprep. :Parameters: @@ -60,6 +61,8 @@ else: :Returns: The SASLprep'ed version of `data`. """ + prohibited: Any + if not isinstance(data, str): return data diff --git a/pymongo/server.py b/pymongo/server.py index cb9442d00..74093b05e 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -17,11 +17,10 @@ from datetime import datetime from bson import _decode_all_selective - from pymongo.errors import NotPrimaryError, OperationFailure from pymongo.helpers import _check_command_response from pymongo.message import _convert_exception, _OpMsg -from pymongo.response import Response, PinnedResponse +from pymongo.response import PinnedResponse, Response from pymongo.server_type import SERVER_TYPE _CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}} @@ -59,6 +58,8 @@ class Server(object): Reconnect with open(). """ if self._publish: + assert self._listener is not None + assert self._events is not None self._events.put((self._listener.publish_server_closed, (self._description.address, self._topology_id))) self._monitor.close() @@ -169,6 +170,8 @@ class Server(object): docs = _decode_all_selective( decrypted, operation.codec_options, user_fields) + response: Response + if client._should_pin_cursor(operation.session) or operation.exhaust: sock_info.pin_cursor() if isinstance(reply, _OpMsg): diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 2cbf6d63c..0a9b79916 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -15,10 +15,13 @@ """Represent one server the driver is connected to.""" import time +from typing import Any, Dict, Mapping, Optional, Set, Tuple, cast from bson import EPOCH_NAIVE -from pymongo.server_type import SERVER_TYPE +from bson.objectid import ObjectId from pymongo.hello import Hello +from pymongo.server_type import SERVER_TYPE +from pymongo.typings import _Address class ServerDescription(object): @@ -41,11 +44,12 @@ class ServerDescription(object): '_topology_version') def __init__( - self, - address, - hello=None, - round_trip_time=None, - error=None): + self, + address: _Address, + hello: Optional[Hello] = None, + round_trip_time: Optional[float] = None, + error: Optional[Exception] = None, + ) -> None: self._address = address if not hello: hello = Hello({}) @@ -72,9 +76,11 @@ class ServerDescription(object): self._error = error self._topology_version = hello.topology_version if error: - if hasattr(error, 'details') and isinstance(error.details, dict): - self._topology_version = error.details.get('topologyVersion') + details = getattr(error, 'details', None) + if isinstance(details, dict): + self._topology_version = details.get('topologyVersion') + self._last_write_date: Optional[float] if hello.last_write_date: # Convert from datetime to seconds. delta = hello.last_write_date - EPOCH_NAIVE @@ -83,17 +89,17 @@ class ServerDescription(object): self._last_write_date = None @property - def address(self): + def address(self) -> _Address: """The address (host, port) of this server.""" return self._address @property - def server_type(self): + def server_type(self) -> int: """The type of this server.""" return self._server_type @property - def server_type_name(self): + def server_type_name(self) -> str: """The server type as a human readable string. .. versionadded:: 3.4 @@ -101,78 +107,78 @@ class ServerDescription(object): return SERVER_TYPE._fields[self._server_type] @property - def all_hosts(self): + def all_hosts(self) -> Set[Tuple[str, int]]: """List of hosts, passives, and arbiters known to this server.""" return self._all_hosts @property - def tags(self): + def tags(self) -> Mapping[str, Any]: return self._tags @property - def replica_set_name(self): + def replica_set_name(self) -> Optional[str]: """Replica set name or None.""" return self._replica_set_name @property - def primary(self): + def primary(self) -> Optional[Tuple[str, int]]: """This server's opinion about who the primary is, or None.""" return self._primary @property - def max_bson_size(self): + def max_bson_size(self) -> int: return self._max_bson_size @property - def max_message_size(self): + def max_message_size(self) -> int: return self._max_message_size @property - def max_write_batch_size(self): + def max_write_batch_size(self) -> int: return self._max_write_batch_size @property - def min_wire_version(self): + def min_wire_version(self) -> int: return self._min_wire_version @property - def max_wire_version(self): + def max_wire_version(self) -> int: return self._max_wire_version @property - def set_version(self): + def set_version(self) -> Optional[int]: return self._set_version @property - def election_id(self): + def election_id(self) -> Optional[ObjectId]: return self._election_id @property - def cluster_time(self): + def cluster_time(self)-> Optional[Mapping[str, Any]]: return self._cluster_time @property - def election_tuple(self): + def election_tuple(self) -> Tuple[Optional[int], Optional[ObjectId]]: return self._set_version, self._election_id @property - def me(self): + def me(self) -> Optional[Tuple[str, int]]: return self._me @property - def logical_session_timeout_minutes(self): + def logical_session_timeout_minutes(self) -> Optional[int]: return self._ls_timeout_minutes @property - def last_write_date(self): + def last_write_date(self) -> Optional[float]: return self._last_write_date @property - def last_update_time(self): + def last_update_time(self) -> float: return self._last_update_time @property - def round_trip_time(self): + def round_trip_time(self) -> Optional[float]: """The current average latency or None.""" # This override is for unittesting only! if self._address in self._host_to_round_trip_time: @@ -181,28 +187,28 @@ class ServerDescription(object): return self._round_trip_time @property - def error(self): + def error(self) -> Optional[Exception]: """The last error attempting to connect to the server, or None.""" return self._error @property - def is_writable(self): + def is_writable(self) -> bool: return self._is_writable @property - def is_readable(self): + def is_readable(self) -> bool: return self._is_readable @property - def mongos(self): + def mongos(self) -> bool: return self._server_type == SERVER_TYPE.Mongos @property - def is_server_type_known(self): + def is_server_type_known(self) -> bool: return self.server_type != SERVER_TYPE.Unknown @property - def retryable_writes_supported(self): + def retryable_writes_supported(self) -> bool: """Checks if this server supports retryable writes.""" return (( self._ls_timeout_minutes is not None and @@ -210,20 +216,20 @@ class ServerDescription(object): or self._server_type == SERVER_TYPE.LoadBalancer) @property - def retryable_reads_supported(self): + def retryable_reads_supported(self) -> bool: """Checks if this server supports retryable writes.""" return self._max_wire_version >= 6 @property - def topology_version(self): + def topology_version(self) -> Optional[Mapping[str, Any]]: return self._topology_version - def to_unknown(self, error=None): + def to_unknown(self, error: Optional[Exception] = None) -> "ServerDescription": unknown = ServerDescription(self.address, error=error) unknown._topology_version = self.topology_version return unknown - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, ServerDescription): return ((self._address == other.address) and (self._server_type == other.server_type) and @@ -242,7 +248,7 @@ class ServerDescription(object): return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other def __repr__(self): @@ -254,4 +260,4 @@ class ServerDescription(object): self.round_trip_time, errmsg) # For unittesting only. Use under no circumstances! - _host_to_round_trip_time = {} + _host_to_round_trip_time: Dict = {} diff --git a/pymongo/server_type.py b/pymongo/server_type.py index 101f9dba4..ee53b6b97 100644 --- a/pymongo/server_type.py +++ b/pymongo/server_type.py @@ -14,10 +14,19 @@ """Type codes for MongoDB servers.""" -from collections import namedtuple +from typing import NamedTuple -SERVER_TYPE = namedtuple('ServerType', - ['Unknown', 'Mongos', 'RSPrimary', 'RSSecondary', - 'RSArbiter', 'RSOther', 'RSGhost', - 'Standalone', 'LoadBalancer'])(*range(9)) +class _ServerType(NamedTuple): + Unknown: int + Mongos: int + RSPrimary: int + RSSecondary: int + RSArbiter: int + RSOther: int + RSGhost: int + Standalone: int + LoadBalancer: int + + +SERVER_TYPE = _ServerType(*range(9)) diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index 48f168be4..9eb3d5f08 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -16,7 +16,9 @@ import errno import select +import socket import sys +from typing import Any, Optional # PYTHON-2320: Jython does not fully support poll on SSL sockets, # https://bugs.jython.org/issue2900 @@ -34,17 +36,19 @@ def _errno_from_exception(exc): class SocketChecker(object): - def __init__(self): + def __init__(self) -> None: + self._poller: Optional[select.poll] if _HAVE_POLL: self._poller = select.poll() else: self._poller = None - def select(self, sock, read=False, write=False, timeout=0): + def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool: """Select for reads or writes with a timeout in seconds (or None). Returns True if the socket is readable/writable, False on timeout. """ + res: Any while True: try: if self._poller: @@ -74,12 +78,12 @@ class SocketChecker(object): # ready: subsets of the first three arguments. Return # True if any of the lists are not empty. return any(res) - except (_SelectError, IOError) as exc: + except (_SelectError, IOError) as exc: # type: ignore if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN): continue raise - def socket_closed(self, sock): + def socket_closed(self, sock: Any) -> bool: """Return True if we know socket has been closed, False otherwise. """ try: diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index 69e075aec..d9ee7b7c8 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -26,6 +26,7 @@ except ImportError: from pymongo.common import CONNECT_TIMEOUT from pymongo.errors import ConfigurationError + # dnspython can return bytes or str from various parts # of its API depending on version. We always want str. def maybe_decode(text): @@ -38,7 +39,7 @@ def maybe_decode(text): def _resolve(*args, **kwargs): if hasattr(resolver, 'resolve'): # dnspython >= 2 - return resolver.resolve(*args, **kwargs) + return resolver.resolve(*args, **kwargs) # type: ignore # dnspython 1.X return resolver.query(*args, **kwargs) diff --git a/pymongo/ssl_context.py b/pymongo/ssl_context.py index 2f35676f8..e54610514 100644 --- a/pymongo/ssl_context.py +++ b/pymongo/ssl_context.py @@ -32,6 +32,7 @@ IS_PYOPENSSL = False SSLError = _ssl.SSLError from ssl import SSLContext + if hasattr(_ssl, "VERIFY_CRL_CHECK_LEAF"): from ssl import VERIFY_CRL_CHECK_LEAF # Python 3.7 uses OpenSSL's hostname matching implementation diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index 5826f9580..b3428197b 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -24,7 +24,7 @@ try: import pymongo.pyopenssl_context as _ssl except ImportError: try: - import pymongo.ssl_context as _ssl + import pymongo.ssl_context as _ssl # type: ignore[no-redef] except ImportError: HAVE_SSL = False @@ -74,7 +74,7 @@ if HAVE_SSL: raise ConfigurationError( "tlsCRLFile cannot be used with PyOpenSSL") # Match the server's behavior. - ctx.verify_flags = getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0) + setattr(ctx, 'verify_flags', getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0)) ctx.load_verify_locations(crlfile) if ca_certs is not None: ctx.load_verify_locations(ca_certs) @@ -83,11 +83,11 @@ if HAVE_SSL: ctx.verify_mode = verify_mode return ctx else: - class SSLError(Exception): + class SSLError(Exception): # type: ignore pass HAS_SNI = False IPADDR_SAFE = False - def get_ssl_context(*dummy): + def get_ssl_context(*dummy): # type: ignore """No ssl module, raise ConfigurationError.""" raise ConfigurationError("The ssl module is not available.") diff --git a/pymongo/topology.py b/pymongo/topology.py index 021a1dee6..b2d31ed31 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -21,35 +21,27 @@ import threading import time import warnings import weakref +from typing import Any -from pymongo import (common, - helpers, - periodic_executor) +from pymongo import common, helpers, periodic_executor from pymongo.client_session import _ServerSessionPool -from pymongo.errors import (ConnectionFailure, - ConfigurationError, - NetworkTimeout, - NotPrimaryError, - OperationFailure, - PyMongoError, - ServerSelectionTimeoutError, - WriteError, - InvalidOperation) +from pymongo.errors import (ConfigurationError, ConnectionFailure, + InvalidOperation, NetworkTimeout, NotPrimaryError, + OperationFailure, PyMongoError, + ServerSelectionTimeoutError, WriteError) from pymongo.hello import Hello from pymongo.monitor import SrvMonitor from pymongo.pool import PoolOptions from pymongo.server import Server from pymongo.server_description import ServerDescription -from pymongo.server_selectors import (any_server_selector, +from pymongo.server_selectors import (Selection, any_server_selector, arbiter_server_selector, - secondary_server_selector, readable_server_selector, - writable_server_selector, - Selection) -from pymongo.topology_description import (updated_topology_description, - _updated_topology_description_srv_polling, - TopologyDescription, - SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE) + secondary_server_selector, + writable_server_selector) +from pymongo.topology_description import ( + SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE, TopologyDescription, + _updated_topology_description_srv_polling, updated_topology_description) def process_events_queue(queue_ref): @@ -80,12 +72,13 @@ class Topology(object): # Create events queue if there are publishers. self._events = None - self.__events_executor = None + self.__events_executor: Any = None if self._publish_server or self._publish_tp: self._events = queue.Queue(maxsize=100) if self._publish_tp: + assert self._events is not None self._events.put((self._listeners.publish_topology_opened, (self._topology_id,))) self._settings = topology_settings @@ -99,6 +92,7 @@ class Topology(object): self._description = topology_description if self._publish_tp: + assert self._events is not None initial_td = TopologyDescription(TOPOLOGY_TYPE.Unknown, {}, None, None, None, self._settings) self._events.put(( @@ -107,6 +101,7 @@ class Topology(object): for seed in topology_settings.seeds: if self._publish_server: + assert self._events is not None self._events.put((self._listeners.publish_server_opened, (seed, self._topology_id))) @@ -296,6 +291,7 @@ class Topology(object): suppress_event = ((self._publish_server or self._publish_tp) and sd_old == server_description) if self._publish_server and not suppress_event: + assert self._events is not None self._events.put(( self._listeners.publish_server_description_changed, (sd_old, server_description, @@ -306,6 +302,7 @@ class Topology(object): self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: + assert self._events is not None self._events.put(( self._listeners.publish_topology_description_changed, (td_old, self._description, self._topology_id))) @@ -354,6 +351,7 @@ class Topology(object): self._update_servers() if self._publish_tp: + assert self._events is not None self._events.put(( self._listeners.publish_topology_description_changed, (td_old, self._description, self._topology_id))) @@ -485,6 +483,7 @@ class Topology(object): # Publish only after releasing the lock. if self._publish_tp: + assert self._events is not None self._events.put((self._listeners.publish_topology_closed, (self._topology_id,))) if self._publish_server or self._publish_tp: diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index c13d00a64..241ef5afb 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -14,34 +14,48 @@ """Represent a deployment of MongoDB servers.""" -from collections import namedtuple from random import sample +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple +from bson.objectid import ObjectId from pymongo import common from pymongo.errors import ConfigurationError -from pymongo.read_preferences import ReadPreference, _AggWritePref +from pymongo.read_preferences import ReadPreference, _AggWritePref, _ServerMode from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection from pymongo.server_type import SERVER_TYPE +from pymongo.typings import _Address # Enumeration for various kinds of MongoDB cluster topologies. -TOPOLOGY_TYPE = namedtuple('TopologyType', [ - 'Single', 'ReplicaSetNoPrimary', 'ReplicaSetWithPrimary', 'Sharded', - 'Unknown', 'LoadBalanced'])(*range(6)) +class _TopologyType(NamedTuple): + Single: int + ReplicaSetNoPrimary: int + ReplicaSetWithPrimary: int + Sharded: int + Unknown: int + LoadBalanced: int + + +TOPOLOGY_TYPE = _TopologyType(*range(6)) # Topologies compatible with SRV record polling. -SRV_POLLING_TOPOLOGIES = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) +SRV_POLLING_TOPOLOGIES: Tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) + + +_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] class TopologyDescription(object): - def __init__(self, - topology_type, - server_descriptions, - replica_set_name, - max_set_version, - max_election_id, - topology_settings): + def __init__( + self, + topology_type: int, + server_descriptions: Dict[_Address, ServerDescription], + replica_set_name: Optional[str], + max_set_version: Optional[int], + max_election_id: Optional[ObjectId], + topology_settings: Any, + ) -> None: """Representation of a deployment of MongoDB servers. :Parameters: @@ -81,7 +95,7 @@ class TopologyDescription(object): for s in readable_servers): self._ls_timeout_minutes = None else: - self._ls_timeout_minutes = min(s.logical_session_timeout_minutes + self._ls_timeout_minutes = min(s.logical_session_timeout_minutes # type: ignore for s in readable_servers) def _init_incompatible_err(self): @@ -104,23 +118,23 @@ class TopologyDescription(object): if server_too_new: self._incompatible_err = ( - "Server at %s:%d requires wire version %d, but this " + "Server at %s:%d requires wire version %d, but this " # type: ignore "version of PyMongo only supports up to %d." - % (s.address[0], s.address[1], + % (s.address[0], s.address[1] or 0, s.min_wire_version, common.MAX_SUPPORTED_WIRE_VERSION)) elif server_too_old: self._incompatible_err = ( - "Server at %s:%d reports wire version %d, but this " + "Server at %s:%d reports wire version %d, but this " # type: ignore "version of PyMongo requires at least %d (MongoDB %s)." - % (s.address[0], s.address[1], + % (s.address[0], s.address[1] or 0, s.max_wire_version, common.MIN_SUPPORTED_WIRE_VERSION, common.MIN_SUPPORTED_SERVER_VERSION)) break - def check_compatible(self): + def check_compatible(self) -> None: """Raise ConfigurationError if any server is incompatible. A server is incompatible if its wire protocol version range does not @@ -129,15 +143,15 @@ class TopologyDescription(object): if self._incompatible_err: raise ConfigurationError(self._incompatible_err) - def has_server(self, address): + def has_server(self, address: _Address) -> bool: return address in self._server_descriptions - def reset_server(self, address): + def reset_server(self, address: _Address) -> "TopologyDescription": """A copy of this description, with one server marked Unknown.""" unknown_sd = self._server_descriptions[address].to_unknown() return updated_topology_description(self, unknown_sd) - def reset(self): + def reset(self) -> "TopologyDescription": """A copy of this description, with all servers marked Unknown.""" if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary @@ -156,18 +170,18 @@ class TopologyDescription(object): self._max_election_id, self._topology_settings) - def server_descriptions(self): + def server_descriptions(self) -> Dict[_Address, ServerDescription]: """Dict of (address, :class:`~pymongo.server_description.ServerDescription`).""" return self._server_descriptions.copy() @property - def topology_type(self): + def topology_type(self) -> int: """The type of this topology.""" return self._topology_type @property - def topology_type_name(self): + def topology_type_name(self) -> str: """The topology type as a human readable string. .. versionadded:: 3.4 @@ -175,44 +189,44 @@ class TopologyDescription(object): return TOPOLOGY_TYPE._fields[self._topology_type] @property - def replica_set_name(self): + def replica_set_name(self) -> Optional[str]: """The replica set name.""" return self._replica_set_name @property - def max_set_version(self): + def max_set_version(self) -> Optional[int]: """Greatest setVersion seen from a primary, or None.""" return self._max_set_version @property - def max_election_id(self): + def max_election_id(self) -> Optional[ObjectId]: """Greatest electionId seen from a primary, or None.""" return self._max_election_id @property - def logical_session_timeout_minutes(self): + def logical_session_timeout_minutes(self) -> Optional[int]: """Minimum logical session timeout, or None.""" return self._ls_timeout_minutes @property - def known_servers(self): + def known_servers(self) -> List[ServerDescription]: """List of Servers of types besides Unknown.""" return [s for s in self._server_descriptions.values() if s.is_server_type_known] @property - def has_known_servers(self): + def has_known_servers(self) -> bool: """Whether there are any Servers of types besides Unknown.""" return any(s for s in self._server_descriptions.values() if s.is_server_type_known) @property - def readable_servers(self): + def readable_servers(self) -> List[ServerDescription]: """List of readable Servers.""" return [s for s in self._server_descriptions.values() if s.is_readable] @property - def common_wire_version(self): + def common_wire_version(self) -> Optional[int]: """Minimum of all servers' max wire versions, or None.""" servers = self.known_servers if servers: @@ -221,11 +235,11 @@ class TopologyDescription(object): return None @property - def heartbeat_frequency(self): + def heartbeat_frequency(self) -> int: return self._topology_settings.heartbeat_frequency @property - def srv_max_hosts(self): + def srv_max_hosts(self) -> int: return self._topology_settings._srv_max_hosts def _apply_local_threshold(self, selection): @@ -238,7 +252,12 @@ class TopologyDescription(object): return [s for s in selection.server_descriptions if (s.round_trip_time - fastest) <= threshold] - def apply_selector(self, selector, address=None, custom_selector=None): + def apply_selector( + self, + selector: Any, + address: Optional[_Address] = None, + custom_selector: Optional[_ServerSelector] = None + ) -> List[ServerDescription]: """List of servers matching the provided selector(s). :Parameters: @@ -288,7 +307,7 @@ class TopologyDescription(object): custom_selector(selection.server_descriptions)) return self._apply_local_threshold(selection) - def has_readable_server(self, read_preference=ReadPreference.PRIMARY): + def has_readable_server(self, read_preference: _ServerMode =ReadPreference.PRIMARY) -> bool: """Does this topology have any readable servers available matching the given read preference? @@ -305,7 +324,7 @@ class TopologyDescription(object): common.validate_read_preference("read_preference", read_preference) return any(self.apply_selector(read_preference)) - def has_writable_server(self): + def has_writable_server(self) -> bool: """Does this topology have a writable server available? .. note:: When connected directly to a single server this method @@ -336,7 +355,9 @@ _SERVER_TYPE_TO_TOPOLOGY_TYPE = { } -def updated_topology_description(topology_description, server_description): +def updated_topology_description( + topology_description: TopologyDescription, server_description: ServerDescription +) -> "TopologyDescription": """Return an updated copy of a TopologyDescription. :Parameters: diff --git a/pymongo/typings.py b/pymongo/typings.py new file mode 100644 index 000000000..ae5aec321 --- /dev/null +++ b/pymongo/typings.py @@ -0,0 +1,29 @@ +# Copyright 2022-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Type aliases used by PyMongo""" +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional, + Tuple, Type, TypeVar, Union) + +if TYPE_CHECKING: + from bson.raw_bson import RawBSONDocument + from pymongo.collation import Collation + + +# Common Shared Types. +_Address = Tuple[str, Optional[int]] +_CollationIn = Union[Mapping[str, Any], "Collation"] +_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"] +_Pipeline = List[Mapping[str, Any]] +_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any]) diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 8c43d5177..c213f4217 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -15,20 +15,19 @@ """Tools to parse and validate a MongoDB URI.""" import re -import warnings import sys - +import warnings +from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, + Union, cast) from urllib.parse import unquote_plus from pymongo.client_options import _parse_ssl_options -from pymongo.common import ( - SRV_SERVICE_NAME, - get_validated_options, INTERNAL_URI_OPTION_NAME_MAP, - URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary) +from pymongo.common import (INTERNAL_URI_OPTION_NAME_MAP, SRV_SERVICE_NAME, + URI_OPTIONS_DEPRECATION_MAP, + _CaseInsensitiveDictionary, get_validated_options) from pymongo.errors import ConfigurationError, InvalidURI from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver - SCHEME = 'mongodb://' SCHEME_LEN = len(SCHEME) SRV_SCHEME = 'mongodb+srv://' @@ -52,7 +51,7 @@ def _unquoted_percent(s): return True return False -def parse_userinfo(userinfo): +def parse_userinfo(userinfo: str) -> Tuple[str, str]: """Validates the format of user information in a MongoDB URI. Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", "]", "@") as per RFC 3986 must be escaped. @@ -76,7 +75,7 @@ def parse_userinfo(userinfo): return unquote_plus(user), unquote_plus(passwd) -def parse_ipv6_literal_host(entity, default_port): +def parse_ipv6_literal_host(entity: str, default_port: Optional[int]) -> Tuple[str, Optional[Union[str, int]]]: """Validates an IPv6 literal host:port string. Returns a 2-tuple of IPv6 literal followed by port where @@ -98,7 +97,7 @@ def parse_ipv6_literal_host(entity, default_port): return entity[1: i], entity[i + 2:] -def parse_host(entity, default_port=DEFAULT_PORT): +def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> Tuple[str, Optional[int]]: """Validates a host string Returns a 2-tuple of host followed by port where port is default_port @@ -111,7 +110,7 @@ def parse_host(entity, default_port=DEFAULT_PORT): specified in entity. """ host = entity - port = default_port + port: Optional[Union[str, int]] = default_port if entity[0] == '[': host, port = parse_ipv6_literal_host(entity, default_port) elif entity.endswith(".sock"): @@ -279,7 +278,7 @@ def _normalize_options(options): return options -def validate_options(opts, warn=False): +def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: """Validates and normalizes options passed in a MongoDB URI. Returns a new dictionary of validated and normalized options. If warn is @@ -295,7 +294,7 @@ def validate_options(opts, warn=False): return get_validated_options(opts, warn) -def split_options(opts, validate=True, warn=False, normalize=True): +def split_options(opts: str, validate: bool = True, warn: bool = False, normalize: bool = True) -> MutableMapping[str, Any]: """Takes the options portion of a MongoDB URI, validates each option and returns the options in a dictionary. @@ -340,7 +339,7 @@ def split_options(opts, validate=True, warn=False, normalize=True): return options -def split_hosts(hosts, default_port=DEFAULT_PORT): +def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> List[Tuple[str, Optional[int]]]: """Takes a string of the form host1[:port],host2[:port]... and splits it into (host, port) tuples. If [:port] isn't present the default_port is used. @@ -393,9 +392,16 @@ def _check_options(nodes, options): 'Cannot specify replicaSet with loadBalanced=true') -def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, - normalize=True, connect_timeout=None, srv_service_name=None, - srv_max_hosts=None): +def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None +) -> Dict[str, Any]: """Parse and validate a MongoDB URI. Returns a dict of the form:: diff --git a/pymongo/write_concern.py b/pymongo/write_concern.py index 2075240f0..5168948ee 100644 --- a/pymongo/write_concern.py +++ b/pymongo/write_concern.py @@ -14,6 +14,8 @@ """Tools for working with write concerns.""" +from typing import Any, Dict, Optional, Union + from pymongo.errors import ConfigurationError @@ -45,8 +47,8 @@ class WriteConcern(object): __slots__ = ("__document", "__acknowledged", "__server_default") - def __init__(self, w=None, wtimeout=None, j=None, fsync=None): - self.__document = {} + def __init__(self, w: Optional[Union[int, str]] = None, wtimeout: Optional[int] = None, j: Optional[bool] = None, fsync: Optional[bool] = None) -> None: + self.__document: Dict[str, Any] = {} self.__acknowledged = True if wtimeout is not None: @@ -84,12 +86,12 @@ class WriteConcern(object): self.__server_default = not self.__document @property - def is_server_default(self): + def is_server_default(self) -> bool: """Does this WriteConcern match the server default.""" return self.__server_default @property - def document(self): + def document(self) -> Dict[str, Any]: """The document representation of this write concern. .. note:: @@ -99,7 +101,7 @@ class WriteConcern(object): return self.__document.copy() @property - def acknowledged(self): + def acknowledged(self) -> bool: """If ``True`` write operations will wait for acknowledgement before returning. """ @@ -109,12 +111,12 @@ class WriteConcern(object): return ("WriteConcern(%s)" % ( ", ".join("%s=%s" % kvt for kvt in self.__document.items()),)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, WriteConcern): return self.__document == other.document return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: if isinstance(other, WriteConcern): return self.__document != other.document return NotImplemented diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index b752453f1..84c6baf60 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -25,7 +25,7 @@ import warnings try: import simplejson as json except ImportError: - import json # type: ignore + import json # type: ignore[no-redef] sys.path[0:0] = [""] diff --git a/test/test_cursor.py b/test/test_cursor.py index 8bea12228..0b8ba049c 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -889,7 +889,7 @@ class TestCursor(IntegrationTest): # Every attribute should be the same. cursor2 = cursor.clone() - self.assertEqual(cursor.__dict__, cursor2.__dict__) + self.assertDictEqual(cursor.__dict__, cursor2.__dict__) # Shallow copies can so can mutate cursor2 = copy.copy(cursor) diff --git a/test/test_grid_file.py b/test/test_grid_file.py index a53e40c4c..6d7cc7ba3 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -238,7 +238,7 @@ class TestGridFile(IntegrationTest): cursor_dict.pop('_Cursor__session') cursor_clone_dict = cursor_clone.__dict__.copy() cursor_clone_dict.pop('_Cursor__session') - self.assertEqual(cursor_dict, cursor_clone_dict) + self.assertDictEqual(cursor_dict, cursor_clone_dict) self.assertRaises(NotImplementedError, cursor.add_option, 0) self.assertRaises(NotImplementedError, cursor.remove_option, 0) diff --git a/tools/clean.py b/tools/clean.py index 55896781a..53729d640 100644 --- a/tools/clean.py +++ b/tools/clean.py @@ -33,7 +33,7 @@ except: pass try: - from pymongo import _cmessage # type: ignore + from pymongo import _cmessage # type: ignore[attr-defined] sys.exit("could still import _cmessage") except ImportError: pass