PYTHON-3060 Add typings to pymongo package (#831)

This commit is contained in:
Steven Silvester 2022-02-02 21:12:36 -06:00 committed by GitHub
parent abfa0d35bc
commit dd6c140d43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 1578 additions and 1099 deletions

View File

@ -61,10 +61,10 @@ import re
import struct
import sys
import uuid
from codecs import utf_8_decode as _utf_8_decode # type: ignore
from codecs import utf_8_encode as _utf_8_encode # type: ignore
from codecs import utf_8_decode as _utf_8_decode # type: ignore[attr-defined]
from codecs import utf_8_encode as _utf_8_encode # type: ignore[attr-defined]
from collections import abc as _abc
from typing import (TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Generator,
from typing import (IO, TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Generator,
Iterator, List, Mapping, MutableMapping, NoReturn,
Sequence, Tuple, Type, TypeVar, Union, cast)
@ -88,11 +88,13 @@ from bson.tz_util import utc
# Import RawBSONDocument for type-checking only to avoid circular dependency.
if TYPE_CHECKING:
from array import array
from mmap import mmap
from bson.raw_bson import RawBSONDocument
try:
from bson import _cbson # type: ignore
from bson import _cbson # type: ignore[attr-defined]
_USE_C = True
except ImportError:
_USE_C = False
@ -851,6 +853,7 @@ _CODEC_OPTIONS_TYPE_ERROR = TypeError(
_DocumentIn = Mapping[str, Any]
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"]
def encode(document: _DocumentIn, check_keys: bool = False, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> bytes:
@ -880,7 +883,7 @@ def encode(document: _DocumentIn, check_keys: bool = False, codec_options: Codec
return _dict_to_bson(document, check_keys, codec_options)
def decode(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> _DocumentOut:
def decode(data: _ReadableBuffer, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Dict[str, Any]:
"""Decode BSON to a document.
By default, returns a BSON document represented as a Python
@ -912,7 +915,7 @@ def decode(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) ->
return _bson_to_dict(data, codec_options)
def decode_all(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> List[_DocumentOut]:
def decode_all(data: _ReadableBuffer, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> List[Dict[str, Any]]:
"""Decode BSON data to multiple documents.
`data` must be a bytes-like object implementing the buffer protocol that
@ -1075,7 +1078,7 @@ def decode_iter(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS
yield _bson_to_dict(elements, codec_options)
def decode_file_iter(file_obj: BinaryIO, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Iterator[_DocumentOut]:
def decode_file_iter(file_obj: Union[BinaryIO, IO], codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Iterator[_DocumentOut]:
"""Decode bson data from a file to multiple documents as a generator.
Works similarly to the decode_all function, but reads from the file object
@ -1158,7 +1161,7 @@ class BSON(bytes):
"""
return cls(encode(document, check_keys, codec_options))
def decode(self, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> _DocumentOut: # type: ignore[override]
def decode(self, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Dict[str, Any]: # type: ignore[override]
"""Decode this BSON data.
By default, returns a BSON document represented as a Python

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple, Type
from typing import Any, Tuple, Type, Union, TYPE_CHECKING
from uuid import UUID
"""Tools for representing BSON binary data.
@ -57,6 +57,11 @@ by :mod:`bson` using this subtype when using
"""
if TYPE_CHECKING:
from array import array as _array
from mmap import mmap as _mmap
class UuidRepresentation:
UNSPECIFIED = 0
"""An unspecified UUID representation.
@ -211,7 +216,7 @@ class Binary(bytes):
_type_marker = 5
__subtype: int
def __new__(cls: Type["Binary"], data: bytes, subtype: int = BINARY_SUBTYPE) -> "Binary":
def __new__(cls: Type["Binary"], data: Union[memoryview, bytes, "_mmap", "_array"], subtype: int = BINARY_SUBTYPE) -> "Binary":
if not isinstance(subtype, int):
raise TypeError("subtype must be an instance of int")
if subtype >= 256 or subtype < 0:

View File

@ -874,10 +874,10 @@ class GridOutCursor(Cursor):
__next__ = next
def add_option(self, *args: Any, **kwargs: Any) -> None:
def add_option(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
raise NotImplementedError("Method does not exist for GridOutCursor")
def remove_option(self, *args: Any, **kwargs: Any) -> None:
def remove_option(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
raise NotImplementedError("Method does not exist for GridOutCursor")
def _clone_base(self, session: ClientSession) -> "GridOutCursor":

View File

@ -1,11 +1,33 @@
[mypy]
check_untyped_defs = true
disallow_subclassing_any = true
disallow_incomplete_defs = true
no_implicit_optional = true
pretty = true
show_error_context = true
show_error_codes = true
strict_equality = true
warn_unused_configs = true
warn_unused_ignores = true
warn_redundant_casts = true
[mypy-kerberos.*]
ignore_missing_imports = True
[mypy-mockupdb]
ignore_missing_imports = True
[mypy-pymongo_auth_aws.*]
ignore_missing_imports = True
[mypy-pymongocrypt.*]
ignore_missing_imports = True
[mypy-service_identity.*]
ignore_missing_imports = True
[mypy-snappy.*]
ignore_missing_imports = True
[mypy-winkerberos.*]
ignore_missing_imports = True

View File

@ -14,6 +14,8 @@
"""Python driver for MongoDB."""
from typing import Tuple, Union
ASCENDING = 1
"""Ascending sort order."""
DESCENDING = -1
@ -53,35 +55,33 @@ TEXT = "text"
.. _text index: http://docs.mongodb.org/manual/core/index-text/
"""
version_tuple = (4, 1, 0, '.dev0')
version_tuple: Tuple[Union[int, str], ...] = (4, 1, 0, '.dev0')
def get_version_string():
def get_version_string() -> str:
if isinstance(version_tuple[-1], str):
return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1]
return '.'.join(map(str, version_tuple))
__version__ = version = get_version_string()
__version__: str = get_version_string()
version = __version__
"""Current version of PyMongo."""
from pymongo.collection import ReturnDocument
from pymongo.common import (MIN_SUPPORTED_WIRE_VERSION,
MAX_SUPPORTED_WIRE_VERSION)
from pymongo.common import (MAX_SUPPORTED_WIRE_VERSION,
MIN_SUPPORTED_WIRE_VERSION)
from pymongo.cursor import CursorType
from pymongo.mongo_client import MongoClient
from pymongo.operations import (IndexModel,
InsertOne,
DeleteOne,
DeleteMany,
UpdateOne,
UpdateMany,
ReplaceOne)
from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, InsertOne,
ReplaceOne, UpdateMany, UpdateOne)
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
def has_c():
def has_c() -> bool:
"""Is the C extension installed?"""
try:
from pymongo import _cmessage
from pymongo import _cmessage # type: ignore[attr-defined]
return True
except ImportError:
return False

View File

@ -15,11 +15,10 @@
"""Perform aggregation operations on a collection or database."""
from bson.son import SON
from pymongo import common
from pymongo.collation import validate_collation_or_none
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import _AggWritePref, ReadPreference
from pymongo.read_preferences import ReadPreference, _AggWritePref
class _AggregationCommand(object):
@ -37,7 +36,7 @@ class _AggregationCommand(object):
self._target = target
common.validate_list('pipeline', pipeline)
pipeline = common.validate_list('pipeline', pipeline)
self._pipeline = pipeline
self._performs_write = False
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
@ -82,7 +81,6 @@ class _AggregationCommand(object):
"""The namespace in which the aggregate command is run."""
raise NotImplementedError
@property
def _cursor_collection(self, cursor_doc):
"""The Collection used for the aggregate command cursor."""
raise NotImplementedError

View File

@ -19,9 +19,9 @@ import hashlib
import hmac
import os
import socket
from base64 import standard_b64decode, standard_b64encode
from collections import namedtuple
from typing import Callable, Mapping
from urllib.parse import quote
from bson.binary import Binary
@ -97,7 +97,7 @@ GSSAPIProperties = namedtuple('GSSAPIProperties',
"""Mechanism properties for GSSAPI authentication."""
_AWSProperties = namedtuple('AWSProperties', ['aws_session_token'])
_AWSProperties = namedtuple('_AWSProperties', ['aws_session_token'])
"""Mechanism properties for MONGODB-AWS authentication."""
@ -140,9 +140,9 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
properties = extra.get('authmechanismproperties', {})
aws_session_token = properties.get('AWS_SESSION_TOKEN')
props = _AWSProperties(aws_session_token=aws_session_token)
aws_props = _AWSProperties(aws_session_token=aws_session_token)
# user can be None for temporary link-local EC2 credentials.
return MongoCredential(mech, '$external', user, passwd, props, None)
return MongoCredential(mech, '$external', user, passwd, aws_props, None)
elif mech == 'PLAIN':
source_database = source or database or '$external'
return MongoCredential(mech, source_database, user, passwd, None, None)
@ -471,7 +471,7 @@ def _authenticate_default(credentials, sock_info):
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1')
_AUTH_MAP = {
_AUTH_MAP: Mapping[str, Callable] = {
'GSSAPI': _authenticate_gssapi,
'MONGODB-CR': _authenticate_mongo_cr,
'MONGODB-X509': _authenticate_x509,
@ -532,7 +532,7 @@ class _X509Context(_AuthContext):
return cmd
_SPECULATIVE_AUTH_MAP = {
_SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
'MONGODB-X509': _X509Context,
'SCRAM-SHA-1': functools.partial(_ScramContext, mechanism='SCRAM-SHA-1'),
'SCRAM-SHA-256': functools.partial(_ScramContext,
@ -544,6 +544,6 @@ _SPECULATIVE_AUTH_MAP = {
def authenticate(credentials, sock_info):
"""Authenticate sock_info."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP.get(mechanism)
auth_func = _AUTH_MAP[mechanism]
auth_func(credentials, sock_info)

View File

@ -16,12 +16,11 @@
try:
import pymongo_auth_aws
from pymongo_auth_aws import (AwsCredential,
AwsSaslContext,
from pymongo_auth_aws import (AwsCredential, AwsSaslContext,
PyMongoAuthAwsError)
_HAVE_MONGODB_AWS = True
except ImportError:
class AwsSaslContext(object):
class AwsSaslContext(object): # type: ignore
def __init__(self, credentials):
pass
_HAVE_MONGODB_AWS = False
@ -32,7 +31,7 @@ from bson.son import SON
from pymongo.errors import ConfigurationError, OperationFailure
class _AwsSaslContext(AwsSaslContext):
class _AwsSaslContext(AwsSaslContext): # type: ignore
# Dependency injection:
def binary_type(self):
"""Return the bson.binary.Binary type."""

View File

@ -17,31 +17,23 @@
.. versionadded:: 2.7
"""
import copy
from itertools import islice
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo.client_session import _validate_session_write_concern
from pymongo.common import (validate_is_mapping,
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update)
from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc
from pymongo.collation import validate_collation_or_none
from pymongo.errors import (BulkWriteError,
ConfigurationError,
InvalidOperation,
OperationFailure)
from pymongo.message import (_INSERT, _UPDATE, _DELETE,
_randint,
_BulkWriteContext,
_EncryptedBulkWriteContext)
from pymongo.common import (validate_is_document_type, validate_is_mapping,
validate_ok_for_replace, validate_ok_for_update)
from pymongo.errors import (BulkWriteError, ConfigurationError,
InvalidOperation, OperationFailure)
from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc
from pymongo.message import (_DELETE, _INSERT, _UPDATE, _BulkWriteContext,
_EncryptedBulkWriteContext, _randint)
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
_DELETE_ALL = 0
_DELETE_ONE = 1

View File

@ -15,21 +15,20 @@
"""Watch changes on a collection, a database, or the entire cluster."""
import copy
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterator, Mapping,
Optional, Union)
from bson import _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.timestamp import Timestamp
from pymongo import common
from pymongo.aggregation import (_CollectionAggregationCommand,
_DatabaseAggregationCommand)
from pymongo.collation import validate_collation_or_none
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (ConnectionFailure,
CursorNotFound,
InvalidOperation,
OperationFailure,
PyMongoError)
from pymongo.errors import (ConnectionFailure, CursorNotFound,
InvalidOperation, OperationFailure, PyMongoError)
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
# The change streams spec considers the following server errors from the
# getMore command non-resumable. All other getMore errors are resumable.
@ -55,7 +54,14 @@ _RESUMABLE_GETMORE_ERRORS = frozenset([
])
class ChangeStream(object):
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.mongo_client import MongoClient
class ChangeStream(Generic[_DocumentType]):
"""The internal abstract base class for change stream cursors.
Should not be called directly by application developers. Use
@ -66,14 +72,22 @@ class ChangeStream(object):
.. versionadded:: 3.6
.. seealso:: The MongoDB documentation on `changeStreams <https://dochub.mongodb.org/core/changeStreams>`_.
"""
def __init__(self, target, pipeline, full_document, resume_after,
max_await_time_ms, batch_size, collation,
start_at_operation_time, session, start_after):
def __init__(
self,
target: Union["MongoClient[_DocumentType]", "Database[_DocumentType]", "Collection[_DocumentType]"],
pipeline: Optional[_Pipeline],
full_document: Optional[str],
resume_after: Optional[Mapping[str, Any]],
max_await_time_ms: Optional[int],
batch_size: Optional[int],
collation: Optional[_CollationIn],
start_at_operation_time: Optional[Timestamp],
session: Optional["ClientSession"],
start_after: Optional[Mapping[str, Any]],
) -> None:
if pipeline is None:
pipeline = []
elif not isinstance(pipeline, list):
raise TypeError("pipeline must be a list")
pipeline = common.validate_list('pipeline', pipeline)
common.validate_string_or_none('full_document', full_document)
validate_collation_or_none(collation)
common.validate_non_negative_integer_or_none("batchSize", batch_size)
@ -84,7 +98,7 @@ class ChangeStream(object):
self._decode_custom = True
# Keep the type registry so that we support encoding custom types
# in the pipeline.
self._target = target.with_options(
self._target = target.with_options( # type: ignore
codec_options=target.codec_options.with_options(
document_class=RawBSONDocument))
else:
@ -117,7 +131,7 @@ class ChangeStream(object):
def _change_stream_options(self):
"""Return the options dict for the $changeStream pipeline stage."""
options = {}
options: Dict[str, Any] = {}
if self._full_document is not None:
options['fullDocument'] = self._full_document
@ -144,7 +158,7 @@ class ChangeStream(object):
def _aggregation_pipeline(self):
"""Return the full aggregation pipeline for this ChangeStream."""
options = self._change_stream_options()
full_pipeline = [{'$changeStream': options}]
full_pipeline: list = [{'$changeStream': options}]
full_pipeline.extend(self._pipeline)
return full_pipeline
@ -197,15 +211,15 @@ class ChangeStream(object):
pass
self._cursor = self._create_cursor()
def close(self):
def close(self) -> None:
"""Close this ChangeStream."""
self._cursor.close()
def __iter__(self):
def __iter__(self) -> "ChangeStream[_DocumentType]":
return self
@property
def resume_token(self):
def resume_token(self) -> Optional[Mapping[str, Any]]:
"""The cached resume token that will be used to resume after the most
recently returned change.
@ -213,7 +227,7 @@ class ChangeStream(object):
"""
return copy.deepcopy(self._resume_token)
def next(self):
def next(self) -> _DocumentType:
"""Advance the cursor.
This method blocks until the next change document is returned or an
@ -255,7 +269,7 @@ class ChangeStream(object):
__next__ = next
@property
def alive(self):
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
@ -265,7 +279,7 @@ class ChangeStream(object):
"""
return self._cursor.alive
def try_next(self):
def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next change document without waiting
@ -354,14 +368,14 @@ class ChangeStream(object):
return _bson_to_dict(change.raw, self._orig_codec_options)
return change
def __enter__(self):
def __enter__(self) -> "ChangeStream":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
class CollectionChangeStream(ChangeStream):
class CollectionChangeStream(ChangeStream, Generic[_DocumentType]):
"""A change stream that watches changes on a single collection.
Should not be called directly by application developers. Use
@ -378,7 +392,7 @@ class CollectionChangeStream(ChangeStream):
return self._target.database.client
class DatabaseChangeStream(ChangeStream):
class DatabaseChangeStream(ChangeStream, Generic[_DocumentType]):
"""A change stream that watches changes on all collections in a database.
Should not be called directly by application developers. Use
@ -395,7 +409,7 @@ class DatabaseChangeStream(ChangeStream):
return self._target.client
class ClusterChangeStream(DatabaseChangeStream):
class ClusterChangeStream(DatabaseChangeStream, Generic[_DocumentType]):
"""A change stream that watches changes on all collections in the cluster.
Should not be called directly by application developers. Use

View File

@ -15,9 +15,9 @@
"""Tools to parse mongo client options."""
from bson.codec_options import _parse_codec_options
from pymongo import common
from pymongo.auth import _build_credentials_tuple
from pymongo.common import validate_boolean
from pymongo import common
from pymongo.compression_support import CompressionSettings
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _EventListeners

View File

@ -134,25 +134,23 @@ Classes
import collections
import time
import uuid
from collections.abc import Mapping as _Mapping
from typing import (TYPE_CHECKING, Any, Callable, ContextManager, Generic,
Mapping, Optional, TypeVar)
from bson.binary import Binary
from bson.int64 import Int64
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo.cursor import _SocketManager
from pymongo.errors import (ConfigurationError,
ConnectionFailure,
InvalidOperation,
OperationFailure,
PyMongoError,
from pymongo.errors import (ConfigurationError, ConnectionFailure,
InvalidOperation, OperationFailure, PyMongoError,
WTimeoutError)
from pymongo.helpers import _RETRYABLE_ERROR_CODES
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import _DocumentType
from pymongo.write_concern import WriteConcern
@ -172,10 +170,12 @@ class SessionOptions(object):
.. versionchanged:: 3.12
Added the ``snapshot`` parameter.
"""
def __init__(self,
causal_consistency=None,
default_transaction_options=None,
snapshot=False):
def __init__(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional["TransactionOptions"] = None,
snapshot: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
raise ConfigurationError('snapshot reads do not support '
@ -194,12 +194,12 @@ class SessionOptions(object):
self._snapshot = snapshot
@property
def causal_consistency(self):
def causal_consistency(self) -> bool:
"""Whether causal consistency is configured."""
return self._causal_consistency
@property
def default_transaction_options(self):
def default_transaction_options(self) -> Optional["TransactionOptions"]:
"""The default TransactionOptions to use for transactions started on
this session.
@ -208,7 +208,7 @@ class SessionOptions(object):
return self._default_transaction_options
@property
def snapshot(self):
def snapshot(self) -> Optional[bool]:
"""Whether snapshot reads are configured.
.. versionadded:: 3.12
@ -243,8 +243,13 @@ class TransactionOptions(object):
.. versionadded:: 3.7
"""
def __init__(self, read_concern=None, write_concern=None,
read_preference=None, max_commit_time_ms=None):
def __init__(
self,
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
max_commit_time_ms: Optional[int] = None
) -> None:
self._read_concern = read_concern
self._write_concern = write_concern
self._read_preference = read_preference
@ -274,23 +279,23 @@ class TransactionOptions(object):
"max_commit_time_ms must be an integer or None")
@property
def read_concern(self):
def read_concern(self) -> Optional[ReadConcern]:
"""This transaction's :class:`~pymongo.read_concern.ReadConcern`."""
return self._read_concern
@property
def write_concern(self):
def write_concern(self) -> Optional[WriteConcern]:
"""This transaction's :class:`~pymongo.write_concern.WriteConcern`."""
return self._write_concern
@property
def read_preference(self):
def read_preference(self) -> Optional[_ServerMode]:
"""This transaction's :class:`~pymongo.read_preferences.ReadPreference`.
"""
return self._read_preference
@property
def max_commit_time_ms(self):
def max_commit_time_ms(self) -> Optional[int]:
"""The maxTimeMS to use when running a commitTransaction command.
.. versionadded:: 3.9
@ -427,7 +432,13 @@ def _within_time_limit(start_time):
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
class ClientSession(object):
_T = TypeVar("_T")
if TYPE_CHECKING:
from pymongo.mongo_client import MongoClient
class ClientSession(Generic[_DocumentType]):
"""A session for ordering sequential operations.
:class:`ClientSession` instances are **not thread-safe or fork-safe**.
@ -439,9 +450,11 @@ class ClientSession(object):
:class:`ClientSession`, call
:meth:`~pymongo.mongo_client.MongoClient.start_session`.
"""
def __init__(self, client, server_session, options, implicit):
def __init__(
self, client: "MongoClient[_DocumentType]", server_session: Any, options: SessionOptions, implicit: bool
) -> None:
# A MongoClient, a _ServerSession, a SessionOptions, and a set.
self._client = client
self._client: MongoClient[_DocumentType] = client
self._server_session = server_session
self._options = options
self._cluster_time = None
@ -451,7 +464,7 @@ class ClientSession(object):
self._implicit = implicit
self._transaction = _Transaction(None, client)
def end_session(self):
def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
It is an error to use the session after the session has ended.
@ -474,39 +487,39 @@ class ClientSession(object):
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
def __enter__(self):
def __enter__(self) -> "ClientSession[_DocumentType]":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self._end_session(lock=True)
@property
def client(self):
def client(self) -> "MongoClient[_DocumentType]":
"""The :class:`~pymongo.mongo_client.MongoClient` this session was
created from.
"""
return self._client
@property
def options(self):
def options(self) -> SessionOptions:
"""The :class:`SessionOptions` this session was created with."""
return self._options
@property
def session_id(self):
def session_id(self) -> Mapping[str, Any]:
"""A BSON document, the opaque server session identifier."""
self._check_ended()
return self._server_session.session_id
@property
def cluster_time(self):
def cluster_time(self) -> Optional[Mapping[str, Any]]:
"""The cluster time returned by the last operation executed
in this session.
"""
return self._cluster_time
@property
def operation_time(self):
def operation_time(self) -> Optional[Timestamp]:
"""The operation time returned by the last operation executed
in this session.
"""
@ -522,8 +535,14 @@ class ClientSession(object):
return val
return getattr(self.client, name)
def with_transaction(self, callback, read_concern=None, write_concern=None,
read_preference=None, max_commit_time_ms=None):
def with_transaction(
self,
callback: Callable[["ClientSession"], _T],
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
max_commit_time_ms: Optional[int] = None,
) -> _T:
"""Execute a callback in a transaction.
This method starts a transaction on this session, executes ``callback``
@ -649,8 +668,13 @@ class ClientSession(object):
# Commit succeeded.
return ret
def start_transaction(self, read_concern=None, write_concern=None,
read_preference=None, max_commit_time_ms=None):
def start_transaction(
self,
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
max_commit_time_ms: Optional[int] = None,
) -> ContextManager:
"""Start a multi-statement transaction.
Takes the same arguments as :class:`TransactionOptions`.
@ -685,7 +709,7 @@ class ClientSession(object):
self._start_retryable_write()
return _TransactionContext(self)
def commit_transaction(self):
def commit_transaction(self) -> None:
"""Commit a multi-statement transaction.
.. versionadded:: 3.7
@ -729,7 +753,7 @@ class ClientSession(object):
finally:
self._transaction.state = _TxnState.COMMITTED
def abort_transaction(self):
def abort_transaction(self) -> None:
"""Abort a multi-statement transaction.
.. versionadded:: 3.7
@ -804,7 +828,7 @@ class ClientSession(object):
if cluster_time["clusterTime"] > self._cluster_time["clusterTime"]:
self._cluster_time = cluster_time
def advance_cluster_time(self, cluster_time):
def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None:
"""Update the cluster time for this session.
:Parameters:
@ -827,7 +851,7 @@ class ClientSession(object):
if operation_time > self._operation_time:
self._operation_time = operation_time
def advance_operation_time(self, operation_time):
def advance_operation_time(self, operation_time: Timestamp) -> None:
"""Update the operation time for this session.
:Parameters:
@ -856,12 +880,12 @@ class ClientSession(object):
self._transaction.recovery_token = recovery_token
@property
def has_ended(self):
def has_ended(self) -> bool:
"""True if this session is finished."""
return self._server_session is None
@property
def in_transaction(self):
def in_transaction(self) -> bool:
"""True if this session has an active multi-statement transaction.
.. versionadded:: 3.10

View File

@ -16,6 +16,7 @@
.. _collations: http://userguide.icu-project.org/collation/concepts
"""
from typing import Any, Dict, Mapping, Optional, Union
from pymongo import common
@ -151,18 +152,18 @@ class Collation(object):
__slots__ = ("__document",)
def __init__(self, locale,
caseLevel=None,
caseFirst=None,
strength=None,
numericOrdering=None,
alternate=None,
maxVariable=None,
normalization=None,
backwards=None,
**kwargs):
def __init__(self, locale: str,
caseLevel: Optional[bool] = None,
caseFirst: Optional[str] = None,
strength: Optional[int] = None,
numericOrdering: Optional[bool] = None,
alternate: Optional[str] = None,
maxVariable: Optional[str] = None,
normalization: Optional[bool] = None,
backwards: Optional[bool] = None,
**kwargs: Any) -> None:
locale = common.validate_string('locale', locale)
self.__document = {'locale': locale}
self.__document: Dict[str, Any] = {'locale': locale}
if caseLevel is not None:
self.__document['caseLevel'] = common.validate_boolean(
'caseLevel', caseLevel)
@ -190,7 +191,7 @@ class Collation(object):
self.__document.update(kwargs)
@property
def document(self):
def document(self) -> Dict[str, Any]:
"""The document representation of this collation.
.. note::
@ -204,16 +205,16 @@ class Collation(object):
return 'Collation(%s)' % (
', '.join('%s=%r' % (key, document[key]) for key in document),)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collation):
return self.document == other.document
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def validate_collation_or_none(value):
def validate_collation_or_none(value: Optional[Union[Mapping[str, Any], Collation]]) -> Optional[Dict[str, Any]]:
if value is None:
return None
if isinstance(value, Collation):

View File

@ -14,44 +14,45 @@
"""Collection level utilities for Mongo."""
import datetime
import warnings
from collections import abc
from typing import (TYPE_CHECKING, Any, Generic, Iterable, List, Mapping,
MutableMapping, Optional, Sequence, Tuple, Union)
from bson.code import Code
from bson.codec_options import CodecOptions
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from bson.codec_options import CodecOptions
from bson.son import SON
from pymongo import (common,
helpers,
message)
from bson.timestamp import Timestamp
from pymongo import common, helpers, message
from pymongo.aggregation import (_CollectionAggregationCommand,
_CollectionRawAggregationCommand)
from pymongo.bulk import _Bulk
from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor
from pymongo.collation import validate_collation_or_none
from pymongo.change_stream import CollectionChangeStream
from pymongo.collation import validate_collation_or_none
from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor
from pymongo.cursor import Cursor, RawBatchCursor
from pymongo.errors import (ConfigurationError,
InvalidName,
InvalidOperation,
from pymongo.errors import (ConfigurationError, InvalidName, InvalidOperation,
OperationFailure)
from pymongo.helpers import _check_write_command_response
from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS
from pymongo.operations import IndexModel
from pymongo.read_preferences import ReadPreference
from pymongo.results import (BulkWriteResult,
DeleteResult,
InsertOneResult,
InsertManyResult,
UpdateResult)
from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, InsertOne,
ReplaceOne, UpdateMany, UpdateOne)
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import (BulkWriteResult, DeleteResult, InsertManyResult,
InsertOneResult, UpdateResult)
from pymongo.typings import _CollationIn, _DocumentIn, _DocumentType, _Pipeline
from pymongo.write_concern import WriteConcern
_FIND_AND_MODIFY_DOC_FIELDS = {'value': 1}
_WriteOp = Union[InsertOne, DeleteOne, DeleteMany, ReplaceOne, UpdateOne, UpdateMany]
# Hint supports index name, "myIndex", or list of index pairs: [('x', 1), ('y', -1)]
_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
_IndexKeyHint = Union[str, _IndexList]
class ReturnDocument(object):
"""An enum used with
:meth:`~pymongo.collection.Collection.find_one_and_replace` and
@ -65,13 +66,28 @@ class ReturnDocument(object):
"""Return the updated/replaced or inserted document."""
class Collection(common.BaseObject):
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.database import Database
from pymongo.read_concern import ReadConcern
class Collection(common.BaseObject, Generic[_DocumentType]):
"""A Mongo collection.
"""
def __init__(self, database, name, create=False, codec_options=None,
read_preference=None, write_concern=None, read_concern=None,
session=None, **kwargs):
def __init__(
self,
database: "Database[_DocumentType]",
name: str,
create: Optional[bool] = False,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional["ReadConcern"] = None,
session: Optional["ClientSession"] = None,
**kwargs: Any,
) -> None:
"""Get / create a Mongo collection.
Raises :class:`TypeError` if `name` is not an instance of
@ -169,7 +185,7 @@ class Collection(common.BaseObject):
"null character")
collation = validate_collation_or_none(kwargs.pop('collation', None))
self.__database = database
self.__database: Database[_DocumentType] = database
self.__name = name
self.__full_name = "%s.%s" % (self.__database.name, self.__name)
if create or kwargs or collation:
@ -252,7 +268,7 @@ class Collection(common.BaseObject):
write_concern=self._write_concern_for(session),
collation=collation, session=session)
def __getattr__(self, name):
def __getattr__(self, name: str) -> "Collection[_DocumentType]":
"""Get a sub-collection of this collection by name.
Raises InvalidName if an invalid collection name is used.
@ -268,7 +284,7 @@ class Collection(common.BaseObject):
name, full_name, full_name))
return self.__getitem__(name)
def __getitem__(self, name):
def __getitem__(self, name: str) -> "Collection[_DocumentType]":
return Collection(self.__database,
"%s.%s" % (self.__name, name),
False,
@ -280,25 +296,25 @@ class Collection(common.BaseObject):
def __repr__(self):
return "Collection(%r, %r)" % (self.__database, self.__name)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, Collection):
return (self.__database == other.database and
self.__name == other.name)
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def __hash__(self):
def __hash__(self) -> int:
return hash((self.__database, self.__name))
def __bool__(self):
def __bool__(self) -> bool:
raise NotImplementedError("Collection objects do not implement truth "
"value testing or bool(). Please compare "
"with None instead: collection is not None")
@property
def full_name(self):
def full_name(self) -> str:
"""The full name of this :class:`Collection`.
The full name is of the form `database_name.collection_name`.
@ -306,19 +322,24 @@ class Collection(common.BaseObject):
return self.__full_name
@property
def name(self):
def name(self) -> str:
"""The name of this :class:`Collection`."""
return self.__name
@property
def database(self):
def database(self) -> "Database[_DocumentType]":
"""The :class:`~pymongo.database.Database` that this
:class:`Collection` is a part of.
"""
return self.__database
def with_options(self, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
def with_options(
self,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional["ReadConcern"] = None,
) -> "Collection[_DocumentType]":
"""Get a clone of this collection changing the specified settings.
>>> coll1.read_preference
@ -356,8 +377,13 @@ class Collection(common.BaseObject):
write_concern or self.write_concern,
read_concern or self.read_concern)
def bulk_write(self, requests, ordered=True,
bypass_document_validation=False, session=None):
def bulk_write(
self,
requests: Sequence[_WriteOp],
ordered: bool = True,
bypass_document_validation: bool = False,
session: Optional["ClientSession"] = None
) -> BulkWriteResult:
"""Send a batch of write operations to the server.
Requests are passed as a list of write operation instances (
@ -470,8 +496,10 @@ class Collection(common.BaseObject):
if not isinstance(doc, RawBSONDocument):
return doc.get('_id')
def insert_one(self, document, bypass_document_validation=False,
session=None):
def insert_one(self, document: _DocumentIn,
bypass_document_validation: bool = False,
session: Optional["ClientSession"] = None
) -> InsertOneResult:
"""Insert a single document.
>>> db.test.count_documents({'x': 1})
@ -520,8 +548,12 @@ class Collection(common.BaseObject):
bypass_doc_val=bypass_document_validation, session=session),
write_concern.acknowledged)
def insert_many(self, documents, ordered=True,
bypass_document_validation=False, session=None):
def insert_many(self,
documents: Iterable[_DocumentIn],
ordered: bool = True,
bypass_document_validation: bool = False,
session: Optional["ClientSession"] = None
) -> InsertManyResult:
"""Insert an iterable of documents.
>>> db.test.count_documents({})
@ -565,7 +597,7 @@ class Collection(common.BaseObject):
or isinstance(documents, abc.Mapping)
or not documents):
raise TypeError("documents must be a non-empty list")
inserted_ids = []
inserted_ids: List[ObjectId] = []
def gen():
"""A generator that validates documents and handles _ids."""
for document in documents:
@ -671,9 +703,16 @@ class Collection(common.BaseObject):
(write_concern or self.write_concern).acknowledged and not multi,
_update, session)
def replace_one(self, filter, replacement, upsert=False,
bypass_document_validation=False, collation=None,
hint=None, session=None, let=None):
def replace_one(self,
filter: Mapping[str, Any],
replacement: Mapping[str, Any],
upsert: bool = False,
bypass_document_validation: bool = False,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional["ClientSession"] = None,
let: Optional[Mapping[str, Any]] = None
) -> UpdateResult:
"""Replace a single document matching the filter.
>>> for doc in db.test.find({}):
@ -755,10 +794,17 @@ class Collection(common.BaseObject):
collation=collation, hint=hint, session=session, let=let),
write_concern.acknowledged)
def update_one(self, filter, update, upsert=False,
bypass_document_validation=False,
collation=None, array_filters=None, hint=None,
session=None, let=None):
def update_one(self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
bypass_document_validation: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional["ClientSession"] = None,
let: Optional[Mapping[str, Any]] = None
) -> UpdateResult:
"""Update a single document matching the filter.
>>> for doc in db.test.find():
@ -800,8 +846,8 @@ class Collection(common.BaseObject):
- `session` (optional): a
:class:`~pymongo.client_session.ClientSession`.
- `let` (optional): Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
:Returns:
@ -836,9 +882,17 @@ class Collection(common.BaseObject):
hint=hint, session=session, let=let),
write_concern.acknowledged)
def update_many(self, filter, update, upsert=False, array_filters=None,
bypass_document_validation=False, collation=None,
hint=None, session=None, let=None):
def update_many(self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
bypass_document_validation: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional["ClientSession"] = None,
let: Optional[Mapping[str, Any]] = None
) -> UpdateResult:
"""Update one or more documents that match the filter.
>>> for doc in db.test.find():
@ -880,8 +934,8 @@ class Collection(common.BaseObject):
- `session` (optional): a
:class:`~pymongo.client_session.ClientSession`.
- `let` (optional): Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
:Returns:
@ -916,7 +970,7 @@ class Collection(common.BaseObject):
hint=hint, session=session, let=let),
write_concern.acknowledged)
def drop(self, session=None):
def drop(self, session: Optional["ClientSession"] = None) -> None:
"""Alias for :meth:`~pymongo.database.Database.drop_collection`.
:Parameters:
@ -1005,8 +1059,13 @@ class Collection(common.BaseObject):
(write_concern or self.write_concern).acknowledged and not multi,
_delete, session)
def delete_one(self, filter, collation=None, hint=None, session=None,
let=None):
def delete_one(self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional["ClientSession"] = None,
let: Optional[Mapping[str, Any]] = None
) -> DeleteResult:
"""Delete a single document matching the filter.
>>> db.test.count_documents({'x': 1})
@ -1030,8 +1089,8 @@ class Collection(common.BaseObject):
- `session` (optional): a
:class:`~pymongo.client_session.ClientSession`.
- `let` (optional): Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
:Returns:
@ -1055,8 +1114,13 @@ class Collection(common.BaseObject):
collation=collation, hint=hint, session=session, let=let),
write_concern.acknowledged)
def delete_many(self, filter, collation=None, hint=None, session=None,
let=None):
def delete_many(self,
filter: Mapping[str, Any],
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional["ClientSession"] = None,
let: Optional[Mapping[str, Any]] = None
) -> DeleteResult:
"""Delete one or more documents matching the filter.
>>> db.test.count_documents({'x': 1})
@ -1080,8 +1144,8 @@ class Collection(common.BaseObject):
- `session` (optional): a
:class:`~pymongo.client_session.ClientSession`.
- `let` (optional): Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
:Returns:
@ -1105,7 +1169,7 @@ class Collection(common.BaseObject):
collation=collation, hint=hint, session=session, let=let),
write_concern.acknowledged)
def find_one(self, filter=None, *args, **kwargs):
def find_one(self, filter: Optional[Any] = None, *args: Any, **kwargs: Any) -> Optional[_DocumentType]:
"""Get a single document from the database.
All arguments to :meth:`find` are also valid arguments for
@ -1139,7 +1203,7 @@ class Collection(common.BaseObject):
return result
return None
def find(self, *args, **kwargs):
def find(self, *args: Any, **kwargs: Any) -> Cursor[_DocumentType]:
"""Query the database.
The `filter` argument is a prototype document that all results
@ -1328,7 +1392,7 @@ class Collection(common.BaseObject):
"""
return Cursor(self, *args, **kwargs)
def find_raw_batches(self, *args, **kwargs):
def find_raw_batches(self, *args: Any, **kwargs: Any) -> RawBatchCursor[_DocumentType]:
"""Query the database and retrieve batches of raw BSON.
Similar to the :meth:`find` method but returns a
@ -1396,7 +1460,7 @@ class Collection(common.BaseObject):
batch = result['cursor']['firstBatch']
return batch[0] if batch else None
def estimated_document_count(self, **kwargs):
def estimated_document_count(self, **kwargs: Any) -> int:
"""Get an estimate of the number of documents in this collection using
collection metadata.
@ -1445,7 +1509,7 @@ class Collection(common.BaseObject):
return self.__database.client._retryable_read(
_cmd, self.read_preference, None)
def count_documents(self, filter, session=None, **kwargs):
def count_documents(self, filter: Mapping[str, Any], session: Optional["ClientSession"] = None, **kwargs: Any) -> int:
"""Count the number of documents in this collection.
.. note:: For a fast count of the total documents in a collection see
@ -1523,7 +1587,7 @@ class Collection(common.BaseObject):
return self.__database.client._retryable_read(
_cmd, self._read_preference_for(session), session)
def create_indexes(self, indexes, session=None, **kwargs):
def create_indexes(self, indexes: Sequence[IndexModel], session: Optional["ClientSession"] = None, **kwargs: Any) -> List[str]:
"""Create one or more indexes on this collection.
>>> from pymongo import IndexModel, ASCENDING, DESCENDING
@ -1598,7 +1662,7 @@ class Collection(common.BaseObject):
session=session)
return names
def create_index(self, keys, session=None, **kwargs):
def create_index(self, keys: _IndexKeyHint, session: Optional["ClientSession"] = None, **kwargs: Any) -> str:
"""Creates an index on this collection.
Takes either a single key or a list of (key, direction) pairs.
@ -1701,7 +1765,7 @@ class Collection(common.BaseObject):
index = IndexModel(keys, **kwargs)
return self.__create_indexes([index], session, **cmd_options)[0]
def drop_indexes(self, session=None, **kwargs):
def drop_indexes(self, session: Optional["ClientSession"] = None, **kwargs: Any) -> None:
"""Drops all indexes on this collection.
Can be used on non-existant collections or collections with no indexes.
@ -1727,7 +1791,7 @@ class Collection(common.BaseObject):
"""
self.drop_index("*", session=session, **kwargs)
def drop_index(self, index_or_name, session=None, **kwargs):
def drop_index(self, index_or_name: _IndexKeyHint, session: Optional["ClientSession"] = None, **kwargs: Any) -> None:
"""Drops the specified index on this collection.
Can be used on non-existant collections or collections with no
@ -1780,7 +1844,7 @@ class Collection(common.BaseObject):
write_concern=self._write_concern_for(session),
session=session)
def list_indexes(self, session=None):
def list_indexes(self, session: Optional["ClientSession"] = None) -> CommandCursor[MutableMapping[str, Any]]:
"""Get a cursor over the index documents for this collection.
>>> for index in db.test.list_indexes():
@ -1829,7 +1893,7 @@ class Collection(common.BaseObject):
return self.__database.client._retryable_read(
_cmd, read_pref, session)
def index_information(self, session=None):
def index_information(self, session: Optional["ClientSession"] = None) -> MutableMapping[str, Any]:
"""Get information on this collection's indexes.
Returns a dictionary where the keys are index names (as
@ -1863,7 +1927,7 @@ class Collection(common.BaseObject):
info[index.pop("name")] = index
return info
def options(self, session=None):
def options(self, session: Optional["ClientSession"] = None) -> MutableMapping[str, Any]:
"""Get the options set on this collection.
Returns a dictionary of options and their values - see
@ -1896,6 +1960,7 @@ class Collection(common.BaseObject):
return {}
options = result.get("options", {})
assert options is not None
if "create" in options:
del options["create"]
@ -1911,7 +1976,7 @@ class Collection(common.BaseObject):
cmd.get_cursor, cmd.get_read_preference(session), session,
retryable=not cmd._performs_write)
def aggregate(self, pipeline, session=None, let=None, **kwargs):
def aggregate(self, pipeline: _Pipeline, session: Optional["ClientSession"] = None, let: Optional[Mapping[str, Any]] = None, **kwargs: Any) -> CommandCursor[_DocumentType]:
"""Perform an aggregation using the aggregation framework on this
collection.
@ -1993,7 +2058,9 @@ class Collection(common.BaseObject):
let=let,
**kwargs)
def aggregate_raw_batches(self, pipeline, session=None, **kwargs):
def aggregate_raw_batches(
self, pipeline: _Pipeline, session: Optional["ClientSession"] = None, **kwargs: Any
) -> RawBatchCursor[_DocumentType]:
"""Perform an aggregation and retrieve batches of raw BSON.
Similar to the :meth:`aggregate` method but returns a
@ -2030,9 +2097,17 @@ class Collection(common.BaseObject):
explicit_session=session is not None,
**kwargs)
def watch(self, pipeline=None, full_document=None, resume_after=None,
max_await_time_ms=None, batch_size=None, collation=None,
start_at_operation_time=None, session=None, start_after=None):
def watch(self,
pipeline: Optional[_Pipeline] = None,
full_document: Optional[str] = None,
resume_after: Optional[Mapping[str, Any]] = None,
max_await_time_ms: Optional[int] = None,
batch_size: Optional[int] = None,
collation: Optional[_CollationIn] = None,
start_at_operation_time: Optional[Timestamp] = None,
session: Optional["ClientSession"] = None,
start_after: Optional[Mapping[str, Any]] = None,
) -> CollectionChangeStream[_DocumentType]:
"""Watch changes on this collection.
Performs an aggregation with an implicit initial ``$changeStream``
@ -2132,7 +2207,7 @@ class Collection(common.BaseObject):
batch_size, collation, start_at_operation_time, session,
start_after)
def rename(self, new_name, session=None, **kwargs):
def rename(self, new_name: str, session: Optional["ClientSession"] = None, **kwargs: Any) -> MutableMapping[str, Any]:
"""Rename this collection.
If operating in auth mode, client must be authorized as an
@ -2183,7 +2258,9 @@ class Collection(common.BaseObject):
parse_write_concern_error=True,
session=s, client=self.__database.client)
def distinct(self, key, filter=None, session=None, **kwargs):
def distinct(
self, key: str, filter: Optional[Mapping[str, Any]] = None, session: Optional["ClientSession"] = None, **kwargs: Any
) -> List:
"""Get a list of distinct values for `key` among all documents
in this collection.
@ -2283,7 +2360,7 @@ class Collection(common.BaseObject):
raise ConfigurationError(
'arrayFilters is unsupported for unacknowledged '
'writes.')
cmd["arrayFilters"] = array_filters
cmd["arrayFilters"] = list(array_filters)
if hint is not None:
if sock_info.max_wire_version < 8:
raise ConfigurationError(
@ -2307,9 +2384,15 @@ class Collection(common.BaseObject):
return self.__database.client._retryable_write(
write_concern.acknowledged, _find_and_modify, session)
def find_one_and_delete(self, filter,
projection=None, sort=None, hint=None,
session=None, let=None, **kwargs):
def find_one_and_delete(self,
filter: Mapping[str, Any],
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
sort: Optional[_IndexList] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional["ClientSession"] = None,
let: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> _DocumentType:
"""Finds a single document and deletes it, returning the document.
>>> db.test.count_documents({'x': 1})
@ -2357,8 +2440,8 @@ class Collection(common.BaseObject):
as keyword arguments (for example maxTimeMS can be used with
recent server versions).
- `let` (optional): Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
.. versionchanged:: 4.1
@ -2384,10 +2467,18 @@ class Collection(common.BaseObject):
return self.__find_and_modify(filter, projection, sort, let=let,
hint=hint, session=session, **kwargs)
def find_one_and_replace(self, filter, replacement,
projection=None, sort=None, upsert=False,
return_document=ReturnDocument.BEFORE,
hint=None, session=None, let=None, **kwargs):
def find_one_and_replace(self,
filter: Mapping[str, Any],
replacement: Mapping[str, Any],
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
sort: Optional[_IndexList] = None,
upsert: bool = False,
return_document: bool = ReturnDocument.BEFORE,
hint: Optional[_IndexKeyHint] = None,
session: Optional["ClientSession"] = None,
let: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> _DocumentType:
"""Finds a single document and replaces it, returning either the
original or the replaced document.
@ -2438,8 +2529,8 @@ class Collection(common.BaseObject):
- `session` (optional): a
:class:`~pymongo.client_session.ClientSession`.
- `let` (optional): Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
- `**kwargs` (optional): additional command arguments can be passed
as keyword arguments (for example maxTimeMS can be used with
@ -2470,11 +2561,19 @@ class Collection(common.BaseObject):
sort, upsert, return_document, let=let,
hint=hint, session=session, **kwargs)
def find_one_and_update(self, filter, update,
projection=None, sort=None, upsert=False,
return_document=ReturnDocument.BEFORE,
array_filters=None, hint=None, session=None,
let=None, **kwargs):
def find_one_and_update(self,
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
sort: Optional[_IndexList] = None,
upsert: bool = False,
return_document: bool = ReturnDocument.BEFORE,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional["ClientSession"] = None,
let: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> _DocumentType:
"""Finds a single document and updates it, returning either the
original or the updated document.
@ -2564,8 +2663,8 @@ class Collection(common.BaseObject):
- `session` (optional): a
:class:`~pymongo.client_session.ClientSession`.
- `let` (optional): Map of parameter names and values. Values must be
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
- `**kwargs` (optional): additional command arguments can be passed
as keyword arguments (for example maxTimeMS can be used with
@ -2600,15 +2699,15 @@ class Collection(common.BaseObject):
array_filters, hint=hint, let=let,
session=session, **kwargs)
def __iter__(self):
def __iter__(self) -> "Collection[_DocumentType]":
return self
def __next__(self):
def __next__(self) -> None:
raise TypeError("'Collection' object is not iterable")
next = __next__
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> None:
"""This is only here so that some API misusages are easier to debug.
"""
if "." not in self.__name:

View File

@ -15,28 +15,38 @@
"""CommandCursor class to iterate over command results."""
from collections import deque
from typing import (TYPE_CHECKING, Any, Generic, Iterator, Mapping, Optional,
Tuple)
from bson import _convert_raw_document_lists_to_streams
from pymongo.cursor import _SocketManager, _CURSOR_CLOSED_ERRORS
from pymongo.errors import (ConnectionFailure,
InvalidOperation,
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager
from pymongo.errors import (ConnectionFailure, InvalidOperation,
OperationFailure)
from pymongo.message import (_CursorAddress,
_GetMore,
_RawBatchGetMore)
from pymongo.message import _CursorAddress, _GetMore, _RawBatchGetMore
from pymongo.response import PinnedResponse
from pymongo.typings import _DocumentType
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
class CommandCursor(object):
class CommandCursor(Generic[_DocumentType]):
"""A cursor / iterator over command cursors."""
_getmore_class = _GetMore
def __init__(self, collection, cursor_info, address,
batch_size=0, max_await_time_ms=None, session=None,
explicit_session=False):
def __init__(self,
collection: "Collection[_DocumentType]",
cursor_info: Mapping[str, Any],
address: Optional[Tuple[str, Optional[int]]],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional["ClientSession"] = None,
explicit_session: bool = False,
) -> None:
"""Create a new command cursor."""
self.__sock_mgr = None
self.__collection = collection
self.__sock_mgr: Any = None
self.__collection: Collection[_DocumentType] = collection
self.__id = cursor_info['id']
self.__data = deque(cursor_info['firstBatch'])
self.__postbatchresumetoken = cursor_info.get('postBatchResumeToken')
@ -60,7 +70,7 @@ class CommandCursor(object):
and max_await_time_ms is not None):
raise TypeError("max_await_time_ms must be an integer or None")
def __del__(self):
def __del__(self) -> None:
self.__die()
def __die(self, synchronous=False):
@ -92,12 +102,12 @@ class CommandCursor(object):
self.__session._end_session(lock=synchronous)
self.__session = None
def close(self):
def close(self) -> None:
"""Explicitly close / kill this cursor.
"""
self.__die(True)
def batch_size(self, batch_size):
def batch_size(self, batch_size: int) -> "CommandCursor[_DocumentType]":
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
@ -222,7 +232,7 @@ class CommandCursor(object):
return len(self.__data)
@property
def alive(self):
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
Even if :attr:`alive` is ``True``, :meth:`next` can raise
@ -239,12 +249,12 @@ class CommandCursor(object):
return bool(len(self.__data) or (not self.__killed))
@property
def cursor_id(self):
def cursor_id(self) -> int:
"""Returns the id of the cursor."""
return self.__id
@property
def address(self):
def address(self) -> Optional[Tuple[str, Optional[int]]]:
"""The (host, port) of the server used, or None.
.. versionadded:: 3.0
@ -252,18 +262,19 @@ class CommandCursor(object):
return self.__address
@property
def session(self):
def session(self) -> Optional["ClientSession"]:
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
.. versionadded:: 3.6
"""
if self.__explicit_session:
return self.__session
return None
def __iter__(self):
def __iter__(self) -> Iterator[_DocumentType]:
return self
def next(self):
def next(self) -> _DocumentType:
"""Advance the cursor."""
# Block until a document is returnable.
while self.alive:
@ -284,19 +295,25 @@ class CommandCursor(object):
else:
return None
def __enter__(self):
def __enter__(self) -> "CommandCursor[_DocumentType]":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
class RawBatchCommandCursor(CommandCursor):
class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
_getmore_class = _RawBatchGetMore
def __init__(self, collection, cursor_info, address,
batch_size=0, max_await_time_ms=None, session=None,
explicit_session=False):
def __init__(self,
collection: "Collection[_DocumentType]",
cursor_info: Mapping[str, Any],
address: Optional[Tuple[str, Optional[int]]],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional["ClientSession"] = None,
explicit_session: bool = False,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
Should not be called directly by application developers -

View File

@ -17,8 +17,9 @@
import datetime
import warnings
from collections import abc, OrderedDict
from collections import OrderedDict, abc
from typing import (Any, Callable, Dict, List, Mapping, MutableMapping,
Optional, Sequence, Tuple, Type, Union, cast)
from urllib.parse import unquote_plus
from bson import SON
@ -29,18 +30,18 @@ from pymongo.auth import MECHANISMS
from pymongo.compression_support import (validate_compressors,
validate_zlib_compression_level)
from pymongo.driver_info import DriverInfo
from pymongo.server_api import ServerApi
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _validate_event_listeners
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _MONGOS_MODES, _ServerMode
from pymongo.server_api import ServerApi
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
ORDERED_TYPES = (SON, OrderedDict)
ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict)
# Defaults until we connect to a server and get updated limits.
MAX_BSON_SIZE = 16 * (1024 ** 2)
MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE
MAX_MESSAGE_SIZE: int = 2 * MAX_BSON_SIZE
MIN_WIRE_VERSION = 0
MAX_WIRE_VERSION = 0
MAX_WRITE_BATCH_SIZE = 1000
@ -85,13 +86,13 @@ MIN_POOL_SIZE = 0
MAX_CONNECTING = 2
# Default value for maxIdleTimeMS.
MAX_IDLE_TIME_MS = None
MAX_IDLE_TIME_MS: Optional[int] = None
# Default value for maxIdleTimeMS in seconds.
MAX_IDLE_TIME_SEC = None
MAX_IDLE_TIME_SEC: Optional[int] = None
# Default value for waitQueueTimeoutMS in seconds.
WAIT_QUEUE_TIMEOUT = None
WAIT_QUEUE_TIMEOUT: Optional[int] = None
# Default value for localThresholdMS.
LOCAL_THRESHOLD_MS = 15
@ -103,10 +104,10 @@ RETRY_WRITES = True
RETRY_READS = True
# The error code returned when a command doesn't exist.
COMMAND_NOT_FOUND_CODES = (59,)
COMMAND_NOT_FOUND_CODES: Sequence[int] = (59,)
# Error codes to ignore if GridFS calls createIndex on a secondary
UNAUTHORIZED_CODES = (13, 16547, 16548)
UNAUTHORIZED_CODES: Sequence[int] = (13, 16547, 16548)
# Maximum number of sessions to send in a single endSessions command.
# From the driver sessions spec.
@ -116,7 +117,7 @@ _MAX_END_SESSIONS = 10000
SRV_SERVICE_NAME = "mongodb"
def partition_node(node):
def partition_node(node: str) -> Tuple[str, int]:
"""Split a host:port string into (host, int(port)) pair."""
host = node
port = 27017
@ -128,7 +129,7 @@ def partition_node(node):
return host, port
def clean_node(node):
def clean_node(node: str) -> Tuple[str, int]:
"""Split and normalize a node name from a hello response."""
host, port = partition_node(node)
@ -139,7 +140,7 @@ def clean_node(node):
return host.lower(), port
def raise_config_error(key, dummy):
def raise_config_error(key: str, dummy: Any) -> None:
"""Raise ConfigurationError with the given key name."""
raise ConfigurationError("Unknown option %s" % (key,))
@ -154,14 +155,14 @@ _UUID_REPRESENTATIONS = {
}
def validate_boolean(option, value):
def validate_boolean(option: str, value: Any) -> bool:
"""Validates that 'value' is True or False."""
if isinstance(value, bool):
return value
raise TypeError("%s must be True or False" % (option,))
def validate_boolean_or_string(option, value):
def validate_boolean_or_string(option: str, value: Any) -> bool:
"""Validates that value is True, False, 'true', or 'false'."""
if isinstance(value, str):
if value not in ('true', 'false'):
@ -171,7 +172,7 @@ def validate_boolean_or_string(option, value):
return validate_boolean(option, value)
def validate_integer(option, value):
def validate_integer(option: str, value: Any) -> int:
"""Validates that 'value' is an integer (or basestring representation).
"""
if isinstance(value, int):
@ -185,7 +186,7 @@ def validate_integer(option, value):
raise TypeError("Wrong type for %s, value must be an integer" % (option,))
def validate_positive_integer(option, value):
def validate_positive_integer(option: str, value: Any) -> int:
"""Validate that 'value' is a positive integer, which does not include 0.
"""
val = validate_integer(option, value)
@ -195,7 +196,7 @@ def validate_positive_integer(option, value):
return val
def validate_non_negative_integer(option, value):
def validate_non_negative_integer(option: str, value: Any) -> int:
"""Validate that 'value' is a positive integer or 0.
"""
val = validate_integer(option, value)
@ -205,7 +206,7 @@ def validate_non_negative_integer(option, value):
return val
def validate_readable(option, value):
def validate_readable(option: str, value: Any) -> Optional[str]:
"""Validates that 'value' is file-like and readable.
"""
if value is None:
@ -217,7 +218,7 @@ def validate_readable(option, value):
return value
def validate_positive_integer_or_none(option, value):
def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]:
"""Validate that 'value' is a positive integer or None.
"""
if value is None:
@ -225,7 +226,7 @@ def validate_positive_integer_or_none(option, value):
return validate_positive_integer(option, value)
def validate_non_negative_integer_or_none(option, value):
def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]:
"""Validate that 'value' is a positive integer or 0 or None.
"""
if value is None:
@ -233,7 +234,7 @@ def validate_non_negative_integer_or_none(option, value):
return validate_non_negative_integer(option, value)
def validate_string(option, value):
def validate_string(option: str, value: Any) -> str:
"""Validates that 'value' is an instance of `str`.
"""
if isinstance(value, str):
@ -242,7 +243,7 @@ def validate_string(option, value):
"str" % (option,))
def validate_string_or_none(option, value):
def validate_string_or_none(option: str, value: Any) -> Optional[str]:
"""Validates that 'value' is an instance of `basestring` or `None`.
"""
if value is None:
@ -250,7 +251,7 @@ def validate_string_or_none(option, value):
return validate_string(option, value)
def validate_int_or_basestring(option, value):
def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]:
"""Validates that 'value' is an integer or string.
"""
if isinstance(value, int):
@ -264,7 +265,7 @@ def validate_int_or_basestring(option, value):
"integer or a string" % (option,))
def validate_non_negative_int_or_basestring(option, value):
def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]:
"""Validates that 'value' is an integer or string.
"""
if isinstance(value, int):
@ -279,7 +280,7 @@ def validate_non_negative_int_or_basestring(option, value):
"non negative integer or a string" % (option,))
def validate_positive_float(option, value):
def validate_positive_float(option: str, value: Any) -> float:
"""Validates that 'value' is a float, or can be converted to one, and is
positive.
"""
@ -299,7 +300,7 @@ def validate_positive_float(option, value):
return value
def validate_positive_float_or_zero(option, value):
def validate_positive_float_or_zero(option: str, value: Any) -> float:
"""Validates that 'value' is 0 or a positive float, or can be converted to
0 or a positive float.
"""
@ -308,7 +309,7 @@ def validate_positive_float_or_zero(option, value):
return validate_positive_float(option, value)
def validate_timeout_or_none(option, value):
def validate_timeout_or_none(option: str, value: Any) -> Optional[float]:
"""Validates a timeout specified in milliseconds returning
a value in floating point seconds.
"""
@ -317,7 +318,7 @@ def validate_timeout_or_none(option, value):
return validate_positive_float(option, value) / 1000.0
def validate_timeout_or_zero(option, value):
def validate_timeout_or_zero(option: str, value: Any) -> float:
"""Validates a timeout specified in milliseconds returning
a value in floating point seconds for the case where None is an error
and 0 is valid. Setting the timeout to nothing in the URI string is a
@ -330,7 +331,7 @@ def validate_timeout_or_zero(option, value):
return validate_positive_float(option, value) / 1000.0
def validate_timeout_or_none_or_zero(option, value):
def validate_timeout_or_none_or_zero(option: Any, value: Any) -> Optional[float]:
"""Validates a timeout specified in milliseconds returning
a value in floating point seconds. value=0 and value="0" are treated the
same as value=None which means unlimited timeout.
@ -340,7 +341,7 @@ def validate_timeout_or_none_or_zero(option, value):
return validate_positive_float(option, value) / 1000.0
def validate_max_staleness(option, value):
def validate_max_staleness(option: str, value: Any) -> int:
"""Validates maxStalenessSeconds according to the Max Staleness Spec."""
if value == -1 or value == "-1":
# Default: No maximum staleness.
@ -348,7 +349,7 @@ def validate_max_staleness(option, value):
return validate_positive_integer(option, value)
def validate_read_preference(dummy, value):
def validate_read_preference(dummy: Any, value: Any) -> _ServerMode:
"""Validate a read preference.
"""
if not isinstance(value, _ServerMode):
@ -356,7 +357,7 @@ def validate_read_preference(dummy, value):
return value
def validate_read_preference_mode(dummy, value):
def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode:
"""Validate read preference mode for a MongoClient.
.. versionchanged:: 3.5
@ -368,7 +369,7 @@ def validate_read_preference_mode(dummy, value):
return value
def validate_auth_mechanism(option, value):
def validate_auth_mechanism(option: str, value: Any) -> str:
"""Validate the authMechanism URI option.
"""
if value not in MECHANISMS:
@ -376,7 +377,7 @@ def validate_auth_mechanism(option, value):
return value
def validate_uuid_representation(dummy, value):
def validate_uuid_representation(dummy: Any, value: Any) -> int:
"""Validate the uuid representation option selected in the URI.
"""
try:
@ -387,13 +388,13 @@ def validate_uuid_representation(dummy, value):
"%s" % (value, tuple(_UUID_REPRESENTATIONS)))
def validate_read_preference_tags(name, value):
def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]]:
"""Parse readPreferenceTags if passed as a client kwarg.
"""
if not isinstance(value, list):
value = [value]
tag_sets = []
tag_sets: List = []
for tag_set in value:
if tag_set == '':
tag_sets.append({})
@ -416,10 +417,10 @@ _MECHANISM_PROPS = frozenset(['SERVICE_NAME',
'AWS_SESSION_TOKEN'])
def validate_auth_mechanism_properties(option, value):
def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]:
"""Validate authMechanismProperties."""
value = validate_string(option, value)
props = {}
props: Dict[str, Any] = {}
for opt in value.split(','):
try:
key, val = opt.split(':')
@ -443,7 +444,7 @@ def validate_auth_mechanism_properties(option, value):
return props
def validate_document_class(option, value):
def validate_document_class(option: str, value: Any) -> Union[Type[MutableMapping], Type[RawBSONDocument]]:
"""Validate the document_class option."""
if not issubclass(value, (abc.MutableMapping, RawBSONDocument)):
raise TypeError("%s must be dict, bson.son.SON, "
@ -452,7 +453,7 @@ def validate_document_class(option, value):
return value
def validate_type_registry(option, value):
def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]:
"""Validate the type_registry option."""
if value is not None and not isinstance(value, TypeRegistry):
raise TypeError("%s must be an instance of %s" % (
@ -460,21 +461,21 @@ def validate_type_registry(option, value):
return value
def validate_list(option, value):
def validate_list(option: str, value: Any) -> List:
"""Validates that 'value' is a list."""
if not isinstance(value, list):
raise TypeError("%s must be a list" % (option,))
return value
def validate_list_or_none(option, value):
def validate_list_or_none(option: Any, value: Any) -> Optional[List]:
"""Validates that 'value' is a list or None."""
if value is None:
return value
return validate_list(option, value)
def validate_list_or_mapping(option, value):
def validate_list_or_mapping(option: Any, value: Any) -> None:
"""Validates that 'value' is a list or a document."""
if not isinstance(value, (abc.Mapping, list)):
raise TypeError("%s must either be a list or an instance of dict, "
@ -482,7 +483,7 @@ def validate_list_or_mapping(option, value):
"collections.Mapping" % (option,))
def validate_is_mapping(option, value):
def validate_is_mapping(option: str, value: Any) -> None:
"""Validate the type of method arguments that expect a document."""
if not isinstance(value, abc.Mapping):
raise TypeError("%s must be an instance of dict, bson.son.SON, or "
@ -490,7 +491,7 @@ def validate_is_mapping(option, value):
"collections.Mapping" % (option,))
def validate_is_document_type(option, value):
def validate_is_document_type(option: str, value: Any) -> None:
"""Validate the type of method arguments that expect a MongoDB document."""
if not isinstance(value, (abc.MutableMapping, RawBSONDocument)):
raise TypeError("%s must be an instance of dict, bson.son.SON, "
@ -499,7 +500,7 @@ def validate_is_document_type(option, value):
"collections.MutableMapping" % (option,))
def validate_appname_or_none(option, value):
def validate_appname_or_none(option: str, value: Any) -> Optional[str]:
"""Validate the appname option."""
if value is None:
return value
@ -510,7 +511,7 @@ def validate_appname_or_none(option, value):
return value
def validate_driver_or_none(option, value):
def validate_driver_or_none(option: Any, value: Any) -> Optional[DriverInfo]:
"""Validate the driver keyword arg."""
if value is None:
return value
@ -519,7 +520,7 @@ def validate_driver_or_none(option, value):
return value
def validate_server_api_or_none(option, value):
def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]:
"""Validate the server_api keyword arg."""
if value is None:
return value
@ -528,7 +529,7 @@ def validate_server_api_or_none(option, value):
return value
def validate_is_callable_or_none(option, value):
def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]:
"""Validates that 'value' is a callable."""
if value is None:
return value
@ -537,7 +538,7 @@ def validate_is_callable_or_none(option, value):
return value
def validate_ok_for_replace(replacement):
def validate_ok_for_replace(replacement: Mapping[str, Any]) -> None:
"""Validate a replacement document."""
validate_is_mapping("replacement", replacement)
# Replacement can be {}
@ -547,7 +548,7 @@ def validate_ok_for_replace(replacement):
raise ValueError('replacement can not include $ operators')
def validate_ok_for_update(update):
def validate_ok_for_update(update: Any) -> None:
"""Validate an update document."""
validate_list_or_mapping("update", update)
# Update cannot be {}.
@ -563,7 +564,7 @@ def validate_ok_for_update(update):
_UNICODE_DECODE_ERROR_HANDLERS = frozenset(['strict', 'replace', 'ignore'])
def validate_unicode_decode_error_handler(dummy, value):
def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str:
"""Validate the Unicode decode error handler option of CodecOptions.
"""
if value not in _UNICODE_DECODE_ERROR_HANDLERS:
@ -573,7 +574,7 @@ def validate_unicode_decode_error_handler(dummy, value):
return value
def validate_tzinfo(dummy, value):
def validate_tzinfo(dummy: Any, value: Any) -> Optional[datetime.tzinfo]:
"""Validate the tzinfo option
"""
if value is not None and not isinstance(value, datetime.tzinfo):
@ -581,7 +582,7 @@ def validate_tzinfo(dummy, value):
return value
def validate_auto_encryption_opts_or_none(option, value):
def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[Any]:
"""Validate the driver keyword arg."""
if value is None:
return value
@ -595,7 +596,7 @@ def validate_auto_encryption_opts_or_none(option, value):
# Dictionary where keys are the names of public URI options, and values
# are lists of aliases for that option.
URI_OPTIONS_ALIAS_MAP = {
URI_OPTIONS_ALIAS_MAP: Dict[str, List[str]] = {
'tls': ['ssl'],
}
@ -603,7 +604,7 @@ URI_OPTIONS_ALIAS_MAP = {
# are functions that validate user-input values for that option. If an option
# alias uses a different validator than its public counterpart, it should be
# included here as a key, value pair.
URI_OPTIONS_VALIDATOR_MAP = {
URI_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = {
'appname': validate_appname_or_none,
'authmechanism': validate_auth_mechanism,
'authmechanismproperties': validate_auth_mechanism_properties,
@ -644,7 +645,7 @@ URI_OPTIONS_VALIDATOR_MAP = {
# Dictionary where keys are the names of URI options specific to pymongo,
# and values are functions that validate user-input values for those options.
NONSPEC_OPTIONS_VALIDATOR_MAP = {
NONSPEC_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = {
'connect': validate_boolean_or_string,
'driver': validate_driver_or_none,
'server_api': validate_server_api_or_none,
@ -661,7 +662,7 @@ NONSPEC_OPTIONS_VALIDATOR_MAP = {
# Dictionary where keys are the names of keyword-only options for the
# MongoClient constructor, and values are functions that validate user-input
# values for those options.
KW_VALIDATORS = {
KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = {
'document_class': validate_document_class,
'type_registry': validate_type_registry,
'read_preference': validate_read_preference,
@ -677,14 +678,14 @@ KW_VALIDATORS = {
# internally-used names of that URI option. Options with only one name
# variant need not be included here. Options whose public and internal
# names are the same need not be included here.
INTERNAL_URI_OPTION_NAME_MAP = {
INTERNAL_URI_OPTION_NAME_MAP: Dict[str, str] = {
'ssl': 'tls',
}
# Map from deprecated URI option names to a tuple indicating the method of
# their deprecation and any additional information that may be needed to
# construct the warning message.
URI_OPTIONS_DEPRECATION_MAP = {
URI_OPTIONS_DEPRECATION_MAP: Dict[str, Tuple[str, str]] = {
# format: <deprecated option name>: (<mode>, <message>),
# Supported <mode> values:
# - 'renamed': <message> should be the new option name. Note that case is
@ -704,11 +705,11 @@ for optname, aliases in URI_OPTIONS_ALIAS_MAP.items():
URI_OPTIONS_VALIDATOR_MAP[optname])
# Map containing all URI option and keyword argument validators.
VALIDATORS = URI_OPTIONS_VALIDATOR_MAP.copy()
VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy()
VALIDATORS.update(KW_VALIDATORS)
# List of timeout-related options.
TIMEOUT_OPTIONS = [
TIMEOUT_OPTIONS: List[str] = [
'connecttimeoutms',
'heartbeatfrequencyms',
'maxidletimems',
@ -722,7 +723,7 @@ TIMEOUT_OPTIONS = [
_AUTH_OPTIONS = frozenset(['authmechanismproperties'])
def validate_auth_option(option, value):
def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]:
"""Validate optional authentication parameters.
"""
lower, value = validate(option, value)
@ -732,7 +733,7 @@ def validate_auth_option(option, value):
return option, value
def validate(option, value):
def validate(option: str, value: Any) -> Tuple[str, Any]:
"""Generic validation function.
"""
lower = option.lower()
@ -741,7 +742,7 @@ def validate(option, value):
return option, value
def get_validated_options(options, warn=True):
def get_validated_options(options: Mapping[str, Any], warn: bool = True) -> MutableMapping[str, Any]:
"""Validate each entry in options and raise a warning if it is not valid.
Returns a copy of options with invalid entries removed.
@ -751,6 +752,7 @@ def get_validated_options(options, warn=True):
invalid options will be ignored. Otherwise, invalid options will
cause errors.
"""
validated_options: MutableMapping[str, Any]
if isinstance(options, _CaseInsensitiveDictionary):
validated_options = _CaseInsensitiveDictionary()
get_normed_key = lambda x: x
@ -794,8 +796,8 @@ class BaseObject(object):
SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB.
"""
def __init__(self, codec_options, read_preference, write_concern,
read_concern):
def __init__(self, codec_options: CodecOptions, read_preference: _ServerMode, write_concern: WriteConcern,
read_concern: ReadConcern) -> None:
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of "
@ -819,14 +821,14 @@ class BaseObject(object):
self.__read_concern = read_concern
@property
def codec_options(self):
def codec_options(self) -> CodecOptions:
"""Read only access to the :class:`~bson.codec_options.CodecOptions`
of this instance.
"""
return self.__codec_options
@property
def write_concern(self):
def write_concern(self) -> WriteConcern:
"""Read only access to the :class:`~pymongo.write_concern.WriteConcern`
of this instance.
@ -844,7 +846,7 @@ class BaseObject(object):
return self.write_concern
@property
def read_preference(self):
def read_preference(self) -> _ServerMode:
"""Read only access to the read preference of this instance.
.. versionchanged:: 3.0
@ -861,7 +863,7 @@ class BaseObject(object):
return self.__read_preference
@property
def read_concern(self):
def read_concern(self) -> ReadConcern:
"""Read only access to the :class:`~pymongo.read_concern.ReadConcern`
of this instance.

View File

@ -13,6 +13,7 @@
# limitations under the License.
import warnings
from typing import Callable
try:
import snappy
@ -99,7 +100,7 @@ class CompressionSettings(object):
return ZstdContext()
def _zlib_no_compress(data):
def _zlib_no_compress(data, level=None):
"""Compress data with zlib level 0."""
cobj = zlib.compressobj(0)
return b"".join([cobj.compress(data), cobj.flush()])
@ -117,6 +118,8 @@ class ZlibContext(object):
compressor_id = 2
def __init__(self, level):
self.compress: Callable[[bytes], bytes]
# Jython zlib.compress doesn't support -1
if level == -1:
self.compress = zlib.compress
@ -124,7 +127,7 @@ class ZlibContext(object):
elif level == 0:
self.compress = _zlib_no_compress
else:
self.compress = lambda data: zlib.compress(data, level)
self.compresss = lambda data, _: zlib.compress(data, level)
class ZstdContext(object):

View File

@ -13,29 +13,26 @@
# limitations under the License.
"""Cursor class to iterate over Mongo query results."""
import copy
import threading
import warnings
from collections import deque
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Mapping,
MutableMapping, Optional, Sequence, Tuple, Union, cast, overload)
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
from bson.code import Code
from bson.son import SON
from pymongo import helpers
from pymongo.common import (validate_boolean, validate_is_mapping,
validate_is_document_type)
from pymongo.collation import validate_collation_or_none
from pymongo.errors import (ConnectionFailure,
InvalidOperation,
from pymongo.common import (validate_boolean, validate_is_document_type,
validate_is_mapping)
from pymongo.errors import (ConnectionFailure, InvalidOperation,
OperationFailure)
from pymongo.message import (_CursorAddress,
_GetMore,
_RawBatchGetMore,
_Query,
_RawBatchQuery)
from pymongo.message import (_CursorAddress, _GetMore, _Query,
_RawBatchGetMore, _RawBatchQuery)
from pymongo.response import PinnedResponse
from pymongo.typings import _CollationIn, _DocumentType
# These errors mean that the server has already killed the cursor so there is
# no need to send killCursors.
@ -126,22 +123,47 @@ class _SocketManager(object):
self.sock.unpin()
self.sock = None
_Sort = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
_Hint = Union[str, _Sort]
class Cursor(object):
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
class Cursor(Generic[_DocumentType]):
"""A cursor / iterator over Mongo query results.
"""
_query_class = _Query
_getmore_class = _GetMore
def __init__(self, collection, filter=None, projection=None, skip=0,
limit=0, no_cursor_timeout=False,
cursor_type=CursorType.NON_TAILABLE,
sort=None, allow_partial_results=False, oplog_replay=False,
batch_size=0,
collation=None, hint=None, max_scan=None, max_time_ms=None,
max=None, min=None, return_key=None, show_record_id=None,
snapshot=None, comment=None, session=None,
allow_disk_use=None, let=None):
def __init__(self,
collection: "Collection[_DocumentType]",
filter: Optional[Mapping[str, Any]] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
skip: int = 0,
limit: int = 0,
no_cursor_timeout: bool = False,
cursor_type: int = CursorType.NON_TAILABLE,
sort: Optional[_Sort] = None,
allow_partial_results: bool = False,
oplog_replay: bool = False,
batch_size: int = 0,
collation: Optional[_CollationIn] = None,
hint: Optional[_Hint] = None,
max_scan: Optional[int] = None,
max_time_ms: Optional[int] = None,
max: Optional[_Sort] = None,
min: Optional[_Sort] = None,
return_key: Optional[bool] = None,
show_record_id: Optional[bool] = None,
snapshot: Optional[bool] = None,
comment: Any = None,
session: Optional["ClientSession"] = None,
allow_disk_use: Optional[bool] = None,
let: Optional[bool] = None
) -> None:
"""Create a new cursor.
Should not be called directly by application developers - see
@ -151,11 +173,12 @@ class Cursor(object):
"""
# Initialize all attributes used in __del__ before possibly raising
# an error to avoid attribute errors during garbage collection.
self.__collection = collection
self.__id = None
self.__collection: Collection[_DocumentType] = collection
self.__id: Any = None
self.__exhaust = False
self.__sock_mgr = None
self.__sock_mgr: Any = None
self.__killed = False
self.__session: Optional["ClientSession"]
if session:
self.__session = session
@ -164,10 +187,7 @@ class Cursor(object):
self.__session = None
self.__explicit_session = False
spec = filter
if spec is None:
spec = {}
spec: Mapping[str, Any] = filter or {}
validate_is_mapping("filter", spec)
if not isinstance(skip, int):
raise TypeError("skip must be an instance of int")
@ -203,6 +223,7 @@ class Cursor(object):
self.__let = let
self.__spec = spec
self.__has_filter = filter is not None
self.__projection = projection
self.__skip = skip
self.__limit = limit
@ -212,9 +233,9 @@ class Cursor(object):
self.__explain = False
self.__comment = comment
self.__max_time_ms = max_time_ms
self.__max_await_time_ms = None
self.__max = max
self.__min = min
self.__max_await_time_ms: Optional[int] = None
self.__max: Optional[Union[SON[Any, Any], _Sort]] = max
self.__min: Optional[Union[SON[Any, Any], _Sort]] = min
self.__collation = validate_collation_or_none(collation)
self.__return_key = return_key
self.__show_record_id = show_record_id
@ -239,7 +260,7 @@ class Cursor(object):
# it anytime we change __limit.
self.__empty = False
self.__data = deque()
self.__data: deque = deque()
self.__address = None
self.__retrieved = 0
@ -261,22 +282,22 @@ class Cursor(object):
self.__collname = collection.name
@property
def collection(self):
def collection(self) -> "Collection[_DocumentType]":
"""The :class:`~pymongo.collection.Collection` that this
:class:`Cursor` is iterating.
"""
return self.__collection
@property
def retrieved(self):
def retrieved(self) -> int:
"""The number of documents retrieved so far.
"""
return self.__retrieved
def __del__(self):
def __del__(self) -> None:
self.__die()
def rewind(self):
def rewind(self) -> "Cursor[_DocumentType]":
"""Rewind this cursor to its unevaluated state.
Reset this cursor if it has been partially or completely evaluated.
@ -294,7 +315,7 @@ class Cursor(object):
return self
def clone(self):
def clone(self) -> "Cursor[_DocumentType]":
"""Get a clone of this cursor.
Returns a new Cursor instance with options matching those that have
@ -318,7 +339,7 @@ class Cursor(object):
"batch_size", "max_scan",
"query_flags", "collation", "empty",
"show_record_id", "return_key", "allow_disk_use",
"snapshot", "exhaust")
"snapshot", "exhaust", "has_filter")
data = dict((k, v) for k, v in self.__dict__.items()
if k.startswith('_Cursor__') and k[9:] in values_to_clone)
if deepcopy:
@ -360,7 +381,7 @@ class Cursor(object):
self.__session = None
self.__sock_mgr = None
def close(self):
def close(self) -> None:
"""Explicitly close / kill this cursor.
"""
self.__die(True)
@ -397,7 +418,7 @@ class Cursor(object):
if operators:
# Make a shallow copy so we can cleanly rewind or clone.
spec = self.__spec.copy()
spec = copy.copy(self.__spec)
# Allow-listed commands must be wrapped in $query.
if "$query" not in spec:
@ -429,7 +450,7 @@ class Cursor(object):
if self.__retrieved or self.__id is not None:
raise InvalidOperation("cannot set options after executing query")
def add_option(self, mask):
def add_option(self, mask: int) -> "Cursor[_DocumentType]":
"""Set arbitrary query flags using a bitmask.
To set the tailable flag:
@ -450,7 +471,7 @@ class Cursor(object):
self.__query_flags |= mask
return self
def remove_option(self, mask):
def remove_option(self, mask: int) -> "Cursor[_DocumentType]":
"""Unset arbitrary query flags using a bitmask.
To unset the tailable flag:
@ -466,7 +487,7 @@ class Cursor(object):
self.__query_flags &= ~mask
return self
def allow_disk_use(self, allow_disk_use):
def allow_disk_use(self, allow_disk_use: bool) -> "Cursor[_DocumentType]":
"""Specifies whether MongoDB can use temporary disk files while
processing a blocking sort operation.
@ -488,7 +509,7 @@ class Cursor(object):
self.__allow_disk_use = allow_disk_use
return self
def limit(self, limit):
def limit(self, limit: int) -> "Cursor[_DocumentType]":
"""Limits the number of results to be returned by this cursor.
Raises :exc:`TypeError` if `limit` is not an integer. Raises
@ -511,7 +532,7 @@ class Cursor(object):
self.__limit = limit
return self
def batch_size(self, batch_size):
def batch_size(self, batch_size: int) -> "Cursor[_DocumentType]":
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
@ -539,7 +560,7 @@ class Cursor(object):
self.__batch_size = batch_size
return self
def skip(self, skip):
def skip(self, skip: int) -> "Cursor[_DocumentType]":
"""Skips the first `skip` results of this cursor.
Raises :exc:`TypeError` if `skip` is not an integer. Raises
@ -560,7 +581,7 @@ class Cursor(object):
self.__skip = skip
return self
def max_time_ms(self, max_time_ms):
def max_time_ms(self, max_time_ms: Optional[int]) -> "Cursor[_DocumentType]":
"""Specifies a time limit for a query operation. If the specified
time is exceeded, the operation will be aborted and
:exc:`~pymongo.errors.ExecutionTimeout` is raised. If `max_time_ms`
@ -581,7 +602,7 @@ class Cursor(object):
self.__max_time_ms = max_time_ms
return self
def max_await_time_ms(self, max_await_time_ms):
def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> "Cursor[_DocumentType]":
"""Specifies a time limit for a getMore operation on a
:attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` cursor. For all other
types of cursor max_await_time_ms is ignored.
@ -609,6 +630,14 @@ class Cursor(object):
return self
@overload
def __getitem__(self, index: int) -> _DocumentType:
...
@overload
def __getitem__(self, index: slice) -> "Cursor[_DocumentType]":
...
def __getitem__(self, index):
"""Get a single document or a slice of documents from this cursor.
@ -691,7 +720,7 @@ class Cursor(object):
raise TypeError("index %r cannot be applied to Cursor "
"instances" % index)
def max_scan(self, max_scan):
def max_scan(self, max_scan: Optional[int]) -> "Cursor[_DocumentType]":
"""**DEPRECATED** - Limit the number of documents to scan when
performing the query.
@ -711,7 +740,7 @@ class Cursor(object):
self.__max_scan = max_scan
return self
def max(self, spec):
def max(self, spec: _Sort) -> "Cursor[_DocumentType]":
"""Adds ``max`` operator that specifies upper bound for specific index.
When using ``max``, :meth:`~hint` should also be configured to ensure
@ -734,7 +763,7 @@ class Cursor(object):
self.__max = SON(spec)
return self
def min(self, spec):
def min(self, spec: _Sort) -> "Cursor[_DocumentType]":
"""Adds ``min`` operator that specifies lower bound for specific index.
When using ``min``, :meth:`~hint` should also be configured to ensure
@ -757,7 +786,7 @@ class Cursor(object):
self.__min = SON(spec)
return self
def sort(self, key_or_list, direction=None):
def sort(self, key_or_list: _Hint, direction: Optional[Union[int, str]] = None) -> "Cursor[_DocumentType]":
"""Sorts this cursor's results.
Pass a field name and a direction, either
@ -803,7 +832,7 @@ class Cursor(object):
self.__ordering = helpers._index_document(keys)
return self
def distinct(self, key):
def distinct(self, key: str) -> List:
"""Get a list of distinct values for `key` among all documents
in the result set of this query.
@ -820,7 +849,7 @@ class Cursor(object):
.. seealso:: :meth:`pymongo.collection.Collection.distinct`
"""
options = {}
options: Dict[str, Any] = {}
if self.__spec:
options["query"] = self.__spec
if self.__max_time_ms is not None:
@ -833,7 +862,7 @@ class Cursor(object):
return self.__collection.distinct(
key, session=self.__session, **options)
def explain(self):
def explain(self) -> _DocumentType:
"""Returns an explain plan record for this cursor.
.. note:: This method uses the default verbosity mode of the
@ -863,7 +892,7 @@ class Cursor(object):
else:
self.__hint = helpers._index_document(index)
def hint(self, index):
def hint(self, index: Optional[_Hint]) -> "Cursor[_DocumentType]":
"""Adds a 'hint', telling Mongo the proper index to use for the query.
Judicious use of hints can greatly improve query
@ -888,7 +917,7 @@ class Cursor(object):
self.__set_hint(index)
return self
def comment(self, comment):
def comment(self, comment: Any) -> "Cursor[_DocumentType]":
"""Adds a 'comment' to the cursor.
http://docs.mongodb.org/manual/reference/operator/comment/
@ -903,7 +932,7 @@ class Cursor(object):
self.__comment = comment
return self
def where(self, code):
def where(self, code: Union[str, Code]) -> "Cursor[_DocumentType]":
"""Adds a `$where`_ clause to this query.
The `code` argument must be an instance of :class:`basestring`
@ -937,10 +966,18 @@ class Cursor(object):
if not isinstance(code, Code):
code = Code(code)
self.__spec["$where"] = code
# Avoid overwriting a filter argument that was given by the user
# when updating the spec.
spec: Dict[str, Any]
if self.__has_filter:
spec = dict(self.__spec)
else:
spec = cast(Dict, self.__spec)
spec["$where"] = code
self.__spec = spec
return self
def collation(self, collation):
def collation(self, collation: Optional[_CollationIn]) -> "Cursor[_DocumentType]":
"""Adds a :class:`~pymongo.collation.Collation` to this query.
Raises :exc:`TypeError` if `collation` is not an instance of
@ -1106,7 +1143,7 @@ class Cursor(object):
return len(self.__data)
@property
def alive(self):
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
This is mostly useful with `tailable cursors
@ -1128,7 +1165,7 @@ class Cursor(object):
return bool(len(self.__data) or (not self.__killed))
@property
def cursor_id(self):
def cursor_id(self) -> Optional[int]:
"""Returns the id of the cursor
.. versionadded:: 2.2
@ -1136,7 +1173,7 @@ class Cursor(object):
return self.__id
@property
def address(self):
def address(self) -> Optional[Tuple[str, Any]]:
"""The (host, port) of the server used, or None.
.. versionchanged:: 3.0
@ -1145,18 +1182,19 @@ class Cursor(object):
return self.__address
@property
def session(self):
def session(self) -> Optional["ClientSession"]:
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
.. versionadded:: 3.6
"""
if self.__explicit_session:
return self.__session
return None
def __iter__(self):
def __iter__(self) -> "Cursor[_DocumentType]":
return self
def next(self):
def next(self) -> _DocumentType:
"""Advance the cursor."""
if self.__empty:
raise StopIteration
@ -1167,20 +1205,20 @@ class Cursor(object):
__next__ = next
def __enter__(self):
def __enter__(self) -> "Cursor[_DocumentType]":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def __copy__(self):
def __copy__(self) -> "Cursor[_DocumentType]":
"""Support function for `copy.copy()`.
.. versionadded:: 2.4
"""
return self._clone(deepcopy=False)
def __deepcopy__(self, memo):
def __deepcopy__(self, memo: Any) -> Any:
"""Support function for `copy.deepcopy()`.
.. versionadded:: 2.4
@ -1193,6 +1231,7 @@ class Cursor(object):
Regular expressions cannot be deep copied but as they are immutable we
don't have to copy them when cloning.
"""
y: Any
if not hasattr(x, 'items'):
y, is_list, iterator = [], True, enumerate(x)
else:
@ -1220,13 +1259,13 @@ class Cursor(object):
return y
class RawBatchCursor(Cursor):
class RawBatchCursor(Cursor, Generic[_DocumentType]):
"""A cursor / iterator over raw batches of BSON data from a query result."""
_query_class = _RawBatchQuery
_getmore_class = _RawBatchGetMore
def __init__(self, *args, **kwargs):
def __init__(self, collection: "Collection[_DocumentType]", *args: Any, **kwargs: Any) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
Should not be called directly by application developers -
@ -1235,7 +1274,7 @@ class RawBatchCursor(Cursor):
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
"""
super(RawBatchCursor, self).__init__(*args, **kwargs)
super(RawBatchCursor, self).__init__(collection, *args, **kwargs)
def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
@ -1247,7 +1286,7 @@ class RawBatchCursor(Cursor):
_convert_raw_document_lists_to_streams(raw_response[0])
return raw_response
def explain(self):
def explain(self) -> _DocumentType:
"""Returns an explain plan record for this cursor.
.. seealso:: The MongoDB documentation on `explain <https://dochub.mongodb.org/core/explain>`_.
@ -1255,5 +1294,5 @@ class RawBatchCursor(Cursor):
clone = self._clone(deepcopy=True, base=Cursor(self.collection))
return clone.explain()
def __getitem__(self, index):
def __getitem__(self, index: Any) -> "Cursor[_DocumentType]":
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")

View File

@ -13,18 +13,21 @@
# limitations under the License.
"""Database level operations."""
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Mapping, MutableMapping, Optional,
Sequence, Union)
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
from bson.dbref import DBRef
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo import common
from pymongo.aggregation import _DatabaseAggregationCommand
from pymongo.change_stream import DatabaseChangeStream
from pymongo.collection import Collection
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (CollectionInvalid,
InvalidName)
from pymongo.read_preferences import ReadPreference
from pymongo.errors import CollectionInvalid, InvalidName
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
def _check_name(name):
@ -39,12 +42,24 @@ def _check_name(name):
"character %r" % invalid_char)
class Database(common.BaseObject):
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
from pymongo.mongo_client import MongoClient
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
class Database(common.BaseObject, Generic[_DocumentType]):
"""A Mongo database.
"""
def __init__(self, client, name, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
def __init__(self,
client: "MongoClient[_DocumentType]",
name: str,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional["WriteConcern"] = None,
read_concern: Optional["ReadConcern"] = None,
) -> None:
"""Get a database by client and name.
Raises :class:`TypeError` if `name` is not an instance of
@ -104,20 +119,24 @@ class Database(common.BaseObject):
_check_name(name)
self.__name = name
self.__client = client
self.__client: MongoClient[_DocumentType] = client
@property
def client(self):
def client(self) -> "MongoClient[_DocumentType]":
"""The client instance for this :class:`Database`."""
return self.__client
@property
def name(self):
def name(self) -> str:
"""The name of this :class:`Database`."""
return self.__name
def with_options(self, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
def with_options(self,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional["WriteConcern"] = None,
read_concern: Optional["ReadConcern"] = None,
) -> "Database[_DocumentType]":
"""Get a clone of this database changing the specified settings.
>>> db1.read_preference
@ -156,22 +175,22 @@ class Database(common.BaseObject):
write_concern or self.write_concern,
read_concern or self.read_concern)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, Database):
return (self.__client == other.client and
self.__name == other.name)
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def __hash__(self):
def __hash__(self) -> int:
return hash((self.__client, self.__name))
def __repr__(self):
return "Database(%r, %r)" % (self.__client, self.__name)
def __getattr__(self, name):
def __getattr__(self, name: str) -> Collection[_DocumentType]:
"""Get a collection of this database by name.
Raises InvalidName if an invalid collection name is used.
@ -185,7 +204,7 @@ class Database(common.BaseObject):
" collection, use database[%r]." % (name, name, name))
return self.__getitem__(name)
def __getitem__(self, name):
def __getitem__(self, name: str) -> "Collection[_DocumentType]":
"""Get a collection of this database by name.
Raises InvalidName if an invalid collection name is used.
@ -195,8 +214,13 @@ class Database(common.BaseObject):
"""
return Collection(self, name)
def get_collection(self, name, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
def get_collection(self,
name: str,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional["WriteConcern"] = None,
read_concern: Optional["ReadConcern"] = None,
) -> Collection[_DocumentType]:
"""Get a :class:`~pymongo.collection.Collection` with the given name
and options.
@ -238,9 +262,15 @@ class Database(common.BaseObject):
self, name, False, codec_options, read_preference,
write_concern, read_concern)
def create_collection(self, name, codec_options=None,
read_preference=None, write_concern=None,
read_concern=None, session=None, **kwargs):
def create_collection(self,
name: str,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional["WriteConcern"] = None,
read_concern: Optional["ReadConcern"] = None,
session: Optional["ClientSession"] = None,
**kwargs: Any,
) -> Collection[_DocumentType]:
"""Create a new :class:`~pymongo.collection.Collection` in this
database.
@ -286,20 +316,20 @@ class Database(common.BaseObject):
timeseries collections
- ``expireAfterSeconds`` (int): the number of seconds after which a
document in a timeseries collection expires
- ``validator`` (dict): a document specifying validation rules or expressions
- ``validator`` (dict): a document specifying validation rules or expressions
for the collection
- ``validationLevel`` (str): how strictly to apply the
- ``validationLevel`` (str): how strictly to apply the
validation rules to existing documents during an update. The default level
is "strict"
- ``validationAction`` (str): whether to "error" on invalid documents
(the default) or just "warn" about the violations but allow invalid
(the default) or just "warn" about the violations but allow invalid
documents to be inserted
- ``indexOptionDefaults`` (dict): a document specifying a default configuration
for indexes when creating a collection
- ``viewOn`` (str): the name of the source collection or view from which
- ``viewOn`` (str): the name of the source collection or view from which
to create the view
- ``pipeline`` (list): a list of aggregation pipeline stages
- ``comment`` (str): a user-provided comment to attach to this command.
- ``comment`` (str): a user-provided comment to attach to this command.
This option is only supported on MongoDB >= 4.4.
.. versionchanged:: 3.11
@ -330,7 +360,11 @@ class Database(common.BaseObject):
read_preference, write_concern,
read_concern, session=s, **kwargs)
def aggregate(self, pipeline, session=None, **kwargs):
def aggregate(self,
pipeline: _Pipeline,
session: Optional["ClientSession"] = None,
**kwargs: Any
) -> CommandCursor[_DocumentType]:
"""Perform a database-level aggregation.
See the `aggregation pipeline`_ documentation for a list of stages
@ -400,9 +434,17 @@ class Database(common.BaseObject):
cmd.get_cursor, cmd.get_read_preference(s), s,
retryable=not cmd._performs_write)
def watch(self, pipeline=None, full_document=None, resume_after=None,
max_await_time_ms=None, batch_size=None, collation=None,
start_at_operation_time=None, session=None, start_after=None):
def watch(self,
pipeline: Optional[_Pipeline] = None,
full_document: Optional[str] = None,
resume_after: Optional[Mapping[str, Any]] = None,
max_await_time_ms: Optional[int] = None,
batch_size: Optional[int] = None,
collation: Optional[_CollationIn] = None,
start_at_operation_time: Optional[Timestamp] = None,
session: Optional["ClientSession"] = None,
start_after: Optional[Mapping[str, Any]] = None,
) -> DatabaseChangeStream[_DocumentType]:
"""Watch changes on this database.
Performs an aggregation with an implicit initial ``$changeStream``
@ -515,9 +557,16 @@ class Database(common.BaseObject):
session=s,
client=self.__client)
def command(self, command, value=1, check=True,
allowable_errors=None, read_preference=None,
codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs):
def command(self,
command: Union[str, MutableMapping[str, Any]],
value: Any = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = None,
codec_options: Optional[CodecOptions] = DEFAULT_CODEC_OPTIONS,
session: Optional["ClientSession"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Issue a MongoDB command.
Send command `command` to the database and return the
@ -648,7 +697,11 @@ class Database(common.BaseObject):
cmd_cursor._maybe_pin_connection(sock_info)
return cmd_cursor
def list_collections(self, session=None, filter=None, **kwargs):
def list_collections(self,
session: Optional["ClientSession"] = None,
filter: Optional[Mapping[str, Any]] = None,
**kwargs: Any
) -> CommandCursor[Dict[str, Any]]:
"""Get a cursor over the collections of this database.
:Parameters:
@ -680,7 +733,11 @@ class Database(common.BaseObject):
return self.__client._retryable_read(
_cmd, read_pref, session)
def list_collection_names(self, session=None, filter=None, **kwargs):
def list_collection_names(self,
session: Optional["ClientSession"] = None,
filter: Optional[Mapping[str, Any]] = None,
**kwargs: Any
) -> List[str]:
"""Get a list of all the collection names in this database.
For example, to list all non-system collections::
@ -717,7 +774,10 @@ class Database(common.BaseObject):
return [result["name"]
for result in self.list_collections(session=session, **kwargs)]
def drop_collection(self, name_or_collection, session=None):
def drop_collection(self,
name_or_collection: Union[str, Collection],
session: Optional["ClientSession"] = None
) -> Dict[str, Any]:
"""Drop a collection.
:Parameters:
@ -752,9 +812,13 @@ class Database(common.BaseObject):
parse_write_concern_error=True,
session=session)
def validate_collection(self, name_or_collection,
scandata=False, full=False, session=None,
background=None):
def validate_collection(self,
name_or_collection: Union[str, Collection],
scandata: bool = False,
full: bool = False,
session: Optional["ClientSession"] = None,
background: Optional[bool] = None,
) -> Dict[str, Any]:
"""Validate a collection.
Returns a dict of validation info. Raises CollectionInvalid if
@ -827,20 +891,23 @@ class Database(common.BaseObject):
return result
def __iter__(self):
def __iter__(self) -> "Database[_DocumentType]":
return self
def __next__(self):
def __next__(self) -> "Database[_DocumentType]":
raise TypeError("'Database' object is not iterable")
next = __next__
def __bool__(self):
def __bool__(self) -> bool:
raise NotImplementedError("Database objects do not implement truth "
"value testing or bool(). Please compare "
"with None instead: database is not None")
def dereference(self, dbref, session=None, **kwargs):
def dereference(self, dbref: DBRef,
session: Optional["ClientSession"] = None,
**kwargs: Any
) -> Optional[_DocumentType]:
"""Dereference a :class:`~bson.dbref.DBRef`, getting the
document it points to.

View File

@ -15,6 +15,7 @@
"""Advanced options for MongoDB drivers implemented on top of PyMongo."""
from collections import namedtuple
from typing import Optional
class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])):
@ -26,7 +27,7 @@ class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])):
like 'MyDriver', '1.2.3', 'some platform info'. Any of these strings may be
None to accept PyMongo's default.
"""
def __new__(cls, name, version=None, platform=None):
def __new__(cls, name: str, version: Optional[str] = None, platform: Optional[str] = None) -> "DriverInfo":
self = super(DriverInfo, cls).__new__(cls, name, version, platform)
for key, value in self._asdict().items():
if value is not None and not isinstance(value, str):

View File

@ -15,15 +15,16 @@
"""Support for explicit client-side field level encryption."""
import contextlib
import os
import subprocess
import uuid
import weakref
from typing import Any, Mapping, Optional, Sequence
try:
from pymongocrypt.auto_encrypter import AutoEncrypter
from pymongocrypt.errors import MongoCryptError
from pymongocrypt.explicit_encrypter import ExplicitEncrypter
from pymongocrypt.explicit_encrypter import (
ExplicitEncrypter
)
from pymongocrypt.mongocrypt import MongoCryptOptions
from pymongocrypt.state_machine import MongoCryptCallback
_HAVE_PYMONGOCRYPT = True
@ -32,29 +33,22 @@ except ImportError:
MongoCryptCallback = object
from bson import _dict_to_bson, decode, encode
from bson.binary import STANDARD, UUID_SUBTYPE, Binary
from bson.codec_options import CodecOptions
from bson.binary import (Binary,
STANDARD,
UUID_SUBTYPE)
from bson.errors import BSONError
from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS,
RawBSONDocument,
from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument,
_inflate_bson)
from bson.son import SON
from pymongo.errors import (ConfigurationError,
EncryptionError,
InvalidOperation,
ServerSelectionTimeoutError)
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts
from pymongo.errors import (ConfigurationError, EncryptionError,
InvalidOperation, ServerSelectionTimeoutError)
from pymongo.mongo_client import MongoClient
from pymongo.pool import _configured_socket, PoolOptions
from pymongo.pool import PoolOptions, _configured_socket
from pymongo.read_concern import ReadConcern
from pymongo.ssl_support import get_ssl_context
from pymongo.uri_parser import parse_host
from pymongo.write_concern import WriteConcern
from pymongo.daemon import _spawn_daemon
_HTTPS_PORT = 443
_KMS_CONNECT_TIMEOUT = 10 # TODO: CDRIVER-3262 will define this value.
@ -80,9 +74,10 @@ def _wrap_encryption_errors():
raise EncryptionError(exc)
class _EncryptionIO(MongoCryptCallback):
class _EncryptionIO(MongoCryptCallback): # type: ignore
def __init__(self, client, key_vault_coll, mongocryptd_client, opts):
"""Internal class to perform I/O on behalf of pymongocrypt."""
self.client_ref: Any
# Use a weak ref to break reference cycle.
if client is not None:
self.client_ref = weakref.ref(client)
@ -355,11 +350,17 @@ class Algorithm(object):
"AEAD_AES_256_CBC_HMAC_SHA_512-Random")
class ClientEncryption(object):
"""Explicit client-side field level encryption."""
def __init__(self, kms_providers, key_vault_namespace, key_vault_client,
codec_options, kms_tls_options=None):
def __init__(self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: MongoClient,
codec_options: CodecOptions,
kms_tls_options: Optional[Mapping[str, Any]] = None
) -> None:
"""Explicit client-side field level encryption.
The ClientEncryption class encapsulates explicit operations on a key
@ -449,12 +450,15 @@ class ClientEncryption(object):
opts = AutoEncryptionOpts(kms_providers, key_vault_namespace,
kms_tls_options=kms_tls_options)
self._io_callbacks = _EncryptionIO(None, key_vault_coll, None, opts)
self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(None, key_vault_coll, None, opts)
self._encryption = ExplicitEncrypter(
self._io_callbacks, MongoCryptOptions(kms_providers, None))
def create_data_key(self, kms_provider, master_key=None,
key_alt_names=None):
def create_data_key(self,
kms_provider: str,
master_key: Optional[Mapping[str, Any]] = None,
key_alt_names: Optional[Sequence[str]] = None
) -> Binary:
"""Create and insert a new data key into the key vault collection.
:Parameters:
@ -526,7 +530,12 @@ class ClientEncryption(object):
kms_provider, master_key=master_key,
key_alt_names=key_alt_names)
def encrypt(self, value, algorithm, key_id=None, key_alt_name=None):
def encrypt(self,
value: Any,
algorithm: str,
key_id: Optional[Binary] = None,
key_alt_name: Optional[str] = None
) -> Binary:
"""Encrypt a BSON value with a given key and algorithm.
Note that exactly one of ``key_id`` or ``key_alt_name`` must be
@ -557,7 +566,7 @@ class ClientEncryption(object):
doc, algorithm, key_id=key_id, key_alt_name=key_alt_name)
return decode(encrypted_doc)['v']
def decrypt(self, value):
def decrypt(self, value: Binary) -> Any:
"""Decrypt an encrypted value.
:Parameters:
@ -578,17 +587,17 @@ class ClientEncryption(object):
return decode(decrypted_doc,
codec_options=self._codec_options)['v']
def __enter__(self):
def __enter__(self) -> "ClientEncryption":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def _check_closed(self):
if self._encryption is None:
raise InvalidOperation("Cannot use closed ClientEncryption")
def close(self):
def close(self) -> None:
"""Release resources.
Note that using this class in a with-statement will automatically call

View File

@ -15,6 +15,7 @@
"""Support for automatic client-side field level encryption."""
import copy
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
try:
import pymongocrypt
@ -25,19 +26,25 @@ except ImportError:
from pymongo.errors import ConfigurationError
from pymongo.uri_parser import _parse_kms_tls_options
if TYPE_CHECKING:
from pymongo.mongo_client import MongoClient
class AutoEncryptionOpts(object):
"""Options to configure automatic client-side field level encryption."""
def __init__(self, kms_providers, key_vault_namespace,
key_vault_client=None,
schema_map=None,
bypass_auto_encryption=False,
mongocryptd_uri='mongodb://localhost:27020',
mongocryptd_bypass_spawn=False,
mongocryptd_spawn_path='mongocryptd',
mongocryptd_spawn_args=None,
kms_tls_options=None):
def __init__(self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: Optional["MongoClient"] = None,
schema_map: Optional[Mapping[str, Any]] = None,
bypass_auto_encryption: Optional[bool] = False,
mongocryptd_uri: str = 'mongodb://localhost:27020',
mongocryptd_bypass_spawn: bool = False,
mongocryptd_spawn_path: str = 'mongocryptd',
mongocryptd_spawn_args: Optional[List[str]] = None,
kms_tls_options: Optional[Mapping[str, Any]] = None
) -> None:
"""Options to configure automatic client-side field level encryption.
Automatic client-side field level encryption requires MongoDB 4.2
@ -152,8 +159,9 @@ class AutoEncryptionOpts(object):
self._mongocryptd_uri = mongocryptd_uri
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
self._mongocryptd_spawn_path = mongocryptd_spawn_path
self._mongocryptd_spawn_args = (copy.copy(mongocryptd_spawn_args) or
['--idleShutdownTimeoutSecs=60'])
if mongocryptd_spawn_args is None:
mongocryptd_spawn_args = ['--idleShutdownTimeoutSecs=60']
self._mongocryptd_spawn_args = mongocryptd_spawn_args
if not isinstance(self._mongocryptd_spawn_args, list):
raise TypeError('mongocryptd_spawn_args must be a list')
if not any('idleShutdownTimeoutSecs' in s

View File

@ -13,6 +13,8 @@
# limitations under the License.
"""Exceptions raised by PyMongo."""
from typing import (Any, Iterable, List, Mapping, Optional, Sequence, Tuple,
Union)
from bson.errors import *
@ -23,18 +25,21 @@ except ImportError:
try:
from ssl import CertificateError as _CertificateError
except ImportError:
class _CertificateError(ValueError):
class _CertificateError(ValueError): # type: ignore
pass
class PyMongoError(Exception):
"""Base class for all PyMongo exceptions."""
def __init__(self, message='', error_labels=None):
def __init__(self,
message: str = '',
error_labels: Optional[Iterable[str]] = None
) -> None:
super(PyMongoError, self).__init__(message)
self._message = message
self._error_labels = set(error_labels or [])
def has_error_label(self, label):
def has_error_label(self, label: str) -> bool:
"""Return True if this error contains the given label.
.. versionadded:: 3.7
@ -70,10 +75,17 @@ class AutoReconnect(ConnectionFailure):
Subclass of :exc:`~pymongo.errors.ConnectionFailure`.
"""
def __init__(self, message='', errors=None):
errors: Union[Mapping[str, Any], Sequence]
details: Union[Mapping[str, Any], Sequence]
def __init__(self,
message: str = '',
errors: Optional[Union[Mapping[str, Any], Sequence]] = None
) -> None:
error_labels = None
if errors is not None and isinstance(errors, dict):
error_labels = errors.get('errorLabels')
if errors is not None:
if isinstance(errors, dict):
error_labels = errors.get('errorLabels')
super(AutoReconnect, self).__init__(message, error_labels)
self.errors = self.details = errors or []
@ -109,7 +121,10 @@ class NotPrimaryError(AutoReconnect):
.. versionadded:: 3.12
"""
def __init__(self, message='', errors=None):
def __init__(self,
message: str = '',
errors: Optional[Union[Mapping[str, Any], List]] = None
) -> None:
super(NotPrimaryError, self).__init__(
_format_detailed_error(message, errors), errors=errors)
@ -139,7 +154,12 @@ class OperationFailure(PyMongoError):
The :attr:`details` attribute.
"""
def __init__(self, error, code=None, details=None, max_wire_version=None):
def __init__(self,
error: str,
code: Optional[int] = None,
details: Optional[Mapping[str, Any]] = None,
max_wire_version: Optional[int] = None,
) -> None:
error_labels = None
if details is not None:
error_labels = details.get('errorLabels')
@ -154,13 +174,13 @@ class OperationFailure(PyMongoError):
return self.__max_wire_version
@property
def code(self):
def code(self) -> Optional[int]:
"""The error code returned by the server, if any.
"""
return self.__code
@property
def details(self):
def details(self) -> Optional[Mapping[str, Any]]:
"""The complete error document returned by the server.
Depending on the error that occurred, the error document
@ -225,14 +245,17 @@ class BulkWriteError(OperationFailure):
.. versionadded:: 2.7
"""
def __init__(self, results):
details: Mapping[str, Any]
def __init__(self, results: Mapping[str, Any]) -> None:
super(BulkWriteError, self).__init__(
"batch op errors occurred", 65, results)
def __reduce__(self):
def __reduce__(self) -> Tuple[Any, Any]:
return self.__class__, (self.details,)
class InvalidOperation(PyMongoError):
"""Raised when a client attempts to perform an invalid operation."""
@ -264,12 +287,12 @@ class EncryptionError(PyMongoError):
.. versionadded:: 3.9
"""
def __init__(self, cause):
def __init__(self, cause: Exception) -> None:
super(EncryptionError, self).__init__(str(cause))
self.__cause = cause
@property
def cause(self):
def cause(self) -> Exception:
"""The exception that caused this encryption or decryption error."""
return self.__cause

View File

@ -26,8 +26,6 @@ or
``MongoClient(event_listeners=[CommandLogger()])``
"""
import logging
from pymongo import monitoring
@ -42,18 +40,18 @@ class CommandLogger(monitoring.CommandListener):
logs them at the `INFO` severity level using :mod:`logging`.
.. versionadded:: 3.11
"""
def started(self, event):
def started(self, event: monitoring.CommandStartedEvent) -> None:
logging.info("Command {0.command_name} with request id "
"{0.request_id} started on server "
"{0.connection_id}".format(event))
def succeeded(self, event):
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
logging.info("Command {0.command_name} with request id "
"{0.request_id} on server {0.connection_id} "
"succeeded in {0.duration_micros} "
"microseconds".format(event))
def failed(self, event):
def failed(self, event: monitoring.CommandFailedEvent) -> None:
logging.info("Command {0.command_name} with request id "
"{0.request_id} on server {0.connection_id} "
"failed in {0.duration_micros} "
@ -70,11 +68,11 @@ class ServerLogger(monitoring.ServerListener):
.. versionadded:: 3.11
"""
def opened(self, event):
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
logging.info("Server {0.server_address} added to topology "
"{0.topology_id}".format(event))
def description_changed(self, event):
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
previous_server_type = event.previous_description.server_type
new_server_type = event.new_description.server_type
if new_server_type != previous_server_type:
@ -84,7 +82,7 @@ class ServerLogger(monitoring.ServerListener):
"{0.previous_description.server_type_name} to "
"{0.new_description.server_type_name}".format(event))
def closed(self, event):
def closed(self, event: monitoring.ServerClosedEvent) -> None:
logging.warning("Server {0.server_address} removed from topology "
"{0.topology_id}".format(event))
@ -99,17 +97,17 @@ class HeartbeatLogger(monitoring.ServerHeartbeatListener):
.. versionadded:: 3.11
"""
def started(self, event):
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
logging.info("Heartbeat sent to server "
"{0.connection_id}".format(event))
def succeeded(self, event):
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
# The reply.document attribute was added in PyMongo 3.4.
logging.info("Heartbeat to server {0.connection_id} "
"succeeded with reply "
"{0.reply.document}".format(event))
def failed(self, event):
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
logging.warning("Heartbeat to server {0.connection_id} "
"failed with error {0.reply}".format(event))
@ -124,11 +122,11 @@ class TopologyLogger(monitoring.TopologyListener):
.. versionadded:: 3.11
"""
def opened(self, event):
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
logging.info("Topology with id {0.topology_id} "
"opened".format(event))
def description_changed(self, event):
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
logging.info("Topology description updated for "
"topology id {0.topology_id}".format(event))
previous_topology_type = event.previous_description.topology_type
@ -146,7 +144,7 @@ class TopologyLogger(monitoring.TopologyListener):
if not event.new_description.has_readable_server():
logging.warning("No readable servers available.")
def closed(self, event):
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
logging.info("Topology with id {0.topology_id} "
"closed".format(event))
@ -168,43 +166,43 @@ class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
.. versionadded:: 3.11
"""
def pool_created(self, event):
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
logging.info("[pool {0.address}] pool created".format(event))
def pool_ready(self, event):
logging.info("[pool {0.address}] pool ready".format(event))
def pool_cleared(self, event):
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
logging.info("[pool {0.address}] pool cleared".format(event))
def pool_closed(self, event):
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
logging.info("[pool {0.address}] pool closed".format(event))
def connection_created(self, event):
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
logging.info("[pool {0.address}][conn #{0.connection_id}] "
"connection created".format(event))
def connection_ready(self, event):
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
logging.info("[pool {0.address}][conn #{0.connection_id}] "
"connection setup succeeded".format(event))
def connection_closed(self, event):
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
logging.info("[pool {0.address}][conn #{0.connection_id}] "
"connection closed, reason: "
"{0.reason}".format(event))
def connection_check_out_started(self, event):
def connection_check_out_started(self, event: monitoring.ConnectionCheckOutStartedEvent) -> None:
logging.info("[pool {0.address}] connection check out "
"started".format(event))
def connection_check_out_failed(self, event):
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
logging.info("[pool {0.address}] connection check out "
"failed, reason: {0.reason}".format(event))
def connection_checked_out(self, event):
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
logging.info("[pool {0.address}][conn #{0.connection_id}] "
"connection checked out of pool".format(event))
def connection_checked_in(self, event):
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
logging.info("[pool {0.address}][conn #{0.connection_id}] "
"connection checked into pool".format(event))

View File

@ -14,10 +14,15 @@
"""Helpers for the 'hello' and legacy hello commands."""
import copy
import datetime
import itertools
from typing import Any, Generic, List, Mapping, Optional, Set, Tuple
from bson.objectid import ObjectId
from pymongo import common
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import _DocumentType
class HelloCompat:
@ -56,7 +61,7 @@ def _get_server_type(doc):
return SERVER_TYPE.Standalone
class Hello(object):
class Hello(Generic[_DocumentType]):
"""Parse a hello response from the server.
.. versionadded:: 3.12
@ -64,9 +69,9 @@ class Hello(object):
__slots__ = ('_doc', '_server_type', '_is_writable', '_is_readable',
'_awaitable')
def __init__(self, doc, awaitable=False):
def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None:
self._server_type = _get_server_type(doc)
self._doc = doc
self._doc: _DocumentType = doc
self._is_writable = self._server_type in (
SERVER_TYPE.RSPrimary,
SERVER_TYPE.Standalone,
@ -79,19 +84,19 @@ class Hello(object):
self._awaitable = awaitable
@property
def document(self):
def document(self) -> _DocumentType:
"""The complete hello command response document.
.. versionadded:: 3.4
"""
return self._doc.copy()
return copy.copy(self._doc)
@property
def server_type(self):
def server_type(self) -> int:
return self._server_type
@property
def all_hosts(self):
def all_hosts(self) -> Set[Tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return set(map(common.clean_node, itertools.chain(
self._doc.get('hosts', []),
@ -99,12 +104,12 @@ class Hello(object):
self._doc.get('arbiters', []))))
@property
def tags(self):
def tags(self) -> Mapping[str, Any]:
"""Replica set member tags or empty dict."""
return self._doc.get('tags', {})
@property
def primary(self):
def primary(self) -> Optional[Tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
if self._doc.get('primary'):
return common.partition_node(self._doc['primary'])
@ -112,70 +117,71 @@ class Hello(object):
return None
@property
def replica_set_name(self):
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self._doc.get('setName')
@property
def max_bson_size(self):
def max_bson_size(self) -> int:
return self._doc.get('maxBsonObjectSize', common.MAX_BSON_SIZE)
@property
def max_message_size(self):
def max_message_size(self) -> int:
return self._doc.get('maxMessageSizeBytes', 2 * self.max_bson_size)
@property
def max_write_batch_size(self):
def max_write_batch_size(self) -> int:
return self._doc.get('maxWriteBatchSize', common.MAX_WRITE_BATCH_SIZE)
@property
def min_wire_version(self):
def min_wire_version(self) -> int:
return self._doc.get('minWireVersion', common.MIN_WIRE_VERSION)
@property
def max_wire_version(self):
def max_wire_version(self) -> int:
return self._doc.get('maxWireVersion', common.MAX_WIRE_VERSION)
@property
def set_version(self):
def set_version(self) -> Optional[int]:
return self._doc.get('setVersion')
@property
def election_id(self):
def election_id(self) -> Optional[ObjectId]:
return self._doc.get('electionId')
@property
def cluster_time(self):
def cluster_time(self) -> Optional[Mapping[str, Any]]:
return self._doc.get('$clusterTime')
@property
def logical_session_timeout_minutes(self):
def logical_session_timeout_minutes(self) -> Optional[int]:
return self._doc.get('logicalSessionTimeoutMinutes')
@property
def is_writable(self):
def is_writable(self) -> bool:
return self._is_writable
@property
def is_readable(self):
def is_readable(self) -> bool:
return self._is_readable
@property
def me(self):
def me(self) -> Optional[Tuple[str, int]]:
me = self._doc.get('me')
if me:
return common.clean_node(me)
return None
@property
def last_write_date(self):
def last_write_date(self) -> Optional[datetime.datetime]:
return self._doc.get('lastWrite', {}).get('lastWriteDate')
@property
def compressors(self):
def compressors(self) -> Optional[List[str]]:
return self._doc.get('compression')
@property
def sasl_supported_mechs(self):
def sasl_supported_mechs(self) -> List[str]:
"""Supported authentication mechanisms for the current user.
For example::
@ -187,22 +193,22 @@ class Hello(object):
return self._doc.get('saslSupportedMechs', [])
@property
def speculative_authenticate(self):
def speculative_authenticate(self) -> Optional[Mapping[str, Any]]:
"""The speculativeAuthenticate field."""
return self._doc.get('speculativeAuthenticate')
@property
def topology_version(self):
def topology_version(self) -> Optional[Mapping[str, Any]]:
return self._doc.get('topologyVersion')
@property
def awaitable(self):
def awaitable(self) -> bool:
return self._awaitable
@property
def service_id(self):
def service_id(self) -> Optional[ObjectId]:
return self._doc.get('serviceId')
@property
def hello_ok(self):
def hello_ok(self) -> bool:
return self._doc.get('helloOk', False)

View File

@ -16,18 +16,14 @@
import sys
import traceback
from collections import abc
from typing import Any
from bson.son import SON
from pymongo import ASCENDING
from pymongo.errors import (CursorNotFound,
DuplicateKeyError,
ExecutionTimeout,
NotPrimaryError,
OperationFailure,
WriteError,
WriteConcernError,
from pymongo.errors import (CursorNotFound, DuplicateKeyError,
ExecutionTimeout, NotPrimaryError,
OperationFailure, WriteConcernError, WriteError,
WTimeoutError)
from pymongo.hello import HelloCompat
@ -95,7 +91,7 @@ def _index_document(index_list):
if not len(index_list):
raise ValueError("key_or_list must not be the empty list")
index = SON()
index: SON[str, Any] = SON()
for (key, value) in index_list:
if not isinstance(key, str):
raise TypeError(

View File

@ -23,8 +23,8 @@ MongoDB.
import datetime
import random
import struct
from io import BytesIO as _BytesIO
from typing import Any
import bson
from bson import (CodecOptions,
@ -32,30 +32,24 @@ from bson import (CodecOptions,
_decode_selective,
_dict_to_bson,
_make_c_string)
from bson import codec_options
from bson.int64 import Int64
from bson.raw_bson import (_inflate_bson, DEFAULT_RAW_BSON_OPTIONS,
RawBSONDocument)
from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument,
_inflate_bson)
from bson.son import SON
try:
from pymongo import _cmessage
from pymongo import _cmessage # type: ignore[attr-defined]
_use_c = True
except ImportError:
_use_c = False
from pymongo.errors import (ConfigurationError,
CursorNotFound,
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NotPrimaryError,
OperationFailure,
ProtocolError)
from pymongo.errors import (ConfigurationError, CursorNotFound,
DocumentTooLarge, ExecutionTimeout,
InvalidOperation, NotPrimaryError,
OperationFailure, ProtocolError)
from pymongo.hello import HelloCompat
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
MAX_INT32 = 2147483647
MIN_INT32 = -2147483648
@ -457,6 +451,7 @@ class _RawBatchGetMore(_GetMore):
class _CursorAddress(tuple):
"""The server address (host, port) of a cursor, with namespace property."""
__namespace: Any
def __new__(cls, address, namespace):
self = tuple.__new__(cls, address)
@ -762,6 +757,7 @@ class _BulkWriteContext(object):
"""A proxy for SocketInfo.unack_write that handles event publishing.
"""
if self.publish:
assert self.start_time is not None
duration = datetime.datetime.now() - self.start_time
cmd = self._start(cmd, request_id, docs)
start = datetime.datetime.now()
@ -777,6 +773,7 @@ class _BulkWriteContext(object):
self._succeed(request_id, reply, duration)
except Exception as exc:
if self.publish:
assert self.start_time is not None
duration = (datetime.datetime.now() - start) + duration
if isinstance(exc, OperationFailure):
failure = _convert_write_result(
@ -795,6 +792,7 @@ class _BulkWriteContext(object):
"""A proxy for SocketInfo.write_command that handles event publishing.
"""
if self.publish:
assert self.start_time is not None
duration = datetime.datetime.now() - self.start_time
self._start(cmd, request_id, docs)
start = datetime.datetime.now()
@ -1171,7 +1169,8 @@ class _OpReply(object):
if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR):
raise NotPrimaryError(error_object["$err"], error_object)
elif error_object.get("code") == 50:
raise ExecutionTimeout(error_object.get("$err"),
default_msg = "operation exceeded time limit"
raise ExecutionTimeout(error_object.get("$err", default_msg),
error_object.get("code"),
error_object)
raise OperationFailure("database error: %s" %

View File

@ -34,46 +34,41 @@ access:
import contextlib
import threading
import weakref
from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Dict, FrozenSet, Generic, List,
Mapping, Optional, Sequence, Set, Tuple, Type, Union, cast)
from bson.codec_options import DEFAULT_CODEC_OPTIONS
import bson
from bson.codec_options import (DEFAULT_CODEC_OPTIONS, CodecOptions,
TypeRegistry)
from bson.son import SON
from pymongo import (common,
database,
helpers,
message,
periodic_executor,
uri_parser,
client_session)
from pymongo.change_stream import ClusterChangeStream
from bson.timestamp import Timestamp
from pymongo import (client_session, common, database, helpers, message,
periodic_executor, uri_parser)
from pymongo.change_stream import ChangeStream, ClusterChangeStream
from pymongo.client_options import ClientOptions
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (AutoReconnect,
BulkWriteError,
ConfigurationError,
ConnectionFailure,
InvalidOperation,
NotPrimaryError,
OperationFailure,
PyMongoError,
from pymongo.errors import (AutoReconnect, BulkWriteError, ConfigurationError,
ConnectionFailure, InvalidOperation,
NotPrimaryError, OperationFailure, PyMongoError,
ServerSelectionTimeoutError)
from pymongo.pool import ConnectionClosedReason
from pymongo.read_preferences import ReadPreference
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.topology import (Topology,
_ErrorContext)
from pymongo.topology_description import TOPOLOGY_TYPE
from pymongo.settings import TopologySettings
from pymongo.uri_parser import (_handle_option_deprecations,
_handle_security_options,
_normalize_options,
_check_options)
from pymongo.write_concern import DEFAULT_WRITE_CONCERN
from pymongo.topology import Topology, _ErrorContext
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
from pymongo.uri_parser import (_check_options, _handle_option_deprecations,
_handle_security_options, _normalize_options)
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
if TYPE_CHECKING:
from pymongo.read_concern import ReadConcern
class MongoClient(common.BaseObject):
class MongoClient(common.BaseObject, Generic[_DocumentType]):
"""
A client-side representation of a MongoDB cluster.
@ -89,15 +84,15 @@ class MongoClient(common.BaseObject):
# No host/port; these are retrieved from TopologySettings.
_constructor_args = ('document_class', 'tz_aware', 'connect')
def __init__(
self,
host=None,
port=None,
document_class=dict,
tz_aware=None,
connect=None,
type_registry=None,
**kwargs):
def __init__(self,
host: Optional[Union[str, Sequence[str]]] = None,
port: Optional[int] = None,
document_class: Type[_DocumentType] = dict,
tz_aware: Optional[bool] = None,
connect: Optional[bool] = None,
type_registry: Optional[TypeRegistry] = None,
**kwargs: Any,
) -> None:
"""Client for a MongoDB instance, a replica set, or a set of mongoses.
The client object is thread-safe and has connection-pooling built in.
@ -621,7 +616,7 @@ class MongoClient(common.BaseObject):
client.__my_database__
"""
self.__init_kwargs = {'host': host,
self.__init_kwargs: Dict[str, Any] = {'host': host,
'port': port,
'document_class': document_class,
'tz_aware': tz_aware,
@ -722,7 +717,7 @@ class MongoClient(common.BaseObject):
self.__default_database_name = dbase
self.__lock = threading.Lock()
self.__kill_cursors_queue = []
self.__kill_cursors_queue: List = []
self._event_listeners = options.pool_options._event_listeners
super(MongoClient, self).__init__(options.codec_options,
@ -765,7 +760,7 @@ class MongoClient(common.BaseObject):
# We strongly reference the executor and it weakly references us via
# this closure. When the client is freed, stop the executor soon.
self_ref = weakref.ref(self, executor.close)
self_ref: Any = weakref.ref(self, executor.close)
self._kill_cursors_executor = executor
if connect:
@ -798,9 +793,17 @@ class MongoClient(common.BaseObject):
return getattr(server.description, attr_name)
def watch(self, pipeline=None, full_document=None, resume_after=None,
max_await_time_ms=None, batch_size=None, collation=None,
start_at_operation_time=None, session=None, start_after=None):
def watch(self,
pipeline: Optional[_Pipeline] = None,
full_document: Optional[str] = None,
resume_after: Optional[Mapping[str, Any]] = None,
max_await_time_ms: Optional[int] = None,
batch_size: Optional[int] = None,
collation: Optional[_CollationIn] = None,
start_at_operation_time: Optional[Timestamp] = None,
session: Optional[client_session.ClientSession] = None,
start_after: Optional[Mapping[str, Any]] = None,
) -> ChangeStream[_DocumentType]:
"""Watch changes on this cluster.
Performs an aggregation with an implicit initial ``$changeStream``
@ -891,7 +894,7 @@ class MongoClient(common.BaseObject):
start_after)
@property
def topology_description(self):
def topology_description(self) -> TopologyDescription:
"""The description of the connected MongoDB deployment.
>>> client.topology_description
@ -913,7 +916,7 @@ class MongoClient(common.BaseObject):
return self._topology.description
@property
def address(self):
def address(self) -> Optional[Tuple[str, int]]:
"""(host, port) of the current standalone, primary, or mongos, or None.
Accessing :attr:`address` raises :exc:`~.errors.InvalidOperation` if
@ -940,7 +943,7 @@ class MongoClient(common.BaseObject):
return self._server_property('address')
@property
def primary(self):
def primary(self) -> Optional[Tuple[str, int]]:
"""The (host, port) of the current primary of the replica set.
Returns ``None`` if this client is not connected to a replica set,
@ -953,7 +956,7 @@ class MongoClient(common.BaseObject):
return self._topology.get_primary()
@property
def secondaries(self):
def secondaries(self) -> Set[Tuple[str, int]]:
"""The secondary members known to this client.
A sequence of (host, port) pairs. Empty if this client is not
@ -966,7 +969,7 @@ class MongoClient(common.BaseObject):
return self._topology.get_secondaries()
@property
def arbiters(self):
def arbiters(self) -> Set[Tuple[str, int]]:
"""Arbiters in the replica set.
A sequence of (host, port) pairs. Empty if this client is not
@ -976,7 +979,7 @@ class MongoClient(common.BaseObject):
return self._topology.get_arbiters()
@property
def is_primary(self):
def is_primary(self) -> bool:
"""If this client is connected to a server that can accept writes.
True if the current server is a standalone, mongos, or the primary of
@ -987,7 +990,7 @@ class MongoClient(common.BaseObject):
return self._server_property('is_writable')
@property
def is_mongos(self):
def is_mongos(self) -> bool:
"""If this client is connected to mongos. If the client is not
connected, this will block until a connection is established or raise
ServerSelectionTimeoutError if no server is available..
@ -995,7 +998,7 @@ class MongoClient(common.BaseObject):
return self._server_property('server_type') == SERVER_TYPE.Mongos
@property
def nodes(self):
def nodes(self) -> FrozenSet[Tuple[str, Optional[int]]]:
"""Set of all currently connected servers.
.. warning:: When connected to a replica set the value of :attr:`nodes`
@ -1009,7 +1012,7 @@ class MongoClient(common.BaseObject):
return frozenset(s.address for s in description.known_servers)
@property
def options(self):
def options(self) -> ClientOptions:
"""The configuration options for this client.
:Returns:
@ -1040,7 +1043,7 @@ class MongoClient(common.BaseObject):
# command.
pass
def close(self):
def close(self) -> None:
"""Cleanup client resources and disconnect from MongoDB.
End all server sessions created by this client by sending one or more
@ -1214,7 +1217,7 @@ class MongoClient(common.BaseObject):
def _retry_internal(self, retryable, func, session, bulk):
"""Internal retryable write helper."""
max_wire_version = 0
last_error = None
last_error: Optional[Exception] = None
retrying = False
def is_retrying():
@ -1239,6 +1242,7 @@ class MongoClient(common.BaseObject):
if is_retrying():
# A retry is not possible because this server does
# not support sessions raise the last error.
assert last_error is not None
raise last_error
retryable = False
return func(session, sock_info, retryable)
@ -1247,6 +1251,7 @@ class MongoClient(common.BaseObject):
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
assert last_error is not None
raise last_error
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
@ -1280,7 +1285,7 @@ class MongoClient(common.BaseObject):
retryable = (retryable and
self.options.retry_reads
and not (session and session.in_transaction))
last_error = None
last_error: Optional[Exception] = None
retrying = False
while True:
@ -1292,6 +1297,7 @@ class MongoClient(common.BaseObject):
if retrying and not retryable:
# A retry is not possible because this server does
# not support retryable reads, raise the last error.
assert last_error is not None
raise last_error
return func(session, server, sock_info, read_pref)
except ServerSelectionTimeoutError:
@ -1299,6 +1305,7 @@ class MongoClient(common.BaseObject):
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
assert last_error is not None
raise last_error
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
@ -1322,15 +1329,15 @@ class MongoClient(common.BaseObject):
with self._tmp_session(session) as s:
return self._retry_with_session(retryable, func, s, None)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self._topology == other._topology
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def __hash__(self):
def __hash__(self) -> int:
return hash(self._topology)
def _repr_helper(self):
@ -1366,7 +1373,7 @@ class MongoClient(common.BaseObject):
def __repr__(self):
return ("MongoClient(%s)" % (self._repr_helper(),))
def __getattr__(self, name):
def __getattr__(self, name: str) -> database.Database[_DocumentType]:
"""Get a database by name.
Raises :class:`~pymongo.errors.InvalidName` if an invalid
@ -1381,7 +1388,7 @@ class MongoClient(common.BaseObject):
" database, use client[%r]." % (name, name, name))
return self.__getitem__(name)
def __getitem__(self, name):
def __getitem__(self, name: str) -> database.Database[_DocumentType]:
"""Get a database by name.
Raises :class:`~pymongo.errors.InvalidName` if an invalid
@ -1539,9 +1546,10 @@ class MongoClient(common.BaseObject):
self, server_session, opts, implicit)
def start_session(self,
causal_consistency=None,
default_transaction_options=None,
snapshot=False):
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
) -> client_session.ClientSession[_DocumentType]:
"""Start a logical session.
This method takes the same parameters as
@ -1630,7 +1638,9 @@ class MongoClient(common.BaseObject):
if session is not None:
session._process_response(reply)
def server_info(self, session=None):
def server_info(self,
session: Optional[client_session.ClientSession] = None
) -> Dict[str, Any]:
"""Get information about the MongoDB server we're connected to.
:Parameters:
@ -1644,7 +1654,10 @@ class MongoClient(common.BaseObject):
read_preference=ReadPreference.PRIMARY,
session=session)
def list_databases(self, session=None, **kwargs):
def list_databases(self,
session: Optional[client_session.ClientSession] = None,
**kwargs: Any
) -> CommandCursor[Dict[str, Any]]:
"""Get a cursor over the databases of the connected server.
:Parameters:
@ -1673,7 +1686,9 @@ class MongoClient(common.BaseObject):
}
return CommandCursor(admin["$cmd"], cursor, None)
def list_database_names(self, session=None):
def list_database_names(self,
session: Optional[client_session.ClientSession] = None
) -> List[str]:
"""Get a list of the names of all databases on the connected server.
:Parameters:
@ -1685,7 +1700,10 @@ class MongoClient(common.BaseObject):
return [doc["name"]
for doc in self.list_databases(session, nameOnly=True)]
def drop_database(self, name_or_database, session=None):
def drop_database(self,
name_or_database: Union[str, database.Database],
session: Optional[client_session.ClientSession] = None
) -> None:
"""Drop a database.
Raises :class:`TypeError` if `name_or_database` is not an instance of
@ -1727,8 +1745,13 @@ class MongoClient(common.BaseObject):
parse_write_concern_error=True,
session=session)
def get_default_database(self, default=None, codec_options=None,
read_preference=None, write_concern=None, read_concern=None):
def get_default_database(self,
default: Optional[str] = None,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional["ReadConcern"] = None,
) -> database.Database[_DocumentType]:
"""Get the database named in the MongoDB connection URI.
>>> uri = 'mongodb://host/my_database'
@ -1773,12 +1796,18 @@ class MongoClient(common.BaseObject):
raise ConfigurationError(
'No default database name defined or provided.')
name = cast(str, self.__default_database_name or default)
return database.Database(
self, self.__default_database_name or default, codec_options,
self, name, codec_options,
read_preference, write_concern, read_concern)
def get_database(self, name=None, codec_options=None, read_preference=None,
write_concern=None, read_concern=None):
def get_database(self,
name: Optional[str] = None,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional["ReadConcern"] = None,
) -> database.Database[_DocumentType]:
"""Get a :class:`~pymongo.database.Database` with the given name and
options.
@ -1838,16 +1867,16 @@ class MongoClient(common.BaseObject):
read_preference=ReadPreference.PRIMARY,
write_concern=DEFAULT_WRITE_CONCERN)
def __enter__(self):
def __enter__(self) -> "MongoClient[_DocumentType]":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def __iter__(self):
def __iter__(self) -> "MongoClient[_DocumentType]":
return self
def __next__(self):
def __next__(self) -> None:
raise TypeError("'MongoClient' object is not iterable")
next = __next__

View File

@ -18,10 +18,10 @@ import atexit
import threading
import time
import weakref
from typing import Any, Mapping, cast
from pymongo import common, periodic_executor
from pymongo.errors import (NotPrimaryError,
OperationFailure,
from pymongo.errors import (NotPrimaryError, OperationFailure,
_OperationCancelled)
from pymongo.hello import Hello
from pymongo.periodic_executor import _shutdown_executors
@ -50,7 +50,7 @@ class MonitorBase(object):
monitor = self_ref()
if monitor is None:
return False # Stop the executor.
monitor._run()
monitor._run() # type:ignore[attr-defined]
return True
executor = periodic_executor.PeriodicExecutor(
@ -214,8 +214,8 @@ class Monitor(MonitorBase):
return self._check_once()
except (OperationFailure, NotPrimaryError) as exc:
# Update max cluster time even when hello fails.
self._topology.receive_cluster_time(
exc.details.get('$clusterTime'))
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get('$clusterTime'))
raise
except ReferenceError:
raise

View File

@ -180,12 +180,21 @@ will not add that listener to existing client instances.
handler first.
"""
import datetime
from collections import abc, namedtuple
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional
from pymongo.hello import HelloCompat
from bson.objectid import ObjectId
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers import _handle_exception
from pymongo.typings import _Address
_Listeners = namedtuple('Listeners',
if TYPE_CHECKING:
from pymongo.server_description import ServerDescription
from pymongo.topology_description import TopologyDescription
_Listeners = namedtuple('_Listeners',
('command_listeners', 'server_listeners',
'server_heartbeat_listeners', 'topology_listeners',
'cmap_listeners'))
@ -193,6 +202,9 @@ _Listeners = namedtuple('Listeners',
_LISTENERS = _Listeners([], [], [], [], [])
_DocumentOut = Mapping[str, Any]
class _EventListener(object):
"""Abstract base class for all event listeners."""
@ -204,7 +216,7 @@ class CommandListener(_EventListener):
and `CommandFailedEvent`.
"""
def started(self, event):
def started(self, event: "CommandStartedEvent") -> None:
"""Abstract method to handle a `CommandStartedEvent`.
:Parameters:
@ -212,7 +224,7 @@ class CommandListener(_EventListener):
"""
raise NotImplementedError
def succeeded(self, event):
def succeeded(self, event: "CommandSucceededEvent") -> None:
"""Abstract method to handle a `CommandSucceededEvent`.
:Parameters:
@ -220,7 +232,7 @@ class CommandListener(_EventListener):
"""
raise NotImplementedError
def failed(self, event):
def failed(self, event: "CommandFailedEvent") -> None:
"""Abstract method to handle a `CommandFailedEvent`.
:Parameters:
@ -245,7 +257,7 @@ class ConnectionPoolListener(_EventListener):
.. versionadded:: 3.9
"""
def pool_created(self, event):
def pool_created(self, event: "PoolCreatedEvent") -> None:
"""Abstract method to handle a :class:`PoolCreatedEvent`.
Emitted when a Connection Pool is created.
@ -255,7 +267,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def pool_ready(self, event):
def pool_ready(self, event: "PoolReadyEvent") -> None:
"""Abstract method to handle a :class:`PoolReadyEvent`.
Emitted when a Connection Pool is marked ready.
@ -267,7 +279,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def pool_cleared(self, event):
def pool_cleared(self, event: "PoolClearedEvent") -> None:
"""Abstract method to handle a `PoolClearedEvent`.
Emitted when a Connection Pool is cleared.
@ -277,7 +289,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def pool_closed(self, event):
def pool_closed(self, event: "PoolClosedEvent") -> None:
"""Abstract method to handle a `PoolClosedEvent`.
Emitted when a Connection Pool is closed.
@ -287,7 +299,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_created(self, event):
def connection_created(self, event: "ConnectionCreatedEvent") -> None:
"""Abstract method to handle a :class:`ConnectionCreatedEvent`.
Emitted when a Connection Pool creates a Connection object.
@ -297,7 +309,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_ready(self, event):
def connection_ready(self, event: "ConnectionReadyEvent") -> None:
"""Abstract method to handle a :class:`ConnectionReadyEvent`.
Emitted when a Connection has finished its setup, and is now ready to
@ -308,7 +320,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_closed(self, event):
def connection_closed(self, event: "ConnectionClosedEvent") -> None:
"""Abstract method to handle a :class:`ConnectionClosedEvent`.
Emitted when a Connection Pool closes a Connection.
@ -318,7 +330,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_check_out_started(self, event):
def connection_check_out_started(self, event: "ConnectionCheckOutStartedEvent") -> None:
"""Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`.
Emitted when the driver starts attempting to check out a connection.
@ -328,7 +340,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_check_out_failed(self, event):
def connection_check_out_failed(self, event: "ConnectionCheckOutFailedEvent") -> None:
"""Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`.
Emitted when the driver's attempt to check out a connection fails.
@ -338,7 +350,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_checked_out(self, event):
def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None:
"""Abstract method to handle a :class:`ConnectionCheckedOutEvent`.
Emitted when the driver successfully checks out a Connection.
@ -348,7 +360,7 @@ class ConnectionPoolListener(_EventListener):
"""
raise NotImplementedError
def connection_checked_in(self, event):
def connection_checked_in(self, event: "ConnectionCheckedInEvent") -> None:
"""Abstract method to handle a :class:`ConnectionCheckedInEvent`.
Emitted when the driver checks in a Connection back to the Connection
@ -369,7 +381,7 @@ class ServerHeartbeatListener(_EventListener):
.. versionadded:: 3.3
"""
def started(self, event):
def started(self, event: "ServerHeartbeatStartedEvent") -> None:
"""Abstract method to handle a `ServerHeartbeatStartedEvent`.
:Parameters:
@ -377,7 +389,7 @@ class ServerHeartbeatListener(_EventListener):
"""
raise NotImplementedError
def succeeded(self, event):
def succeeded(self, event: "ServerHeartbeatSucceededEvent") -> None:
"""Abstract method to handle a `ServerHeartbeatSucceededEvent`.
:Parameters:
@ -385,7 +397,7 @@ class ServerHeartbeatListener(_EventListener):
"""
raise NotImplementedError
def failed(self, event):
def failed(self, event: "ServerHeartbeatFailedEvent") -> None:
"""Abstract method to handle a `ServerHeartbeatFailedEvent`.
:Parameters:
@ -402,7 +414,7 @@ class TopologyListener(_EventListener):
.. versionadded:: 3.3
"""
def opened(self, event):
def opened(self, event: "TopologyOpenedEvent") -> None:
"""Abstract method to handle a `TopologyOpenedEvent`.
:Parameters:
@ -410,7 +422,7 @@ class TopologyListener(_EventListener):
"""
raise NotImplementedError
def description_changed(self, event):
def description_changed(self, event: "TopologyDescriptionChangedEvent") -> None:
"""Abstract method to handle a `TopologyDescriptionChangedEvent`.
:Parameters:
@ -418,7 +430,7 @@ class TopologyListener(_EventListener):
"""
raise NotImplementedError
def closed(self, event):
def closed(self, event: "TopologyClosedEvent") -> None:
"""Abstract method to handle a `TopologyClosedEvent`.
:Parameters:
@ -435,7 +447,7 @@ class ServerListener(_EventListener):
.. versionadded:: 3.3
"""
def opened(self, event):
def opened(self, event: "ServerOpeningEvent") -> None:
"""Abstract method to handle a `ServerOpeningEvent`.
:Parameters:
@ -443,7 +455,7 @@ class ServerListener(_EventListener):
"""
raise NotImplementedError
def description_changed(self, event):
def description_changed(self, event: "ServerDescriptionChangedEvent") -> None:
"""Abstract method to handle a `ServerDescriptionChangedEvent`.
:Parameters:
@ -451,7 +463,7 @@ class ServerListener(_EventListener):
"""
raise NotImplementedError
def closed(self, event):
def closed(self, event: "ServerClosedEvent") -> None:
"""Abstract method to handle a `ServerClosedEvent`.
:Parameters:
@ -478,7 +490,7 @@ def _validate_event_listeners(option, listeners):
return listeners
def register(listener):
def register(listener: _EventListener) -> None:
"""Register a global event listener.
:Parameters:
@ -525,8 +537,14 @@ class _CommandEvent(object):
__slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id",
"__service_id")
def __init__(self, command_name, request_id, connection_id, operation_id,
service_id=None):
def __init__(
self,
command_name: str,
request_id: int,
connection_id: _Address,
operation_id: Optional[int],
service_id: Optional[ObjectId] = None,
) -> None:
self.__cmd_name = command_name
self.__rqst_id = request_id
self.__conn_id = connection_id
@ -534,22 +552,22 @@ class _CommandEvent(object):
self.__service_id = service_id
@property
def command_name(self):
def command_name(self) -> str:
"""The command name."""
return self.__cmd_name
@property
def request_id(self):
def request_id(self) -> int:
"""The request id for this operation."""
return self.__rqst_id
@property
def connection_id(self):
def connection_id(self) -> _Address:
"""The address (host, port) of the server this command was sent to."""
return self.__conn_id
@property
def service_id(self):
def service_id(self) -> Optional[ObjectId]:
"""The service_id this command was sent to, or ``None``.
.. versionadded:: 3.12
@ -557,7 +575,7 @@ class _CommandEvent(object):
return self.__service_id
@property
def operation_id(self):
def operation_id(self) -> Optional[int]:
"""An id for this series of events or None."""
return self.__op_id
@ -576,28 +594,36 @@ class CommandStartedEvent(_CommandEvent):
"""
__slots__ = ("__cmd", "__db")
def __init__(self, command, database_name, *args, service_id=None):
def __init__(
self,
command: _DocumentOut,
database_name: str,
request_id: int,
connection_id: _Address,
operation_id: Optional[int],
service_id: Optional[ObjectId] = None,
) -> None:
if not command:
raise ValueError("%r is not a valid command" % (command,))
# Command name must be first key.
command_name = next(iter(command))
super(CommandStartedEvent, self).__init__(
command_name, *args, service_id=service_id)
command_name, request_id, connection_id, operation_id, service_id=service_id)
cmd_name, cmd_doc = command_name.lower(), command[command_name]
if (cmd_name in _SENSITIVE_COMMANDS or
_is_speculative_authenticate(cmd_name, command)):
self.__cmd = {}
self.__cmd: Mapping[str, Any] = {}
else:
self.__cmd = command
self.__db = database_name
@property
def command(self):
def command(self) -> _DocumentOut:
"""The command document."""
return self.__cmd
@property
def database_name(self):
def database_name(self) -> str:
"""The name of the database this command was run against."""
return self.__db
@ -625,8 +651,16 @@ class CommandSucceededEvent(_CommandEvent):
"""
__slots__ = ("__duration_micros", "__reply")
def __init__(self, duration, reply, command_name,
request_id, connection_id, operation_id, service_id=None):
def __init__(
self,
duration: datetime.timedelta,
reply: _DocumentOut,
command_name: str,
request_id: int,
connection_id: _Address,
operation_id: Optional[int],
service_id: Optional[ObjectId] = None,
) -> None:
super(CommandSucceededEvent, self).__init__(
command_name, request_id, connection_id, operation_id,
service_id=service_id)
@ -634,17 +668,17 @@ class CommandSucceededEvent(_CommandEvent):
cmd_name = command_name.lower()
if (cmd_name in _SENSITIVE_COMMANDS or
_is_speculative_authenticate(cmd_name, reply)):
self.__reply = {}
self.__reply: Mapping[str, Any] = {}
else:
self.__reply = reply
@property
def duration_micros(self):
def duration_micros(self) -> int:
"""The duration of this operation in microseconds."""
return self.__duration_micros
@property
def reply(self):
def reply(self) -> _DocumentOut:
"""The server failure document for this operation."""
return self.__reply
@ -672,18 +706,27 @@ class CommandFailedEvent(_CommandEvent):
"""
__slots__ = ("__duration_micros", "__failure")
def __init__(self, duration, failure, *args, service_id=None):
super(CommandFailedEvent, self).__init__(*args, service_id=service_id)
def __init__(
self,
duration: datetime.timedelta,
failure: _DocumentOut,
command_name: str,
request_id: int,
connection_id: _Address,
operation_id: Optional[int],
service_id: Optional[ObjectId] = None,
) -> None:
super(CommandFailedEvent, self).__init__(command_name, request_id, connection_id, operation_id, service_id=service_id)
self.__duration_micros = _to_micros(duration)
self.__failure = failure
@property
def duration_micros(self):
def duration_micros(self) -> int:
"""The duration of this operation in microseconds."""
return self.__duration_micros
@property
def failure(self):
def failure(self) -> _DocumentOut:
"""The server failure document for this operation."""
return self.__failure
@ -700,11 +743,11 @@ class _PoolEvent(object):
"""Base class for pool events."""
__slots__ = ("__address",)
def __init__(self, address):
def __init__(self, address: _Address) -> None:
self.__address = address
@property
def address(self):
def address(self) -> _Address:
"""The address (host, port) pair of the server the pool is attempting
to connect to.
"""
@ -725,12 +768,12 @@ class PoolCreatedEvent(_PoolEvent):
"""
__slots__ = ("__options",)
def __init__(self, address, options):
def __init__(self, address: _Address, options: Dict[str, Any]) -> None:
super(PoolCreatedEvent, self).__init__(address)
self.__options = options
@property
def options(self):
def options(self) -> Dict[str, Any]:
"""Any non-default pool options that were set on this Connection Pool.
"""
return self.__options
@ -764,12 +807,12 @@ class PoolClearedEvent(_PoolEvent):
"""
__slots__ = ("__service_id",)
def __init__(self, address, service_id=None):
def __init__(self, address: _Address, service_id: Optional[ObjectId] = None) -> None:
super(PoolClearedEvent, self).__init__(address)
self.__service_id = service_id
@property
def service_id(self):
def service_id(self) -> Optional[ObjectId]:
"""Connections with this service_id are cleared.
When service_id is ``None``, all connections in the pool are cleared.
@ -839,19 +882,19 @@ class _ConnectionEvent(object):
"""Private base class for some connection events."""
__slots__ = ("__address", "__connection_id")
def __init__(self, address, connection_id):
def __init__(self, address: _Address, connection_id: int) -> None:
self.__address = address
self.__connection_id = connection_id
@property
def address(self):
def address(self) -> _Address:
"""The address (host, port) pair of the server this connection is
attempting to connect to.
"""
return self.__address
@property
def connection_id(self):
def connection_id(self) -> int:
"""The ID of the Connection."""
return self.__connection_id
@ -958,19 +1001,19 @@ class ConnectionCheckOutFailedEvent(object):
"""
__slots__ = ("__address", "__reason")
def __init__(self, address, reason):
def __init__(self, address: _Address, reason: str) -> None:
self.__address = address
self.__reason = reason
@property
def address(self):
def address(self) -> _Address:
"""The address (host, port) pair of the server this connection is
attempting to connect to.
"""
return self.__address
@property
def reason(self):
def reason(self) -> str:
"""A reason explaining why connection check out failed.
The reason must be one of the strings from the
@ -1014,17 +1057,17 @@ class _ServerEvent(object):
__slots__ = ("__server_address", "__topology_id")
def __init__(self, server_address, topology_id):
def __init__(self, server_address: _Address, topology_id: ObjectId) -> None:
self.__server_address = server_address
self.__topology_id = topology_id
@property
def server_address(self):
def server_address(self) -> _Address:
"""The address (host, port) pair of the server"""
return self.__server_address
@property
def topology_id(self):
def topology_id(self) -> ObjectId:
"""A unique identifier for the topology this server is a part of."""
return self.__topology_id
@ -1041,19 +1084,19 @@ class ServerDescriptionChangedEvent(_ServerEvent):
__slots__ = ('__previous_description', '__new_description')
def __init__(self, previous_description, new_description, *args):
def __init__(self, previous_description: "ServerDescription", new_description: "ServerDescription", *args: Any) -> None:
super(ServerDescriptionChangedEvent, self).__init__(*args)
self.__previous_description = previous_description
self.__new_description = new_description
@property
def previous_description(self):
def previous_description(self) -> "ServerDescription":
"""The previous
:class:`~pymongo.server_description.ServerDescription`."""
return self.__previous_description
@property
def new_description(self):
def new_description(self) -> "ServerDescription":
"""The new
:class:`~pymongo.server_description.ServerDescription`."""
return self.__new_description
@ -1087,11 +1130,11 @@ class TopologyEvent(object):
__slots__ = ('__topology_id')
def __init__(self, topology_id):
def __init__(self, topology_id: ObjectId) -> None:
self.__topology_id = topology_id
@property
def topology_id(self):
def topology_id(self) -> ObjectId:
"""A unique identifier for the topology this server is a part of."""
return self.__topology_id
@ -1108,19 +1151,19 @@ class TopologyDescriptionChangedEvent(TopologyEvent):
__slots__ = ('__previous_description', '__new_description')
def __init__(self, previous_description, new_description, *args):
def __init__(self, previous_description: "TopologyDescription", new_description: "TopologyDescription", *args: Any) -> None:
super(TopologyDescriptionChangedEvent, self).__init__(*args)
self.__previous_description = previous_description
self.__new_description = new_description
@property
def previous_description(self):
def previous_description(self) -> "TopologyDescription":
"""The previous
:class:`~pymongo.topology_description.TopologyDescription`."""
return self.__previous_description
@property
def new_description(self):
def new_description(self) -> "TopologyDescription":
"""The new
:class:`~pymongo.topology_description.TopologyDescription`."""
return self.__new_description
@ -1154,11 +1197,11 @@ class _ServerHeartbeatEvent(object):
__slots__ = ('__connection_id')
def __init__(self, connection_id):
def __init__(self, connection_id: _Address) -> None:
self.__connection_id = connection_id
@property
def connection_id(self):
def connection_id(self) -> _Address:
"""The address (host, port) of the server this heartbeat was sent
to."""
return self.__connection_id
@ -1184,24 +1227,24 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
__slots__ = ('__duration', '__reply', '__awaited')
def __init__(self, duration, reply, connection_id, awaited=False):
def __init__(self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False) -> None:
super(ServerHeartbeatSucceededEvent, self).__init__(connection_id)
self.__duration = duration
self.__reply = reply
self.__awaited = awaited
@property
def duration(self):
def duration(self) -> float:
"""The duration of this heartbeat in microseconds."""
return self.__duration
@property
def reply(self):
def reply(self) -> Hello:
"""An instance of :class:`~pymongo.hello.Hello`."""
return self.__reply
@property
def awaited(self):
def awaited(self) -> bool:
"""Whether the heartbeat was awaited.
If true, then :meth:`duration` reflects the sum of the round trip time
@ -1225,24 +1268,24 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent):
__slots__ = ('__duration', '__reply', '__awaited')
def __init__(self, duration, reply, connection_id, awaited=False):
def __init__(self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False) -> None:
super(ServerHeartbeatFailedEvent, self).__init__(connection_id)
self.__duration = duration
self.__reply = reply
self.__awaited = awaited
@property
def duration(self):
def duration(self) -> float:
"""The duration of this heartbeat in microseconds."""
return self.__duration
@property
def reply(self):
def reply(self) -> Exception:
"""A subclass of :exc:`Exception`."""
return self.__reply
@property
def awaited(self):
def awaited(self) -> bool:
"""Whether the heartbeat was awaited.
If true, then :meth:`duration` reflects the sum of the round trip time

View File

@ -20,21 +20,16 @@ import socket
import struct
import time
from bson import _decode_all_selective
from pymongo import helpers, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import decompress, _NO_COMPRESSION
from pymongo.errors import (NotPrimaryError,
OperationFailure,
ProtocolError,
from pymongo.compression_support import _NO_COMPRESSION, decompress
from pymongo.errors import (NotPrimaryError, OperationFailure, ProtocolError,
_OperationCancelled)
from pymongo.message import _UNPACK_REPLY, _OpMsg
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.socket_checker import _errno_from_exception
_UNPACK_HEADER = struct.Struct("<iiii").unpack

View File

@ -21,7 +21,7 @@ from threading import Lock
class _OCSPCache(object):
"""A cache for OCSP responses."""
CACHE_KEY_TYPE = namedtuple('OcspResponseCacheKey',
CACHE_KEY_TYPE = namedtuple('OcspResponseCacheKey', # type: ignore
['hash_algorithm', 'issuer_name_hash',
'issuer_key_hash', 'serial_number'])

View File

@ -16,41 +16,40 @@
import logging as _logging
import re as _re
from datetime import datetime as _datetime
from cryptography.exceptions import InvalidSignature as _InvalidSignature
from cryptography.hazmat.backends import default_backend as _default_backend
from cryptography.hazmat.primitives.asymmetric.dsa import (
DSAPublicKey as _DSAPublicKey)
from cryptography.hazmat.primitives.asymmetric.ec import (
ECDSA as _ECDSA,
EllipticCurvePublicKey as _EllipticCurvePublicKey)
from cryptography.hazmat.primitives.asymmetric.padding import (
PKCS1v15 as _PKCS1v15)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPublicKey as _RSAPublicKey)
from cryptography.hazmat.primitives.hashes import (
Hash as _Hash,
SHA1 as _SHA1)
from cryptography.hazmat.primitives.serialization import (
Encoding as _Encoding,
PublicFormat as _PublicFormat)
from cryptography.x509 import (
AuthorityInformationAccess as _AuthorityInformationAccess,
ExtendedKeyUsage as _ExtendedKeyUsage,
ExtensionNotFound as _ExtensionNotFound,
load_pem_x509_certificate as _load_pem_x509_certificate,
TLSFeature as _TLSFeature,
TLSFeatureType as _TLSFeatureType)
from cryptography.x509.oid import (
AuthorityInformationAccessOID as _AuthorityInformationAccessOID,
ExtendedKeyUsageOID as _ExtendedKeyUsageOID)
from cryptography.x509.ocsp import (
load_der_ocsp_response as _load_der_ocsp_response,
OCSPCertStatus as _OCSPCertStatus,
OCSPRequestBuilder as _OCSPRequestBuilder,
OCSPResponseStatus as _OCSPResponseStatus)
from cryptography.hazmat.primitives.asymmetric.dsa import \
DSAPublicKey as _DSAPublicKey
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA as _ECDSA
from cryptography.hazmat.primitives.asymmetric.ec import \
EllipticCurvePublicKey as _EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.padding import \
PKCS1v15 as _PKCS1v15
from cryptography.hazmat.primitives.asymmetric.rsa import \
RSAPublicKey as _RSAPublicKey
from cryptography.hazmat.primitives.hashes import SHA1 as _SHA1
from cryptography.hazmat.primitives.hashes import Hash as _Hash
from cryptography.hazmat.primitives.serialization import Encoding as _Encoding
from cryptography.hazmat.primitives.serialization import \
PublicFormat as _PublicFormat
from cryptography.x509 import \
AuthorityInformationAccess as _AuthorityInformationAccess
from cryptography.x509 import ExtendedKeyUsage as _ExtendedKeyUsage
from cryptography.x509 import ExtensionNotFound as _ExtensionNotFound
from cryptography.x509 import TLSFeature as _TLSFeature
from cryptography.x509 import TLSFeatureType as _TLSFeatureType
from cryptography.x509 import \
load_pem_x509_certificate as _load_pem_x509_certificate
from cryptography.x509.ocsp import OCSPCertStatus as _OCSPCertStatus
from cryptography.x509.ocsp import OCSPRequestBuilder as _OCSPRequestBuilder
from cryptography.x509.ocsp import OCSPResponseStatus as _OCSPResponseStatus
from cryptography.x509.ocsp import \
load_der_ocsp_response as _load_der_ocsp_response
from cryptography.x509.oid import \
AuthorityInformationAccessOID as _AuthorityInformationAccessOID
from cryptography.x509.oid import ExtendedKeyUsageOID as _ExtendedKeyUsageOID
from requests import post as _post
from requests.exceptions import RequestException as _RequestException

View File

@ -13,11 +13,13 @@
# limitations under the License.
"""Operation class definitions."""
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from pymongo import helpers
from pymongo.common import validate_boolean, validate_is_mapping, validate_list
from pymongo.collation import validate_collation_or_none
from pymongo.common import validate_boolean, validate_is_mapping, validate_list
from pymongo.helpers import _gen_index_name, _index_document, _index_list
from pymongo.typings import _CollationIn, _DocumentIn, _Pipeline
class InsertOne(object):
@ -25,7 +27,7 @@ class InsertOne(object):
__slots__ = ("_doc",)
def __init__(self, document):
def __init__(self, document: _DocumentIn) -> None:
"""Create an InsertOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
@ -43,21 +45,25 @@ class InsertOne(object):
def __repr__(self):
return "InsertOne(%r)" % (self._doc,)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return other._doc == self._doc
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
_IndexKeyHint = Union[str, _IndexList]
class DeleteOne(object):
"""Represents a delete_one operation."""
__slots__ = ("_filter", "_collation", "_hint")
def __init__(self, filter, collation=None, hint=None):
def __init__(self, filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None) -> None:
"""Create a DeleteOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
@ -95,13 +101,13 @@ class DeleteOne(object):
def __repr__(self):
return "DeleteOne(%r, %r)" % (self._filter, self._collation)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return ((other._filter, other._collation) ==
(self._filter, self._collation))
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
@ -110,7 +116,7 @@ class DeleteMany(object):
__slots__ = ("_filter", "_collation", "_hint")
def __init__(self, filter, collation=None, hint=None):
def __init__(self, filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None) -> None:
"""Create a DeleteMany instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
@ -148,13 +154,13 @@ class DeleteMany(object):
def __repr__(self):
return "DeleteMany(%r, %r)" % (self._filter, self._collation)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return ((other._filter, other._collation) ==
(self._filter, self._collation))
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
@ -163,8 +169,8 @@ class ReplaceOne(object):
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
def __init__(self, filter, replacement, upsert=False, collation=None,
hint=None):
def __init__(self, filter: Mapping[str, Any], replacement: Mapping[str, Any], upsert: bool = False, collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None) -> None:
"""Create a ReplaceOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
@ -207,7 +213,7 @@ class ReplaceOne(object):
bulkobj.add_replace(self._filter, self._doc, self._upsert,
collation=self._collation, hint=self._hint)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if type(other) == type(self):
return (
(other._filter, other._doc, other._upsert, other._collation,
@ -215,7 +221,7 @@ class ReplaceOne(object):
self._collation, other._hint))
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self):
@ -241,7 +247,6 @@ class _UpdateOp(object):
if not isinstance(hint, str):
hint = helpers._index_document(hint)
self._filter = filter
self._doc = doc
self._upsert = upsert
@ -272,8 +277,8 @@ class UpdateOne(_UpdateOp):
__slots__ = ()
def __init__(self, filter, update, upsert=False, collation=None,
array_filters=None, hint=None):
def __init__(self, filter: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, collation: Optional[_CollationIn] = None,
array_filters: Optional[List[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None) -> None:
"""Represents an update_one operation.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
@ -319,8 +324,8 @@ class UpdateMany(_UpdateOp):
__slots__ = ()
def __init__(self, filter, update, upsert=False, collation=None,
array_filters=None, hint=None):
def __init__(self, filter: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, collation: Optional[_CollationIn] = None,
array_filters: Optional[List[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None) -> None:
"""Create an UpdateMany instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
@ -366,7 +371,7 @@ class IndexModel(object):
__slots__ = ("__document",)
def __init__(self, keys, **kwargs):
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
"""Create an Index instance.
For use with :meth:`~pymongo.collection.Collection.create_indexes`.
@ -437,7 +442,7 @@ class IndexModel(object):
self.__document['collation'] = collation
@property
def document(self):
def document(self) -> Dict[str, Any]:
"""An index document suitable for passing to the createIndexes
command.
"""

View File

@ -17,6 +17,7 @@
import threading
import time
import weakref
from typing import Any, Optional
class PeriodicExecutor(object):
@ -41,7 +42,7 @@ class PeriodicExecutor(object):
self._min_interval = min_interval
self._target = target
self._stopped = False
self._thread = None
self._thread: Optional[threading.Thread] = None
self._name = name
self._skip_sleep = False
@ -52,7 +53,7 @@ class PeriodicExecutor(object):
return '<%s(name=%s) object at 0x%x>' % (
self.__class__.__name__, self._name, id(self))
def open(self):
def open(self) -> None:
"""Start. Multiple calls have no effect.
Not safe to call from multiple threads at once.
@ -64,13 +65,14 @@ class PeriodicExecutor(object):
# join should not block indefinitely because there is no
# other work done outside the while loop in self._run.
try:
assert self._thread is not None
self._thread.join()
except ReferenceError:
# Thread terminated.
pass
self._thread_will_exit = False
self._stopped = False
started = False
started: Any = False
try:
started = self._thread and self._thread.is_alive()
except ReferenceError:
@ -84,7 +86,7 @@ class PeriodicExecutor(object):
_register_executor(self)
thread.start()
def close(self, dummy=None):
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
@ -92,7 +94,7 @@ class PeriodicExecutor(object):
"""
self._stopped = True
def join(self, timeout=None):
def join(self, timeout: Optional[int] = None) -> None:
if self._thread is not None:
try:
self._thread.join(timeout)
@ -100,14 +102,14 @@ class PeriodicExecutor(object):
# Thread already terminated, or not yet started.
pass
def wake(self):
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval):
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self):
def skip_sleep(self) -> None:
self._skip_sleep = True
def __should_stop(self):

View File

@ -18,50 +18,38 @@ import copy
import ipaddress
import os
import platform
import ssl
import socket
import ssl
import sys
import threading
import time
import weakref
from typing import Any
from bson import DEFAULT_CODEC_OPTIONS
from bson.son import SON
from pymongo import auth, helpers, __version__
from pymongo import __version__, auth, helpers
from pymongo.client_session import _validate_session_write_concern
from pymongo.common import (MAX_BSON_SIZE,
MAX_CONNECTING,
MAX_IDLE_TIME_SEC,
MAX_MESSAGE_SIZE,
MAX_POOL_SIZE,
MAX_WIRE_VERSION,
MAX_WRITE_BATCH_SIZE,
MIN_POOL_SIZE,
ORDERED_TYPES,
from pymongo.common import (MAX_BSON_SIZE, MAX_CONNECTING, MAX_IDLE_TIME_SEC,
MAX_MESSAGE_SIZE, MAX_POOL_SIZE, MAX_WIRE_VERSION,
MAX_WRITE_BATCH_SIZE, MIN_POOL_SIZE, ORDERED_TYPES,
WAIT_QUEUE_TIMEOUT)
from pymongo.errors import (AutoReconnect,
_CertificateError,
ConnectionFailure,
ConfigurationError,
InvalidOperation,
DocumentTooLarge,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError)
from pymongo.hello import HelloCompat, Hello
from pymongo.errors import (AutoReconnect, ConfigurationError,
ConnectionFailure, DocumentTooLarge,
InvalidOperation, NetworkTimeout, NotPrimaryError,
OperationFailure, PyMongoError, _CertificateError)
from pymongo.hello import Hello, HelloCompat
from pymongo.monitoring import (ConnectionCheckOutFailedReason,
ConnectionClosedReason)
from pymongo.network import (command,
receive_message)
from pymongo.network import command, receive_message
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import (
SSLError as _SSLError,
HAS_SNI as _HAVE_SNI,
IPADDR_SAFE as _IPADDR_SAFE)
from pymongo.ssl_support import HAS_SNI as _HAVE_SNI
from pymongo.ssl_support import IPADDR_SAFE as _IPADDR_SAFE
from pymongo.ssl_support import SSLError as _SSLError
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
# not permitted for SNI hostname.
@ -73,7 +61,7 @@ def is_ip_address(address):
return False
try:
from fcntl import fcntl, F_GETFD, F_SETFD, FD_CLOEXEC
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
def _set_non_inheritable_non_atomic(fd):
"""Set the close-on-exec flag on the given file descriptor."""
flags = fcntl(fd, F_GETFD)
@ -82,7 +70,7 @@ except ImportError:
# Windows, various platforms we don't claim to support
# (Jython, IronPython, ...), systems that don't provide
# everything we need from fcntl, etc.
def _set_non_inheritable_non_atomic(dummy):
def _set_non_inheritable_non_atomic(fd):
"""Dummy function for platforms that don't provide fcntl."""
pass
@ -145,7 +133,7 @@ else:
_set_tcp_option(sock, 'TCP_KEEPINTVL', _MAX_TCP_KEEPINTVL)
_set_tcp_option(sock, 'TCP_KEEPCNT', _MAX_TCP_KEEPCNT)
_METADATA = SON([
_METADATA: SON[str, Any] = SON([
('driver', SON([('name', 'PyMongo'), ('version', __version__)])),
])
@ -205,7 +193,7 @@ else:
if platform.python_implementation().startswith('PyPy'):
_METADATA['platform'] = ' '.join(
(platform.python_implementation(),
'.'.join(map(str, sys.pypy_version_info)),
'.'.join(map(str, sys.pypy_version_info)), # type: ignore
'(Python %s)' % '.'.join(map(str, sys.version_info))))
elif sys.platform.startswith('java'):
_METADATA['platform'] = ' '.join(
@ -688,7 +676,7 @@ class SocketInfo(object):
session = _validate_session_write_concern(session, write_concern)
# Ensure command name remains in first place.
if not isinstance(spec, ORDERED_TYPES):
if not isinstance(spec, ORDERED_TYPES): # type:ignore[arg-type]
spec = SON(spec)
if not (write_concern is None or write_concern.acknowledged or
@ -1088,7 +1076,7 @@ class Pool:
# LIFO pool. Sockets are ordered on idle time. Sockets claimed
# and returned to pool from the left side. Stale sockets removed
# from the right side.
self.sockets = collections.deque()
self.sockets: collections.deque = collections.deque()
self.lock = threading.Lock()
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
@ -1165,8 +1153,8 @@ class Pool:
if service_id is None:
sockets, self.sockets = self.sockets, collections.deque()
else:
discard = collections.deque()
keep = collections.deque()
discard: collections.deque = collections.deque()
keep: collections.deque = collections.deque()
for sock_info in self.sockets:
if sock_info.service_id == service_id:
discard.append(sock_info)

View File

@ -20,29 +20,28 @@ import socket as _socket
import ssl as _stdlibssl
import sys as _sys
import time as _time
from errno import EINTR as _EINTR
from ipaddress import ip_address as _ip_address
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate
from OpenSSL import crypto as _crypto, SSL as _SSL
from service_identity.pyopenssl import (
verify_hostname as _verify_hostname,
verify_ip_address as _verify_ip_address)
from cryptography.x509 import \
load_der_x509_certificate as _load_der_x509_certificate
from OpenSSL import SSL as _SSL
from OpenSSL import crypto as _crypto
from service_identity import (
CertificateError as _SICertificateError,
VerificationError as _SIVerificationError)
CertificateError as _SICertificateError
)
from service_identity import VerificationError as _SIVerificationError
from service_identity.pyopenssl import ( #
verify_hostname as _verify_hostname
)
from service_identity.pyopenssl import verify_ip_address as _verify_ip_address
from pymongo.errors import (
_CertificateError,
ConfigurationError as _ConfigurationError)
from pymongo.ocsp_support import (
_load_trusted_ca_certs,
_ocsp_callback)
from pymongo.errors import ConfigurationError as _ConfigurationError
from pymongo.errors import _CertificateError
from pymongo.ocsp_cache import _OCSPCache
from pymongo.socket_checker import (
_errno_from_exception, SocketChecker as _SocketChecker)
from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback
from pymongo.socket_checker import SocketChecker as _SocketChecker
from pymongo.socket_checker import _errno_from_exception
try:
import certifi
@ -132,7 +131,7 @@ class _sslConn(_SSL.Connection):
def recv_into(self, *args, **kwargs):
try:
return self._call(super(_sslConn, self).recv_into, *args, **kwargs)
return self._call(super(_sslConn, self).recv_into, *args, **kwargs) # type: ignore
except _SSL.SysCallError as exc:
# Suppress ragged EOFs to match the stdlib.
if self.suppress_ragged_eofs and _ragged_eof(exc):
@ -147,7 +146,7 @@ class _sslConn(_SSL.Connection):
while total_sent < total_length:
try:
sent = self._call(
super(_sslConn, self).send, view[total_sent:], flags)
super(_sslConn, self).send, view[total_sent:], flags) # type: ignore
# XXX: It's not clear if this can actually happen. PyOpenSSL
# doesn't appear to have any interrupt handling, nor any interrupt
# errors for OpenSSL connections.
@ -296,7 +295,7 @@ class SSLContext(object):
"""Attempt to load CA certs from Windows trust store."""
cert_store = self._ctx.get_cert_store()
oid = _stdlibssl.Purpose.SERVER_AUTH.oid
for cert, encoding, trust in _stdlibssl.enum_certificates(store):
for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore
if encoding == "x509_asn":
if trust is True or oid in trust:
cert_store.add_cert(

View File

@ -14,6 +14,8 @@
"""Tools for working with read concerns."""
from typing import Any, Dict, Optional
class ReadConcern(object):
"""ReadConcern
@ -29,7 +31,7 @@ class ReadConcern(object):
"""
def __init__(self, level=None):
def __init__(self, level: Optional[str] = None) -> None:
if level is None or isinstance(level, str):
self.__level = level
else:
@ -37,18 +39,18 @@ class ReadConcern(object):
'level must be a string or None.')
@property
def level(self):
def level(self) -> Optional[str]:
"""The read concern level."""
return self.__level
@property
def ok_for_legacy(self):
def ok_for_legacy(self) -> bool:
"""Return ``True`` if this read concern is compatible with
old wire protocol versions."""
return self.level is None or self.level == 'local'
@property
def document(self):
def document(self) -> Dict[str, Any]:
"""The document representation of this read concern.
.. note::
@ -60,7 +62,7 @@ class ReadConcern(object):
doc['level'] = self.level
return doc
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, ReadConcern):
return self.document == other.document
return NotImplemented

View File

@ -15,13 +15,13 @@
"""Utilities for choosing which member of a replica set to read from."""
from collections import abc
from typing import Any, Dict, Mapping, Optional, Sequence
from pymongo import max_staleness_selectors
from pymongo.errors import ConfigurationError
from pymongo.server_selectors import (member_with_tags_server_selector,
secondary_with_tags_server_selector)
_PRIMARY = 0
_PRIMARY_PREFERRED = 1
_SECONDARY = 2
@ -44,9 +44,9 @@ def _validate_tag_sets(tag_sets):
if tag_sets is None:
return tag_sets
if not isinstance(tag_sets, list):
if not isinstance(tag_sets, (list, tuple)):
raise TypeError((
"Tag sets %r invalid, must be a list") % (tag_sets,))
"Tag sets %r invalid, must be a sequence") % (tag_sets,))
if len(tag_sets) == 0:
raise ValueError((
"Tag sets %r invalid, must be None or contain at least one set of"
@ -59,7 +59,7 @@ def _validate_tag_sets(tag_sets):
"bson.son.SON or other type that inherits from "
"collection.Mapping" % (tags,))
return tag_sets
return list(tag_sets)
def _invalid_max_staleness_msg(max_staleness):
@ -93,6 +93,10 @@ def _validate_hedge(hedge):
return hedge
_Hedge = Mapping[str, Any]
_TagSets = Sequence[Mapping[str, Any]]
class _ServerMode(object):
"""Base class for all read preferences.
"""
@ -100,7 +104,7 @@ class _ServerMode(object):
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness",
"__hedge")
def __init__(self, mode, tag_sets=None, max_staleness=-1, hedge=None):
def __init__(self, mode: int, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
self.__mongos_mode = _MONGOS_MODES[mode]
self.__mode = mode
self.__tag_sets = _validate_tag_sets(tag_sets)
@ -108,22 +112,22 @@ class _ServerMode(object):
self.__hedge = _validate_hedge(hedge)
@property
def name(self):
def name(self) -> str:
"""The name of this read preference.
"""
return self.__class__.__name__
@property
def mongos_mode(self):
def mongos_mode(self) -> str:
"""The mongos mode of this read preference.
"""
return self.__mongos_mode
@property
def document(self):
def document(self) -> Dict[str, Any]:
"""Read preference as a document.
"""
doc = {'mode': self.__mongos_mode}
doc: Dict[str, Any] = {'mode': self.__mongos_mode}
if self.__tag_sets not in (None, [{}]):
doc['tags'] = self.__tag_sets
if self.__max_staleness != -1:
@ -133,13 +137,13 @@ class _ServerMode(object):
return doc
@property
def mode(self):
def mode(self) -> int:
"""The mode of this read preference instance.
"""
return self.__mode
@property
def tag_sets(self):
def tag_sets(self) -> _TagSets:
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
read only from members whose ``dc`` tag has the value ``"ny"``.
To specify a priority-order for tag sets, provide a list of
@ -154,14 +158,14 @@ class _ServerMode(object):
return list(self.__tag_sets) if self.__tag_sets else [{}]
@property
def max_staleness(self):
def max_staleness(self) -> int:
"""The maximum estimated length of time (in seconds) a replica set
secondary can fall behind the primary in replication before it will
no longer be selected for operations, or -1 for no maximum."""
return self.__max_staleness
@property
def hedge(self):
def hedge(self) -> Optional[_Hedge]:
"""The read preference ``hedge`` parameter.
A dictionary that configures how the server will perform hedged reads.
@ -185,7 +189,7 @@ class _ServerMode(object):
return self.__hedge
@property
def min_wire_version(self):
def min_wire_version(self) -> int:
"""The wire protocol version the server must support.
Some read preferences impose version requirements on all servers (e.g.
@ -201,7 +205,7 @@ class _ServerMode(object):
return "%s(tag_sets=%r, max_staleness=%r, hedge=%r)" % (
self.name, self.__tag_sets, self.__max_staleness, self.__hedge)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return (self.mode == other.mode and
self.tag_sets == other.tag_sets and
@ -209,7 +213,7 @@ class _ServerMode(object):
self.hedge == other.hedge)
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def __getstate__(self):
@ -243,17 +247,17 @@ class Primary(_ServerMode):
__slots__ = ()
def __init__(self):
def __init__(self) -> None:
super(Primary, self).__init__(_PRIMARY)
def __call__(self, selection):
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to a Selection."""
return selection.primary_selection
def __repr__(self):
return "Primary()"
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, _ServerMode):
return other.mode == _PRIMARY
return NotImplemented
@ -289,11 +293,11 @@ class PrimaryPreferred(_ServerMode):
__slots__ = ()
def __init__(self, tag_sets=None, max_staleness=-1, hedge=None):
def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
super(PrimaryPreferred, self).__init__(
_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection):
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to Selection."""
if selection.primary:
return selection.primary_selection
@ -329,11 +333,11 @@ class Secondary(_ServerMode):
__slots__ = ()
def __init__(self, tag_sets=None, max_staleness=-1, hedge=None):
def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
super(Secondary, self).__init__(
_SECONDARY, tag_sets, max_staleness, hedge)
def __call__(self, selection):
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to Selection."""
return secondary_with_tags_server_selector(
self.tag_sets,
@ -370,11 +374,11 @@ class SecondaryPreferred(_ServerMode):
__slots__ = ()
def __init__(self, tag_sets=None, max_staleness=-1, hedge=None):
def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
super(SecondaryPreferred, self).__init__(
_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
def __call__(self, selection):
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to Selection."""
secondaries = secondary_with_tags_server_selector(
self.tag_sets,
@ -412,11 +416,11 @@ class Nearest(_ServerMode):
__slots__ = ()
def __init__(self, tag_sets=None, max_staleness=-1, hedge=None):
def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
super(Nearest, self).__init__(
_NEAREST, tag_sets, max_staleness, hedge)
def __call__(self, selection):
def __call__(self, selection: Any) -> Any:
"""Apply this read preference to Selection."""
return member_with_tags_server_selector(
self.tag_sets,
@ -467,7 +471,7 @@ _ALL_READ_PREFERENCES = (Primary, PrimaryPreferred,
Secondary, SecondaryPreferred, Nearest)
def make_read_preference(mode, tag_sets, max_staleness=-1):
def make_read_preference(mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1) -> _ServerMode:
if mode == _PRIMARY:
if tag_sets not in (None, [{}]):
raise ConfigurationError("Read preference primary "
@ -476,7 +480,7 @@ def make_read_preference(mode, tag_sets, max_staleness=-1):
raise ConfigurationError("Read preference primary cannot be "
"combined with maxStalenessSeconds")
return Primary()
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness)
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
_MODES = (
@ -545,7 +549,7 @@ class ReadPreference(object):
NEAREST = Nearest()
def read_pref_mode_from_name(name):
def read_pref_mode_from_name(name: str) -> int:
"""Get the read preference mode from mongos/uri name.
"""
return _MONGOS_MODES.index(name)
@ -553,10 +557,12 @@ def read_pref_mode_from_name(name):
class MovingAverage(object):
"""Tracks an exponentially-weighted moving average."""
def __init__(self):
average: Optional[float]
def __init__(self) -> None:
self.average = None
def add_sample(self, sample):
def add_sample(self, sample: float) -> None:
if sample < 0:
# Likely system time change while waiting for hello response
# and not using time.monotonic. Ignore it, the next one will
@ -569,9 +575,9 @@ class MovingAverage(object):
# average with alpha = 0.2.
self.average = 0.8 * self.average + 0.2 * sample
def get(self):
def get(self) -> Optional[float]:
"""Get the calculated average, or None if no samples yet."""
return self.average
def reset(self):
def reset(self) -> None:
self.average = None

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Result class definitions."""
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
from pymongo.errors import InvalidOperation
@ -22,7 +23,7 @@ class _WriteResult(object):
__slots__ = ("__acknowledged",)
def __init__(self, acknowledged):
def __init__(self, acknowledged: bool) -> None:
self.__acknowledged = acknowledged
def _raise_if_unacknowledged(self, property_name):
@ -34,7 +35,7 @@ class _WriteResult(object):
"error." % (property_name,))
@property
def acknowledged(self):
def acknowledged(self) -> bool:
"""Is this the result of an acknowledged write operation?
The :attr:`acknowledged` attribute will be ``False`` when using
@ -59,12 +60,12 @@ class InsertOneResult(_WriteResult):
__slots__ = ("__inserted_id", "__acknowledged")
def __init__(self, inserted_id, acknowledged):
def __init__(self, inserted_id: Any, acknowledged: bool) -> None:
self.__inserted_id = inserted_id
super(InsertOneResult, self).__init__(acknowledged)
@property
def inserted_id(self):
def inserted_id(self) -> Any:
"""The inserted document's _id."""
return self.__inserted_id
@ -75,12 +76,12 @@ class InsertManyResult(_WriteResult):
__slots__ = ("__inserted_ids", "__acknowledged")
def __init__(self, inserted_ids, acknowledged):
def __init__(self, inserted_ids: List[Any], acknowledged: bool) -> None:
self.__inserted_ids = inserted_ids
super(InsertManyResult, self).__init__(acknowledged)
@property
def inserted_ids(self):
def inserted_ids(self) -> List:
"""A list of _ids of the inserted documents, in the order provided.
.. note:: If ``False`` is passed for the `ordered` parameter to
@ -99,17 +100,17 @@ class UpdateResult(_WriteResult):
__slots__ = ("__raw_result", "__acknowledged")
def __init__(self, raw_result, acknowledged):
def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None:
self.__raw_result = raw_result
super(UpdateResult, self).__init__(acknowledged)
@property
def raw_result(self):
def raw_result(self) -> Dict[str, Any]:
"""The raw result document returned by the server."""
return self.__raw_result
@property
def matched_count(self):
def matched_count(self) -> int:
"""The number of documents matched for this update."""
self._raise_if_unacknowledged("matched_count")
if self.upserted_id is not None:
@ -117,13 +118,13 @@ class UpdateResult(_WriteResult):
return self.__raw_result.get("n", 0)
@property
def modified_count(self):
def modified_count(self) -> int:
"""The number of documents modified. """
self._raise_if_unacknowledged("modified_count")
return self.__raw_result.get("nModified")
return cast(int, self.__raw_result.get("nModified"))
@property
def upserted_id(self):
def upserted_id(self) -> Any:
"""The _id of the inserted document if an upsert took place. Otherwise
``None``.
"""
@ -137,17 +138,17 @@ class DeleteResult(_WriteResult):
__slots__ = ("__raw_result", "__acknowledged")
def __init__(self, raw_result, acknowledged):
def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None:
self.__raw_result = raw_result
super(DeleteResult, self).__init__(acknowledged)
@property
def raw_result(self):
def raw_result(self) -> Dict[str, Any]:
"""The raw result document returned by the server."""
return self.__raw_result
@property
def deleted_count(self):
def deleted_count(self) -> int:
"""The number of documents deleted."""
self._raise_if_unacknowledged("deleted_count")
return self.__raw_result.get("n", 0)
@ -158,7 +159,7 @@ class BulkWriteResult(_WriteResult):
__slots__ = ("__bulk_api_result", "__acknowledged")
def __init__(self, bulk_api_result, acknowledged):
def __init__(self, bulk_api_result: Dict[str, Any], acknowledged: bool) -> None:
"""Create a BulkWriteResult instance.
:Parameters:
@ -171,44 +172,45 @@ class BulkWriteResult(_WriteResult):
super(BulkWriteResult, self).__init__(acknowledged)
@property
def bulk_api_result(self):
def bulk_api_result(self) -> Dict[str, Any]:
"""The raw bulk API result."""
return self.__bulk_api_result
@property
def inserted_count(self):
def inserted_count(self) -> int:
"""The number of documents inserted."""
self._raise_if_unacknowledged("inserted_count")
return self.__bulk_api_result.get("nInserted")
return cast(int, self.__bulk_api_result.get("nInserted"))
@property
def matched_count(self):
def matched_count(self) -> int:
"""The number of documents matched for an update."""
self._raise_if_unacknowledged("matched_count")
return self.__bulk_api_result.get("nMatched")
return cast(int, self.__bulk_api_result.get("nMatched"))
@property
def modified_count(self):
def modified_count(self) -> int:
"""The number of documents modified."""
self._raise_if_unacknowledged("modified_count")
return self.__bulk_api_result.get("nModified")
return cast(int, self.__bulk_api_result.get("nModified"))
@property
def deleted_count(self):
def deleted_count(self) -> int:
"""The number of documents deleted."""
self._raise_if_unacknowledged("deleted_count")
return self.__bulk_api_result.get("nRemoved")
return cast(int, self.__bulk_api_result.get("nRemoved"))
@property
def upserted_count(self):
def upserted_count(self) -> int:
"""The number of documents upserted."""
self._raise_if_unacknowledged("upserted_count")
return self.__bulk_api_result.get("nUpserted")
return cast(int, self.__bulk_api_result.get("nUpserted"))
@property
def upserted_ids(self):
def upserted_ids(self) -> Optional[Dict[int, Any]]:
"""A map of operation index to the _id of the upserted document."""
self._raise_if_unacknowledged("upserted_ids")
if self.__bulk_api_result:
return dict((upsert["index"], upsert["_id"])
for upsert in self.bulk_api_result["upserted"])
return None

View File

@ -13,13 +13,13 @@
# limitations under the License.
"""An implementation of RFC4013 SASLprep."""
from typing import Any, Optional
try:
import stringprep
except ImportError:
HAVE_STRINGPREP = False
def saslprep(data):
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> str:
"""SASLprep dummy"""
if isinstance(data, str):
raise TypeError(
@ -29,6 +29,7 @@ except ImportError:
else:
HAVE_STRINGPREP = True
import unicodedata
# RFC4013 section 2.3 prohibited output.
_PROHIBITED = (
# A strict reading of RFC 4013 requires table c12 here, but
@ -44,7 +45,7 @@ else:
stringprep.in_table_c8,
stringprep.in_table_c9)
def saslprep(data, prohibit_unassigned_code_points=True):
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> str:
"""An implementation of RFC4013 SASLprep.
:Parameters:
@ -60,6 +61,8 @@ else:
:Returns:
The SASLprep'ed version of `data`.
"""
prohibited: Any
if not isinstance(data, str):
return data

View File

@ -17,11 +17,10 @@
from datetime import datetime
from bson import _decode_all_selective
from pymongo.errors import NotPrimaryError, OperationFailure
from pymongo.helpers import _check_command_response
from pymongo.message import _convert_exception, _OpMsg
from pymongo.response import Response, PinnedResponse
from pymongo.response import PinnedResponse, Response
from pymongo.server_type import SERVER_TYPE
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
@ -59,6 +58,8 @@ class Server(object):
Reconnect with open().
"""
if self._publish:
assert self._listener is not None
assert self._events is not None
self._events.put((self._listener.publish_server_closed,
(self._description.address, self._topology_id)))
self._monitor.close()
@ -169,6 +170,8 @@ class Server(object):
docs = _decode_all_selective(
decrypted, operation.codec_options, user_fields)
response: Response
if client._should_pin_cursor(operation.session) or operation.exhaust:
sock_info.pin_cursor()
if isinstance(reply, _OpMsg):

View File

@ -15,10 +15,13 @@
"""Represent one server the driver is connected to."""
import time
from typing import Any, Dict, Mapping, Optional, Set, Tuple, cast
from bson import EPOCH_NAIVE
from pymongo.server_type import SERVER_TYPE
from bson.objectid import ObjectId
from pymongo.hello import Hello
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import _Address
class ServerDescription(object):
@ -41,11 +44,12 @@ class ServerDescription(object):
'_topology_version')
def __init__(
self,
address,
hello=None,
round_trip_time=None,
error=None):
self,
address: _Address,
hello: Optional[Hello] = None,
round_trip_time: Optional[float] = None,
error: Optional[Exception] = None,
) -> None:
self._address = address
if not hello:
hello = Hello({})
@ -72,9 +76,11 @@ class ServerDescription(object):
self._error = error
self._topology_version = hello.topology_version
if error:
if hasattr(error, 'details') and isinstance(error.details, dict):
self._topology_version = error.details.get('topologyVersion')
details = getattr(error, 'details', None)
if isinstance(details, dict):
self._topology_version = details.get('topologyVersion')
self._last_write_date: Optional[float]
if hello.last_write_date:
# Convert from datetime to seconds.
delta = hello.last_write_date - EPOCH_NAIVE
@ -83,17 +89,17 @@ class ServerDescription(object):
self._last_write_date = None
@property
def address(self):
def address(self) -> _Address:
"""The address (host, port) of this server."""
return self._address
@property
def server_type(self):
def server_type(self) -> int:
"""The type of this server."""
return self._server_type
@property
def server_type_name(self):
def server_type_name(self) -> str:
"""The server type as a human readable string.
.. versionadded:: 3.4
@ -101,78 +107,78 @@ class ServerDescription(object):
return SERVER_TYPE._fields[self._server_type]
@property
def all_hosts(self):
def all_hosts(self) -> Set[Tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return self._all_hosts
@property
def tags(self):
def tags(self) -> Mapping[str, Any]:
return self._tags
@property
def replica_set_name(self):
def replica_set_name(self) -> Optional[str]:
"""Replica set name or None."""
return self._replica_set_name
@property
def primary(self):
def primary(self) -> Optional[Tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
return self._primary
@property
def max_bson_size(self):
def max_bson_size(self) -> int:
return self._max_bson_size
@property
def max_message_size(self):
def max_message_size(self) -> int:
return self._max_message_size
@property
def max_write_batch_size(self):
def max_write_batch_size(self) -> int:
return self._max_write_batch_size
@property
def min_wire_version(self):
def min_wire_version(self) -> int:
return self._min_wire_version
@property
def max_wire_version(self):
def max_wire_version(self) -> int:
return self._max_wire_version
@property
def set_version(self):
def set_version(self) -> Optional[int]:
return self._set_version
@property
def election_id(self):
def election_id(self) -> Optional[ObjectId]:
return self._election_id
@property
def cluster_time(self):
def cluster_time(self)-> Optional[Mapping[str, Any]]:
return self._cluster_time
@property
def election_tuple(self):
def election_tuple(self) -> Tuple[Optional[int], Optional[ObjectId]]:
return self._set_version, self._election_id
@property
def me(self):
def me(self) -> Optional[Tuple[str, int]]:
return self._me
@property
def logical_session_timeout_minutes(self):
def logical_session_timeout_minutes(self) -> Optional[int]:
return self._ls_timeout_minutes
@property
def last_write_date(self):
def last_write_date(self) -> Optional[float]:
return self._last_write_date
@property
def last_update_time(self):
def last_update_time(self) -> float:
return self._last_update_time
@property
def round_trip_time(self):
def round_trip_time(self) -> Optional[float]:
"""The current average latency or None."""
# This override is for unittesting only!
if self._address in self._host_to_round_trip_time:
@ -181,28 +187,28 @@ class ServerDescription(object):
return self._round_trip_time
@property
def error(self):
def error(self) -> Optional[Exception]:
"""The last error attempting to connect to the server, or None."""
return self._error
@property
def is_writable(self):
def is_writable(self) -> bool:
return self._is_writable
@property
def is_readable(self):
def is_readable(self) -> bool:
return self._is_readable
@property
def mongos(self):
def mongos(self) -> bool:
return self._server_type == SERVER_TYPE.Mongos
@property
def is_server_type_known(self):
def is_server_type_known(self) -> bool:
return self.server_type != SERVER_TYPE.Unknown
@property
def retryable_writes_supported(self):
def retryable_writes_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return ((
self._ls_timeout_minutes is not None and
@ -210,20 +216,20 @@ class ServerDescription(object):
or self._server_type == SERVER_TYPE.LoadBalancer)
@property
def retryable_reads_supported(self):
def retryable_reads_supported(self) -> bool:
"""Checks if this server supports retryable writes."""
return self._max_wire_version >= 6
@property
def topology_version(self):
def topology_version(self) -> Optional[Mapping[str, Any]]:
return self._topology_version
def to_unknown(self, error=None):
def to_unknown(self, error: Optional[Exception] = None) -> "ServerDescription":
unknown = ServerDescription(self.address, error=error)
unknown._topology_version = self.topology_version
return unknown
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, ServerDescription):
return ((self._address == other.address) and
(self._server_type == other.server_type) and
@ -242,7 +248,7 @@ class ServerDescription(object):
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def __repr__(self):
@ -254,4 +260,4 @@ class ServerDescription(object):
self.round_trip_time, errmsg)
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time = {}
_host_to_round_trip_time: Dict = {}

View File

@ -14,10 +14,19 @@
"""Type codes for MongoDB servers."""
from collections import namedtuple
from typing import NamedTuple
SERVER_TYPE = namedtuple('ServerType',
['Unknown', 'Mongos', 'RSPrimary', 'RSSecondary',
'RSArbiter', 'RSOther', 'RSGhost',
'Standalone', 'LoadBalancer'])(*range(9))
class _ServerType(NamedTuple):
Unknown: int
Mongos: int
RSPrimary: int
RSSecondary: int
RSArbiter: int
RSOther: int
RSGhost: int
Standalone: int
LoadBalancer: int
SERVER_TYPE = _ServerType(*range(9))

View File

@ -16,7 +16,9 @@
import errno
import select
import socket
import sys
from typing import Any, Optional
# PYTHON-2320: Jython does not fully support poll on SSL sockets,
# https://bugs.jython.org/issue2900
@ -34,17 +36,19 @@ def _errno_from_exception(exc):
class SocketChecker(object):
def __init__(self):
def __init__(self) -> None:
self._poller: Optional[select.poll]
if _HAVE_POLL:
self._poller = select.poll()
else:
self._poller = None
def select(self, sock, read=False, write=False, timeout=0):
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
"""Select for reads or writes with a timeout in seconds (or None).
Returns True if the socket is readable/writable, False on timeout.
"""
res: Any
while True:
try:
if self._poller:
@ -74,12 +78,12 @@ class SocketChecker(object):
# ready: subsets of the first three arguments. Return
# True if any of the lists are not empty.
return any(res)
except (_SelectError, IOError) as exc:
except (_SelectError, IOError) as exc: # type: ignore
if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN):
continue
raise
def socket_closed(self, sock):
def socket_closed(self, sock: Any) -> bool:
"""Return True if we know socket has been closed, False otherwise.
"""
try:

View File

@ -26,6 +26,7 @@ except ImportError:
from pymongo.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError
# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text):
@ -38,7 +39,7 @@ def maybe_decode(text):
def _resolve(*args, **kwargs):
if hasattr(resolver, 'resolve'):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
return resolver.resolve(*args, **kwargs) # type: ignore
# dnspython 1.X
return resolver.query(*args, **kwargs)

View File

@ -32,6 +32,7 @@ IS_PYOPENSSL = False
SSLError = _ssl.SSLError
from ssl import SSLContext
if hasattr(_ssl, "VERIFY_CRL_CHECK_LEAF"):
from ssl import VERIFY_CRL_CHECK_LEAF
# Python 3.7 uses OpenSSL's hostname matching implementation

View File

@ -24,7 +24,7 @@ try:
import pymongo.pyopenssl_context as _ssl
except ImportError:
try:
import pymongo.ssl_context as _ssl
import pymongo.ssl_context as _ssl # type: ignore[no-redef]
except ImportError:
HAVE_SSL = False
@ -74,7 +74,7 @@ if HAVE_SSL:
raise ConfigurationError(
"tlsCRLFile cannot be used with PyOpenSSL")
# Match the server's behavior.
ctx.verify_flags = getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0)
setattr(ctx, 'verify_flags', getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0))
ctx.load_verify_locations(crlfile)
if ca_certs is not None:
ctx.load_verify_locations(ca_certs)
@ -83,11 +83,11 @@ if HAVE_SSL:
ctx.verify_mode = verify_mode
return ctx
else:
class SSLError(Exception):
class SSLError(Exception): # type: ignore
pass
HAS_SNI = False
IPADDR_SAFE = False
def get_ssl_context(*dummy):
def get_ssl_context(*dummy): # type: ignore
"""No ssl module, raise ConfigurationError."""
raise ConfigurationError("The ssl module is not available.")

View File

@ -21,35 +21,27 @@ import threading
import time
import warnings
import weakref
from typing import Any
from pymongo import (common,
helpers,
periodic_executor)
from pymongo import common, helpers, periodic_executor
from pymongo.client_session import _ServerSessionPool
from pymongo.errors import (ConnectionFailure,
ConfigurationError,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
ServerSelectionTimeoutError,
WriteError,
InvalidOperation)
from pymongo.errors import (ConfigurationError, ConnectionFailure,
InvalidOperation, NetworkTimeout, NotPrimaryError,
OperationFailure, PyMongoError,
ServerSelectionTimeoutError, WriteError)
from pymongo.hello import Hello
from pymongo.monitor import SrvMonitor
from pymongo.pool import PoolOptions
from pymongo.server import Server
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import (any_server_selector,
from pymongo.server_selectors import (Selection, any_server_selector,
arbiter_server_selector,
secondary_server_selector,
readable_server_selector,
writable_server_selector,
Selection)
from pymongo.topology_description import (updated_topology_description,
_updated_topology_description_srv_polling,
TopologyDescription,
SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE)
secondary_server_selector,
writable_server_selector)
from pymongo.topology_description import (
SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE, TopologyDescription,
_updated_topology_description_srv_polling, updated_topology_description)
def process_events_queue(queue_ref):
@ -80,12 +72,13 @@ class Topology(object):
# Create events queue if there are publishers.
self._events = None
self.__events_executor = None
self.__events_executor: Any = None
if self._publish_server or self._publish_tp:
self._events = queue.Queue(maxsize=100)
if self._publish_tp:
assert self._events is not None
self._events.put((self._listeners.publish_topology_opened,
(self._topology_id,)))
self._settings = topology_settings
@ -99,6 +92,7 @@ class Topology(object):
self._description = topology_description
if self._publish_tp:
assert self._events is not None
initial_td = TopologyDescription(TOPOLOGY_TYPE.Unknown, {}, None,
None, None, self._settings)
self._events.put((
@ -107,6 +101,7 @@ class Topology(object):
for seed in topology_settings.seeds:
if self._publish_server:
assert self._events is not None
self._events.put((self._listeners.publish_server_opened,
(seed, self._topology_id)))
@ -296,6 +291,7 @@ class Topology(object):
suppress_event = ((self._publish_server or self._publish_tp)
and sd_old == server_description)
if self._publish_server and not suppress_event:
assert self._events is not None
self._events.put((
self._listeners.publish_server_description_changed,
(sd_old, server_description,
@ -306,6 +302,7 @@ class Topology(object):
self._receive_cluster_time_no_lock(server_description.cluster_time)
if self._publish_tp and not suppress_event:
assert self._events is not None
self._events.put((
self._listeners.publish_topology_description_changed,
(td_old, self._description, self._topology_id)))
@ -354,6 +351,7 @@ class Topology(object):
self._update_servers()
if self._publish_tp:
assert self._events is not None
self._events.put((
self._listeners.publish_topology_description_changed,
(td_old, self._description, self._topology_id)))
@ -485,6 +483,7 @@ class Topology(object):
# Publish only after releasing the lock.
if self._publish_tp:
assert self._events is not None
self._events.put((self._listeners.publish_topology_closed,
(self._topology_id,)))
if self._publish_server or self._publish_tp:

View File

@ -14,34 +14,48 @@
"""Represent a deployment of MongoDB servers."""
from collections import namedtuple
from random import sample
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from bson.objectid import ObjectId
from pymongo import common
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference, _AggWritePref
from pymongo.read_preferences import ReadPreference, _AggWritePref, _ServerMode
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import Selection
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import _Address
# Enumeration for various kinds of MongoDB cluster topologies.
TOPOLOGY_TYPE = namedtuple('TopologyType', [
'Single', 'ReplicaSetNoPrimary', 'ReplicaSetWithPrimary', 'Sharded',
'Unknown', 'LoadBalanced'])(*range(6))
class _TopologyType(NamedTuple):
Single: int
ReplicaSetNoPrimary: int
ReplicaSetWithPrimary: int
Sharded: int
Unknown: int
LoadBalanced: int
TOPOLOGY_TYPE = _TopologyType(*range(6))
# Topologies compatible with SRV record polling.
SRV_POLLING_TOPOLOGIES = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
SRV_POLLING_TOPOLOGIES: Tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]]
class TopologyDescription(object):
def __init__(self,
topology_type,
server_descriptions,
replica_set_name,
max_set_version,
max_election_id,
topology_settings):
def __init__(
self,
topology_type: int,
server_descriptions: Dict[_Address, ServerDescription],
replica_set_name: Optional[str],
max_set_version: Optional[int],
max_election_id: Optional[ObjectId],
topology_settings: Any,
) -> None:
"""Representation of a deployment of MongoDB servers.
:Parameters:
@ -81,7 +95,7 @@ class TopologyDescription(object):
for s in readable_servers):
self._ls_timeout_minutes = None
else:
self._ls_timeout_minutes = min(s.logical_session_timeout_minutes
self._ls_timeout_minutes = min(s.logical_session_timeout_minutes # type: ignore
for s in readable_servers)
def _init_incompatible_err(self):
@ -104,23 +118,23 @@ class TopologyDescription(object):
if server_too_new:
self._incompatible_err = (
"Server at %s:%d requires wire version %d, but this "
"Server at %s:%d requires wire version %d, but this " # type: ignore
"version of PyMongo only supports up to %d."
% (s.address[0], s.address[1],
% (s.address[0], s.address[1] or 0,
s.min_wire_version, common.MAX_SUPPORTED_WIRE_VERSION))
elif server_too_old:
self._incompatible_err = (
"Server at %s:%d reports wire version %d, but this "
"Server at %s:%d reports wire version %d, but this " # type: ignore
"version of PyMongo requires at least %d (MongoDB %s)."
% (s.address[0], s.address[1],
% (s.address[0], s.address[1] or 0,
s.max_wire_version,
common.MIN_SUPPORTED_WIRE_VERSION,
common.MIN_SUPPORTED_SERVER_VERSION))
break
def check_compatible(self):
def check_compatible(self) -> None:
"""Raise ConfigurationError if any server is incompatible.
A server is incompatible if its wire protocol version range does not
@ -129,15 +143,15 @@ class TopologyDescription(object):
if self._incompatible_err:
raise ConfigurationError(self._incompatible_err)
def has_server(self, address):
def has_server(self, address: _Address) -> bool:
return address in self._server_descriptions
def reset_server(self, address):
def reset_server(self, address: _Address) -> "TopologyDescription":
"""A copy of this description, with one server marked Unknown."""
unknown_sd = self._server_descriptions[address].to_unknown()
return updated_topology_description(self, unknown_sd)
def reset(self):
def reset(self) -> "TopologyDescription":
"""A copy of this description, with all servers marked Unknown."""
if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
@ -156,18 +170,18 @@ class TopologyDescription(object):
self._max_election_id,
self._topology_settings)
def server_descriptions(self):
def server_descriptions(self) -> Dict[_Address, ServerDescription]:
"""Dict of (address,
:class:`~pymongo.server_description.ServerDescription`)."""
return self._server_descriptions.copy()
@property
def topology_type(self):
def topology_type(self) -> int:
"""The type of this topology."""
return self._topology_type
@property
def topology_type_name(self):
def topology_type_name(self) -> str:
"""The topology type as a human readable string.
.. versionadded:: 3.4
@ -175,44 +189,44 @@ class TopologyDescription(object):
return TOPOLOGY_TYPE._fields[self._topology_type]
@property
def replica_set_name(self):
def replica_set_name(self) -> Optional[str]:
"""The replica set name."""
return self._replica_set_name
@property
def max_set_version(self):
def max_set_version(self) -> Optional[int]:
"""Greatest setVersion seen from a primary, or None."""
return self._max_set_version
@property
def max_election_id(self):
def max_election_id(self) -> Optional[ObjectId]:
"""Greatest electionId seen from a primary, or None."""
return self._max_election_id
@property
def logical_session_timeout_minutes(self):
def logical_session_timeout_minutes(self) -> Optional[int]:
"""Minimum logical session timeout, or None."""
return self._ls_timeout_minutes
@property
def known_servers(self):
def known_servers(self) -> List[ServerDescription]:
"""List of Servers of types besides Unknown."""
return [s for s in self._server_descriptions.values()
if s.is_server_type_known]
@property
def has_known_servers(self):
def has_known_servers(self) -> bool:
"""Whether there are any Servers of types besides Unknown."""
return any(s for s in self._server_descriptions.values()
if s.is_server_type_known)
@property
def readable_servers(self):
def readable_servers(self) -> List[ServerDescription]:
"""List of readable Servers."""
return [s for s in self._server_descriptions.values() if s.is_readable]
@property
def common_wire_version(self):
def common_wire_version(self) -> Optional[int]:
"""Minimum of all servers' max wire versions, or None."""
servers = self.known_servers
if servers:
@ -221,11 +235,11 @@ class TopologyDescription(object):
return None
@property
def heartbeat_frequency(self):
def heartbeat_frequency(self) -> int:
return self._topology_settings.heartbeat_frequency
@property
def srv_max_hosts(self):
def srv_max_hosts(self) -> int:
return self._topology_settings._srv_max_hosts
def _apply_local_threshold(self, selection):
@ -238,7 +252,12 @@ class TopologyDescription(object):
return [s for s in selection.server_descriptions
if (s.round_trip_time - fastest) <= threshold]
def apply_selector(self, selector, address=None, custom_selector=None):
def apply_selector(
self,
selector: Any,
address: Optional[_Address] = None,
custom_selector: Optional[_ServerSelector] = None
) -> List[ServerDescription]:
"""List of servers matching the provided selector(s).
:Parameters:
@ -288,7 +307,7 @@ class TopologyDescription(object):
custom_selector(selection.server_descriptions))
return self._apply_local_threshold(selection)
def has_readable_server(self, read_preference=ReadPreference.PRIMARY):
def has_readable_server(self, read_preference: _ServerMode =ReadPreference.PRIMARY) -> bool:
"""Does this topology have any readable servers available matching the
given read preference?
@ -305,7 +324,7 @@ class TopologyDescription(object):
common.validate_read_preference("read_preference", read_preference)
return any(self.apply_selector(read_preference))
def has_writable_server(self):
def has_writable_server(self) -> bool:
"""Does this topology have a writable server available?
.. note:: When connected directly to a single server this method
@ -336,7 +355,9 @@ _SERVER_TYPE_TO_TOPOLOGY_TYPE = {
}
def updated_topology_description(topology_description, server_description):
def updated_topology_description(
topology_description: TopologyDescription, server_description: ServerDescription
) -> "TopologyDescription":
"""Return an updated copy of a TopologyDescription.
:Parameters:

29
pymongo/typings.py Normal file
View File

@ -0,0 +1,29 @@
# Copyright 2022-Present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Type aliases used by PyMongo"""
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
Tuple, Type, TypeVar, Union)
if TYPE_CHECKING:
from bson.raw_bson import RawBSONDocument
from pymongo.collation import Collation
# Common Shared Types.
_Address = Tuple[str, Optional[int]]
_CollationIn = Union[Mapping[str, Any], "Collation"]
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
_Pipeline = List[Mapping[str, Any]]
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])

View File

@ -15,20 +15,19 @@
"""Tools to parse and validate a MongoDB URI."""
import re
import warnings
import sys
import warnings
from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple,
Union, cast)
from urllib.parse import unquote_plus
from pymongo.client_options import _parse_ssl_options
from pymongo.common import (
SRV_SERVICE_NAME,
get_validated_options, INTERNAL_URI_OPTION_NAME_MAP,
URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary)
from pymongo.common import (INTERNAL_URI_OPTION_NAME_MAP, SRV_SERVICE_NAME,
URI_OPTIONS_DEPRECATION_MAP,
_CaseInsensitiveDictionary, get_validated_options)
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver
SCHEME = 'mongodb://'
SCHEME_LEN = len(SCHEME)
SRV_SCHEME = 'mongodb+srv://'
@ -52,7 +51,7 @@ def _unquoted_percent(s):
return True
return False
def parse_userinfo(userinfo):
def parse_userinfo(userinfo: str) -> Tuple[str, str]:
"""Validates the format of user information in a MongoDB URI.
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
"]", "@") as per RFC 3986 must be escaped.
@ -76,7 +75,7 @@ def parse_userinfo(userinfo):
return unquote_plus(user), unquote_plus(passwd)
def parse_ipv6_literal_host(entity, default_port):
def parse_ipv6_literal_host(entity: str, default_port: Optional[int]) -> Tuple[str, Optional[Union[str, int]]]:
"""Validates an IPv6 literal host:port string.
Returns a 2-tuple of IPv6 literal followed by port where
@ -98,7 +97,7 @@ def parse_ipv6_literal_host(entity, default_port):
return entity[1: i], entity[i + 2:]
def parse_host(entity, default_port=DEFAULT_PORT):
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> Tuple[str, Optional[int]]:
"""Validates a host string
Returns a 2-tuple of host followed by port where port is default_port
@ -111,7 +110,7 @@ def parse_host(entity, default_port=DEFAULT_PORT):
specified in entity.
"""
host = entity
port = default_port
port: Optional[Union[str, int]] = default_port
if entity[0] == '[':
host, port = parse_ipv6_literal_host(entity, default_port)
elif entity.endswith(".sock"):
@ -279,7 +278,7 @@ def _normalize_options(options):
return options
def validate_options(opts, warn=False):
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
"""Validates and normalizes options passed in a MongoDB URI.
Returns a new dictionary of validated and normalized options. If warn is
@ -295,7 +294,7 @@ def validate_options(opts, warn=False):
return get_validated_options(opts, warn)
def split_options(opts, validate=True, warn=False, normalize=True):
def split_options(opts: str, validate: bool = True, warn: bool = False, normalize: bool = True) -> MutableMapping[str, Any]:
"""Takes the options portion of a MongoDB URI, validates each option
and returns the options in a dictionary.
@ -340,7 +339,7 @@ def split_options(opts, validate=True, warn=False, normalize=True):
return options
def split_hosts(hosts, default_port=DEFAULT_PORT):
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> List[Tuple[str, Optional[int]]]:
"""Takes a string of the form host1[:port],host2[:port]... and
splits it into (host, port) tuples. If [:port] isn't present the
default_port is used.
@ -393,9 +392,16 @@ def _check_options(nodes, options):
'Cannot specify replicaSet with loadBalanced=true')
def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
normalize=True, connect_timeout=None, srv_service_name=None,
srv_max_hosts=None):
def parse_uri(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None
) -> Dict[str, Any]:
"""Parse and validate a MongoDB URI.
Returns a dict of the form::

View File

@ -14,6 +14,8 @@
"""Tools for working with write concerns."""
from typing import Any, Dict, Optional, Union
from pymongo.errors import ConfigurationError
@ -45,8 +47,8 @@ class WriteConcern(object):
__slots__ = ("__document", "__acknowledged", "__server_default")
def __init__(self, w=None, wtimeout=None, j=None, fsync=None):
self.__document = {}
def __init__(self, w: Optional[Union[int, str]] = None, wtimeout: Optional[int] = None, j: Optional[bool] = None, fsync: Optional[bool] = None) -> None:
self.__document: Dict[str, Any] = {}
self.__acknowledged = True
if wtimeout is not None:
@ -84,12 +86,12 @@ class WriteConcern(object):
self.__server_default = not self.__document
@property
def is_server_default(self):
def is_server_default(self) -> bool:
"""Does this WriteConcern match the server default."""
return self.__server_default
@property
def document(self):
def document(self) -> Dict[str, Any]:
"""The document representation of this write concern.
.. note::
@ -99,7 +101,7 @@ class WriteConcern(object):
return self.__document.copy()
@property
def acknowledged(self):
def acknowledged(self) -> bool:
"""If ``True`` write operations will wait for acknowledgement before
returning.
"""
@ -109,12 +111,12 @@ class WriteConcern(object):
return ("WriteConcern(%s)" % (
", ".join("%s=%s" % kvt for kvt in self.__document.items()),))
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, WriteConcern):
return self.__document == other.document
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
if isinstance(other, WriteConcern):
return self.__document != other.document
return NotImplemented

View File

@ -25,7 +25,7 @@ import warnings
try:
import simplejson as json
except ImportError:
import json # type: ignore
import json # type: ignore[no-redef]
sys.path[0:0] = [""]

View File

@ -889,7 +889,7 @@ class TestCursor(IntegrationTest):
# Every attribute should be the same.
cursor2 = cursor.clone()
self.assertEqual(cursor.__dict__, cursor2.__dict__)
self.assertDictEqual(cursor.__dict__, cursor2.__dict__)
# Shallow copies can so can mutate
cursor2 = copy.copy(cursor)

View File

@ -238,7 +238,7 @@ class TestGridFile(IntegrationTest):
cursor_dict.pop('_Cursor__session')
cursor_clone_dict = cursor_clone.__dict__.copy()
cursor_clone_dict.pop('_Cursor__session')
self.assertEqual(cursor_dict, cursor_clone_dict)
self.assertDictEqual(cursor_dict, cursor_clone_dict)
self.assertRaises(NotImplementedError, cursor.add_option, 0)
self.assertRaises(NotImplementedError, cursor.remove_option, 0)

View File

@ -33,7 +33,7 @@ except:
pass
try:
from pymongo import _cmessage # type: ignore
from pymongo import _cmessage # type: ignore[attr-defined]
sys.exit("could still import _cmessage")
except ImportError:
pass