PYTHON-3060 Add typings to pymongo package (#831)
This commit is contained in:
parent
abfa0d35bc
commit
dd6c140d43
@ -61,10 +61,10 @@ import re
|
||||
import struct
|
||||
import sys
|
||||
import uuid
|
||||
from codecs import utf_8_decode as _utf_8_decode # type: ignore
|
||||
from codecs import utf_8_encode as _utf_8_encode # type: ignore
|
||||
from codecs import utf_8_decode as _utf_8_decode # type: ignore[attr-defined]
|
||||
from codecs import utf_8_encode as _utf_8_encode # type: ignore[attr-defined]
|
||||
from collections import abc as _abc
|
||||
from typing import (TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Generator,
|
||||
from typing import (IO, TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Generator,
|
||||
Iterator, List, Mapping, MutableMapping, NoReturn,
|
||||
Sequence, Tuple, Type, TypeVar, Union, cast)
|
||||
|
||||
@ -88,11 +88,13 @@ from bson.tz_util import utc
|
||||
|
||||
# Import RawBSONDocument for type-checking only to avoid circular dependency.
|
||||
if TYPE_CHECKING:
|
||||
from array import array
|
||||
from mmap import mmap
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
|
||||
|
||||
try:
|
||||
from bson import _cbson # type: ignore
|
||||
from bson import _cbson # type: ignore[attr-defined]
|
||||
_USE_C = True
|
||||
except ImportError:
|
||||
_USE_C = False
|
||||
@ -851,6 +853,7 @@ _CODEC_OPTIONS_TYPE_ERROR = TypeError(
|
||||
|
||||
_DocumentIn = Mapping[str, Any]
|
||||
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
|
||||
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"]
|
||||
|
||||
|
||||
def encode(document: _DocumentIn, check_keys: bool = False, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> bytes:
|
||||
@ -880,7 +883,7 @@ def encode(document: _DocumentIn, check_keys: bool = False, codec_options: Codec
|
||||
return _dict_to_bson(document, check_keys, codec_options)
|
||||
|
||||
|
||||
def decode(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> _DocumentOut:
|
||||
def decode(data: _ReadableBuffer, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Dict[str, Any]:
|
||||
"""Decode BSON to a document.
|
||||
|
||||
By default, returns a BSON document represented as a Python
|
||||
@ -912,7 +915,7 @@ def decode(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) ->
|
||||
return _bson_to_dict(data, codec_options)
|
||||
|
||||
|
||||
def decode_all(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> List[_DocumentOut]:
|
||||
def decode_all(data: _ReadableBuffer, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> List[Dict[str, Any]]:
|
||||
"""Decode BSON data to multiple documents.
|
||||
|
||||
`data` must be a bytes-like object implementing the buffer protocol that
|
||||
@ -1075,7 +1078,7 @@ def decode_iter(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS
|
||||
yield _bson_to_dict(elements, codec_options)
|
||||
|
||||
|
||||
def decode_file_iter(file_obj: BinaryIO, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Iterator[_DocumentOut]:
|
||||
def decode_file_iter(file_obj: Union[BinaryIO, IO], codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Iterator[_DocumentOut]:
|
||||
"""Decode bson data from a file to multiple documents as a generator.
|
||||
|
||||
Works similarly to the decode_all function, but reads from the file object
|
||||
@ -1158,7 +1161,7 @@ class BSON(bytes):
|
||||
"""
|
||||
return cls(encode(document, check_keys, codec_options))
|
||||
|
||||
def decode(self, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> _DocumentOut: # type: ignore[override]
|
||||
def decode(self, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Dict[str, Any]: # type: ignore[override]
|
||||
"""Decode this BSON data.
|
||||
|
||||
By default, returns a BSON document represented as a Python
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Tuple, Type
|
||||
from typing import Any, Tuple, Type, Union, TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
"""Tools for representing BSON binary data.
|
||||
@ -57,6 +57,11 @@ by :mod:`bson` using this subtype when using
|
||||
"""
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from array import array as _array
|
||||
from mmap import mmap as _mmap
|
||||
|
||||
|
||||
class UuidRepresentation:
|
||||
UNSPECIFIED = 0
|
||||
"""An unspecified UUID representation.
|
||||
@ -211,7 +216,7 @@ class Binary(bytes):
|
||||
_type_marker = 5
|
||||
__subtype: int
|
||||
|
||||
def __new__(cls: Type["Binary"], data: bytes, subtype: int = BINARY_SUBTYPE) -> "Binary":
|
||||
def __new__(cls: Type["Binary"], data: Union[memoryview, bytes, "_mmap", "_array"], subtype: int = BINARY_SUBTYPE) -> "Binary":
|
||||
if not isinstance(subtype, int):
|
||||
raise TypeError("subtype must be an instance of int")
|
||||
if subtype >= 256 or subtype < 0:
|
||||
|
||||
@ -874,10 +874,10 @@ class GridOutCursor(Cursor):
|
||||
|
||||
__next__ = next
|
||||
|
||||
def add_option(self, *args: Any, **kwargs: Any) -> None:
|
||||
def add_option(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
|
||||
raise NotImplementedError("Method does not exist for GridOutCursor")
|
||||
|
||||
def remove_option(self, *args: Any, **kwargs: Any) -> None:
|
||||
def remove_option(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
|
||||
raise NotImplementedError("Method does not exist for GridOutCursor")
|
||||
|
||||
def _clone_base(self, session: ClientSession) -> "GridOutCursor":
|
||||
|
||||
22
mypy.ini
22
mypy.ini
@ -1,11 +1,33 @@
|
||||
[mypy]
|
||||
check_untyped_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_incomplete_defs = true
|
||||
no_implicit_optional = true
|
||||
pretty = true
|
||||
show_error_context = true
|
||||
show_error_codes = true
|
||||
strict_equality = true
|
||||
warn_unused_configs = true
|
||||
warn_unused_ignores = true
|
||||
warn_redundant_casts = true
|
||||
|
||||
[mypy-kerberos.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-mockupdb]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-pymongo_auth_aws.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-pymongocrypt.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-service_identity.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-snappy.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-winkerberos.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
|
||||
"""Python driver for MongoDB."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
ASCENDING = 1
|
||||
"""Ascending sort order."""
|
||||
DESCENDING = -1
|
||||
@ -53,35 +55,33 @@ TEXT = "text"
|
||||
.. _text index: http://docs.mongodb.org/manual/core/index-text/
|
||||
"""
|
||||
|
||||
version_tuple = (4, 1, 0, '.dev0')
|
||||
version_tuple: Tuple[Union[int, str], ...] = (4, 1, 0, '.dev0')
|
||||
|
||||
def get_version_string():
|
||||
def get_version_string() -> str:
|
||||
if isinstance(version_tuple[-1], str):
|
||||
return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1]
|
||||
return '.'.join(map(str, version_tuple))
|
||||
|
||||
__version__ = version = get_version_string()
|
||||
__version__: str = get_version_string()
|
||||
version = __version__
|
||||
|
||||
"""Current version of PyMongo."""
|
||||
|
||||
from pymongo.collection import ReturnDocument
|
||||
from pymongo.common import (MIN_SUPPORTED_WIRE_VERSION,
|
||||
MAX_SUPPORTED_WIRE_VERSION)
|
||||
from pymongo.common import (MAX_SUPPORTED_WIRE_VERSION,
|
||||
MIN_SUPPORTED_WIRE_VERSION)
|
||||
from pymongo.cursor import CursorType
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.operations import (IndexModel,
|
||||
InsertOne,
|
||||
DeleteOne,
|
||||
DeleteMany,
|
||||
UpdateOne,
|
||||
UpdateMany,
|
||||
ReplaceOne)
|
||||
from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, InsertOne,
|
||||
ReplaceOne, UpdateMany, UpdateOne)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
def has_c():
|
||||
|
||||
def has_c() -> bool:
|
||||
"""Is the C extension installed?"""
|
||||
try:
|
||||
from pymongo import _cmessage
|
||||
from pymongo import _cmessage # type: ignore[attr-defined]
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@ -15,11 +15,10 @@
|
||||
"""Perform aggregation operations on a collection or database."""
|
||||
|
||||
from bson.son import SON
|
||||
|
||||
from pymongo import common
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_preferences import _AggWritePref, ReadPreference
|
||||
from pymongo.read_preferences import ReadPreference, _AggWritePref
|
||||
|
||||
|
||||
class _AggregationCommand(object):
|
||||
@ -37,7 +36,7 @@ class _AggregationCommand(object):
|
||||
|
||||
self._target = target
|
||||
|
||||
common.validate_list('pipeline', pipeline)
|
||||
pipeline = common.validate_list('pipeline', pipeline)
|
||||
self._pipeline = pipeline
|
||||
self._performs_write = False
|
||||
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
|
||||
@ -82,7 +81,6 @@ class _AggregationCommand(object):
|
||||
"""The namespace in which the aggregate command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _cursor_collection(self, cursor_doc):
|
||||
"""The Collection used for the aggregate command cursor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -19,9 +19,9 @@ import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import socket
|
||||
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from collections import namedtuple
|
||||
from typing import Callable, Mapping
|
||||
from urllib.parse import quote
|
||||
|
||||
from bson.binary import Binary
|
||||
@ -97,7 +97,7 @@ GSSAPIProperties = namedtuple('GSSAPIProperties',
|
||||
"""Mechanism properties for GSSAPI authentication."""
|
||||
|
||||
|
||||
_AWSProperties = namedtuple('AWSProperties', ['aws_session_token'])
|
||||
_AWSProperties = namedtuple('_AWSProperties', ['aws_session_token'])
|
||||
"""Mechanism properties for MONGODB-AWS authentication."""
|
||||
|
||||
|
||||
@ -140,9 +140,9 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
|
||||
|
||||
properties = extra.get('authmechanismproperties', {})
|
||||
aws_session_token = properties.get('AWS_SESSION_TOKEN')
|
||||
props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
aws_props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
# user can be None for temporary link-local EC2 credentials.
|
||||
return MongoCredential(mech, '$external', user, passwd, props, None)
|
||||
return MongoCredential(mech, '$external', user, passwd, aws_props, None)
|
||||
elif mech == 'PLAIN':
|
||||
source_database = source or database or '$external'
|
||||
return MongoCredential(mech, source_database, user, passwd, None, None)
|
||||
@ -471,7 +471,7 @@ def _authenticate_default(credentials, sock_info):
|
||||
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1')
|
||||
|
||||
|
||||
_AUTH_MAP = {
|
||||
_AUTH_MAP: Mapping[str, Callable] = {
|
||||
'GSSAPI': _authenticate_gssapi,
|
||||
'MONGODB-CR': _authenticate_mongo_cr,
|
||||
'MONGODB-X509': _authenticate_x509,
|
||||
@ -532,7 +532,7 @@ class _X509Context(_AuthContext):
|
||||
return cmd
|
||||
|
||||
|
||||
_SPECULATIVE_AUTH_MAP = {
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
|
||||
'MONGODB-X509': _X509Context,
|
||||
'SCRAM-SHA-1': functools.partial(_ScramContext, mechanism='SCRAM-SHA-1'),
|
||||
'SCRAM-SHA-256': functools.partial(_ScramContext,
|
||||
@ -544,6 +544,6 @@ _SPECULATIVE_AUTH_MAP = {
|
||||
def authenticate(credentials, sock_info):
|
||||
"""Authenticate sock_info."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP.get(mechanism)
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
auth_func(credentials, sock_info)
|
||||
|
||||
|
||||
@ -16,12 +16,11 @@
|
||||
|
||||
try:
|
||||
import pymongo_auth_aws
|
||||
from pymongo_auth_aws import (AwsCredential,
|
||||
AwsSaslContext,
|
||||
from pymongo_auth_aws import (AwsCredential, AwsSaslContext,
|
||||
PyMongoAuthAwsError)
|
||||
_HAVE_MONGODB_AWS = True
|
||||
except ImportError:
|
||||
class AwsSaslContext(object):
|
||||
class AwsSaslContext(object): # type: ignore
|
||||
def __init__(self, credentials):
|
||||
pass
|
||||
_HAVE_MONGODB_AWS = False
|
||||
@ -32,7 +31,7 @@ from bson.son import SON
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
|
||||
|
||||
class _AwsSaslContext(AwsSaslContext):
|
||||
class _AwsSaslContext(AwsSaslContext): # type: ignore
|
||||
# Dependency injection:
|
||||
def binary_type(self):
|
||||
"""Return the bson.binary.Binary type."""
|
||||
|
||||
@ -17,31 +17,23 @@
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
import copy
|
||||
|
||||
from itertools import islice
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.son import SON
|
||||
from pymongo.client_session import _validate_session_write_concern
|
||||
from pymongo.common import (validate_is_mapping,
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update)
|
||||
from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import (BulkWriteError,
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
OperationFailure)
|
||||
from pymongo.message import (_INSERT, _UPDATE, _DELETE,
|
||||
_randint,
|
||||
_BulkWriteContext,
|
||||
_EncryptedBulkWriteContext)
|
||||
from pymongo.common import (validate_is_document_type, validate_is_mapping,
|
||||
validate_ok_for_replace, validate_ok_for_update)
|
||||
from pymongo.errors import (BulkWriteError, ConfigurationError,
|
||||
InvalidOperation, OperationFailure)
|
||||
from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc
|
||||
from pymongo.message import (_DELETE, _INSERT, _UPDATE, _BulkWriteContext,
|
||||
_EncryptedBulkWriteContext, _randint)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
_DELETE_ALL = 0
|
||||
_DELETE_ONE = 1
|
||||
|
||||
|
||||
@ -15,21 +15,20 @@
|
||||
"""Watch changes on a collection, a database, or the entire cluster."""
|
||||
|
||||
import copy
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterator, Mapping,
|
||||
Optional, Union)
|
||||
|
||||
from bson import _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import common
|
||||
from pymongo.aggregation import (_CollectionAggregationCommand,
|
||||
_DatabaseAggregationCommand)
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.errors import (ConnectionFailure,
|
||||
CursorNotFound,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
PyMongoError)
|
||||
|
||||
from pymongo.errors import (ConnectionFailure, CursorNotFound,
|
||||
InvalidOperation, OperationFailure, PyMongoError)
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
|
||||
# The change streams spec considers the following server errors from the
|
||||
# getMore command non-resumable. All other getMore errors are resumable.
|
||||
@ -55,7 +54,14 @@ _RESUMABLE_GETMORE_ERRORS = frozenset([
|
||||
])
|
||||
|
||||
|
||||
class ChangeStream(object):
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.mongo_client import MongoClient
|
||||
|
||||
|
||||
class ChangeStream(Generic[_DocumentType]):
|
||||
"""The internal abstract base class for change stream cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
@ -66,14 +72,22 @@ class ChangeStream(object):
|
||||
.. versionadded:: 3.6
|
||||
.. seealso:: The MongoDB documentation on `changeStreams <https://dochub.mongodb.org/core/changeStreams>`_.
|
||||
"""
|
||||
def __init__(self, target, pipeline, full_document, resume_after,
|
||||
max_await_time_ms, batch_size, collation,
|
||||
start_at_operation_time, session, start_after):
|
||||
def __init__(
|
||||
self,
|
||||
target: Union["MongoClient[_DocumentType]", "Database[_DocumentType]", "Collection[_DocumentType]"],
|
||||
pipeline: Optional[_Pipeline],
|
||||
full_document: Optional[str],
|
||||
resume_after: Optional[Mapping[str, Any]],
|
||||
max_await_time_ms: Optional[int],
|
||||
batch_size: Optional[int],
|
||||
collation: Optional[_CollationIn],
|
||||
start_at_operation_time: Optional[Timestamp],
|
||||
session: Optional["ClientSession"],
|
||||
start_after: Optional[Mapping[str, Any]],
|
||||
) -> None:
|
||||
if pipeline is None:
|
||||
pipeline = []
|
||||
elif not isinstance(pipeline, list):
|
||||
raise TypeError("pipeline must be a list")
|
||||
|
||||
pipeline = common.validate_list('pipeline', pipeline)
|
||||
common.validate_string_or_none('full_document', full_document)
|
||||
validate_collation_or_none(collation)
|
||||
common.validate_non_negative_integer_or_none("batchSize", batch_size)
|
||||
@ -84,7 +98,7 @@ class ChangeStream(object):
|
||||
self._decode_custom = True
|
||||
# Keep the type registry so that we support encoding custom types
|
||||
# in the pipeline.
|
||||
self._target = target.with_options(
|
||||
self._target = target.with_options( # type: ignore
|
||||
codec_options=target.codec_options.with_options(
|
||||
document_class=RawBSONDocument))
|
||||
else:
|
||||
@ -117,7 +131,7 @@ class ChangeStream(object):
|
||||
|
||||
def _change_stream_options(self):
|
||||
"""Return the options dict for the $changeStream pipeline stage."""
|
||||
options = {}
|
||||
options: Dict[str, Any] = {}
|
||||
if self._full_document is not None:
|
||||
options['fullDocument'] = self._full_document
|
||||
|
||||
@ -144,7 +158,7 @@ class ChangeStream(object):
|
||||
def _aggregation_pipeline(self):
|
||||
"""Return the full aggregation pipeline for this ChangeStream."""
|
||||
options = self._change_stream_options()
|
||||
full_pipeline = [{'$changeStream': options}]
|
||||
full_pipeline: list = [{'$changeStream': options}]
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
@ -197,15 +211,15 @@ class ChangeStream(object):
|
||||
pass
|
||||
self._cursor = self._create_cursor()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close this ChangeStream."""
|
||||
self._cursor.close()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> "ChangeStream[_DocumentType]":
|
||||
return self
|
||||
|
||||
@property
|
||||
def resume_token(self):
|
||||
def resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The cached resume token that will be used to resume after the most
|
||||
recently returned change.
|
||||
|
||||
@ -213,7 +227,7 @@ class ChangeStream(object):
|
||||
"""
|
||||
return copy.deepcopy(self._resume_token)
|
||||
|
||||
def next(self):
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor.
|
||||
|
||||
This method blocks until the next change document is returned or an
|
||||
@ -255,7 +269,7 @@ class ChangeStream(object):
|
||||
__next__ = next
|
||||
|
||||
@property
|
||||
def alive(self):
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
@ -265,7 +279,7 @@ class ChangeStream(object):
|
||||
"""
|
||||
return self._cursor.alive
|
||||
|
||||
def try_next(self):
|
||||
def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next change document without waiting
|
||||
@ -354,14 +368,14 @@ class ChangeStream(object):
|
||||
return _bson_to_dict(change.raw, self._orig_codec_options)
|
||||
return change
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "ChangeStream":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class CollectionChangeStream(ChangeStream):
|
||||
class CollectionChangeStream(ChangeStream, Generic[_DocumentType]):
|
||||
"""A change stream that watches changes on a single collection.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
@ -378,7 +392,7 @@ class CollectionChangeStream(ChangeStream):
|
||||
return self._target.database.client
|
||||
|
||||
|
||||
class DatabaseChangeStream(ChangeStream):
|
||||
class DatabaseChangeStream(ChangeStream, Generic[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in a database.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
@ -395,7 +409,7 @@ class DatabaseChangeStream(ChangeStream):
|
||||
return self._target.client
|
||||
|
||||
|
||||
class ClusterChangeStream(DatabaseChangeStream):
|
||||
class ClusterChangeStream(DatabaseChangeStream, Generic[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in the cluster.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
|
||||
@ -15,9 +15,9 @@
|
||||
"""Tools to parse mongo client options."""
|
||||
|
||||
from bson.codec_options import _parse_codec_options
|
||||
from pymongo import common
|
||||
from pymongo.auth import _build_credentials_tuple
|
||||
from pymongo.common import validate_boolean
|
||||
from pymongo import common
|
||||
from pymongo.compression_support import CompressionSettings
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.monitoring import _EventListeners
|
||||
|
||||
@ -134,25 +134,23 @@ Classes
|
||||
import collections
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from collections.abc import Mapping as _Mapping
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ContextManager, Generic,
|
||||
Mapping, Optional, TypeVar)
|
||||
|
||||
from bson.binary import Binary
|
||||
from bson.int64 import Int64
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
|
||||
from pymongo.cursor import _SocketManager
|
||||
from pymongo.errors import (ConfigurationError,
|
||||
ConnectionFailure,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
from pymongo.errors import (ConfigurationError, ConnectionFailure,
|
||||
InvalidOperation, OperationFailure, PyMongoError,
|
||||
WTimeoutError)
|
||||
from pymongo.helpers import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import _DocumentType
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
@ -172,10 +170,12 @@ class SessionOptions(object):
|
||||
.. versionchanged:: 3.12
|
||||
Added the ``snapshot`` parameter.
|
||||
"""
|
||||
def __init__(self,
|
||||
causal_consistency=None,
|
||||
default_transaction_options=None,
|
||||
snapshot=False):
|
||||
def __init__(
|
||||
self,
|
||||
causal_consistency: Optional[bool] = None,
|
||||
default_transaction_options: Optional["TransactionOptions"] = None,
|
||||
snapshot: Optional[bool] = False,
|
||||
) -> None:
|
||||
if snapshot:
|
||||
if causal_consistency:
|
||||
raise ConfigurationError('snapshot reads do not support '
|
||||
@ -194,12 +194,12 @@ class SessionOptions(object):
|
||||
self._snapshot = snapshot
|
||||
|
||||
@property
|
||||
def causal_consistency(self):
|
||||
def causal_consistency(self) -> bool:
|
||||
"""Whether causal consistency is configured."""
|
||||
return self._causal_consistency
|
||||
|
||||
@property
|
||||
def default_transaction_options(self):
|
||||
def default_transaction_options(self) -> Optional["TransactionOptions"]:
|
||||
"""The default TransactionOptions to use for transactions started on
|
||||
this session.
|
||||
|
||||
@ -208,7 +208,7 @@ class SessionOptions(object):
|
||||
return self._default_transaction_options
|
||||
|
||||
@property
|
||||
def snapshot(self):
|
||||
def snapshot(self) -> Optional[bool]:
|
||||
"""Whether snapshot reads are configured.
|
||||
|
||||
.. versionadded:: 3.12
|
||||
@ -243,8 +243,13 @@ class TransactionOptions(object):
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
def __init__(self, read_concern=None, write_concern=None,
|
||||
read_preference=None, max_commit_time_ms=None):
|
||||
def __init__(
|
||||
self,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
max_commit_time_ms: Optional[int] = None
|
||||
) -> None:
|
||||
self._read_concern = read_concern
|
||||
self._write_concern = write_concern
|
||||
self._read_preference = read_preference
|
||||
@ -274,23 +279,23 @@ class TransactionOptions(object):
|
||||
"max_commit_time_ms must be an integer or None")
|
||||
|
||||
@property
|
||||
def read_concern(self):
|
||||
def read_concern(self) -> Optional[ReadConcern]:
|
||||
"""This transaction's :class:`~pymongo.read_concern.ReadConcern`."""
|
||||
return self._read_concern
|
||||
|
||||
@property
|
||||
def write_concern(self):
|
||||
def write_concern(self) -> Optional[WriteConcern]:
|
||||
"""This transaction's :class:`~pymongo.write_concern.WriteConcern`."""
|
||||
return self._write_concern
|
||||
|
||||
@property
|
||||
def read_preference(self):
|
||||
def read_preference(self) -> Optional[_ServerMode]:
|
||||
"""This transaction's :class:`~pymongo.read_preferences.ReadPreference`.
|
||||
"""
|
||||
return self._read_preference
|
||||
|
||||
@property
|
||||
def max_commit_time_ms(self):
|
||||
def max_commit_time_ms(self) -> Optional[int]:
|
||||
"""The maxTimeMS to use when running a commitTransaction command.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
@ -427,7 +432,13 @@ def _within_time_limit(start_time):
|
||||
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
|
||||
|
||||
|
||||
class ClientSession(object):
|
||||
_T = TypeVar("_T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.mongo_client import MongoClient
|
||||
|
||||
|
||||
class ClientSession(Generic[_DocumentType]):
|
||||
"""A session for ordering sequential operations.
|
||||
|
||||
:class:`ClientSession` instances are **not thread-safe or fork-safe**.
|
||||
@ -439,9 +450,11 @@ class ClientSession(object):
|
||||
:class:`ClientSession`, call
|
||||
:meth:`~pymongo.mongo_client.MongoClient.start_session`.
|
||||
"""
|
||||
def __init__(self, client, server_session, options, implicit):
|
||||
def __init__(
|
||||
self, client: "MongoClient[_DocumentType]", server_session: Any, options: SessionOptions, implicit: bool
|
||||
) -> None:
|
||||
# A MongoClient, a _ServerSession, a SessionOptions, and a set.
|
||||
self._client = client
|
||||
self._client: MongoClient[_DocumentType] = client
|
||||
self._server_session = server_session
|
||||
self._options = options
|
||||
self._cluster_time = None
|
||||
@ -451,7 +464,7 @@ class ClientSession(object):
|
||||
self._implicit = implicit
|
||||
self._transaction = _Transaction(None, client)
|
||||
|
||||
def end_session(self):
|
||||
def end_session(self) -> None:
|
||||
"""Finish this session. If a transaction has started, abort it.
|
||||
|
||||
It is an error to use the session after the session has ended.
|
||||
@ -474,39 +487,39 @@ class ClientSession(object):
|
||||
if self._server_session is None:
|
||||
raise InvalidOperation("Cannot use ended session")
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "ClientSession[_DocumentType]":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self._end_session(lock=True)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
def client(self) -> "MongoClient[_DocumentType]":
|
||||
"""The :class:`~pymongo.mongo_client.MongoClient` this session was
|
||||
created from.
|
||||
"""
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
def options(self) -> SessionOptions:
|
||||
"""The :class:`SessionOptions` this session was created with."""
|
||||
return self._options
|
||||
|
||||
@property
|
||||
def session_id(self):
|
||||
def session_id(self) -> Mapping[str, Any]:
|
||||
"""A BSON document, the opaque server session identifier."""
|
||||
self._check_ended()
|
||||
return self._server_session.session_id
|
||||
|
||||
@property
|
||||
def cluster_time(self):
|
||||
def cluster_time(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The cluster time returned by the last operation executed
|
||||
in this session.
|
||||
"""
|
||||
return self._cluster_time
|
||||
|
||||
@property
|
||||
def operation_time(self):
|
||||
def operation_time(self) -> Optional[Timestamp]:
|
||||
"""The operation time returned by the last operation executed
|
||||
in this session.
|
||||
"""
|
||||
@ -522,8 +535,14 @@ class ClientSession(object):
|
||||
return val
|
||||
return getattr(self.client, name)
|
||||
|
||||
def with_transaction(self, callback, read_concern=None, write_concern=None,
|
||||
read_preference=None, max_commit_time_ms=None):
|
||||
def with_transaction(
|
||||
self,
|
||||
callback: Callable[["ClientSession"], _T],
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
max_commit_time_ms: Optional[int] = None,
|
||||
) -> _T:
|
||||
"""Execute a callback in a transaction.
|
||||
|
||||
This method starts a transaction on this session, executes ``callback``
|
||||
@ -649,8 +668,13 @@ class ClientSession(object):
|
||||
# Commit succeeded.
|
||||
return ret
|
||||
|
||||
def start_transaction(self, read_concern=None, write_concern=None,
|
||||
read_preference=None, max_commit_time_ms=None):
|
||||
def start_transaction(
|
||||
self,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
max_commit_time_ms: Optional[int] = None,
|
||||
) -> ContextManager:
|
||||
"""Start a multi-statement transaction.
|
||||
|
||||
Takes the same arguments as :class:`TransactionOptions`.
|
||||
@ -685,7 +709,7 @@ class ClientSession(object):
|
||||
self._start_retryable_write()
|
||||
return _TransactionContext(self)
|
||||
|
||||
def commit_transaction(self):
|
||||
def commit_transaction(self) -> None:
|
||||
"""Commit a multi-statement transaction.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
@ -729,7 +753,7 @@ class ClientSession(object):
|
||||
finally:
|
||||
self._transaction.state = _TxnState.COMMITTED
|
||||
|
||||
def abort_transaction(self):
|
||||
def abort_transaction(self) -> None:
|
||||
"""Abort a multi-statement transaction.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
@ -804,7 +828,7 @@ class ClientSession(object):
|
||||
if cluster_time["clusterTime"] > self._cluster_time["clusterTime"]:
|
||||
self._cluster_time = cluster_time
|
||||
|
||||
def advance_cluster_time(self, cluster_time):
|
||||
def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None:
|
||||
"""Update the cluster time for this session.
|
||||
|
||||
:Parameters:
|
||||
@ -827,7 +851,7 @@ class ClientSession(object):
|
||||
if operation_time > self._operation_time:
|
||||
self._operation_time = operation_time
|
||||
|
||||
def advance_operation_time(self, operation_time):
|
||||
def advance_operation_time(self, operation_time: Timestamp) -> None:
|
||||
"""Update the operation time for this session.
|
||||
|
||||
:Parameters:
|
||||
@ -856,12 +880,12 @@ class ClientSession(object):
|
||||
self._transaction.recovery_token = recovery_token
|
||||
|
||||
@property
|
||||
def has_ended(self):
|
||||
def has_ended(self) -> bool:
|
||||
"""True if this session is finished."""
|
||||
return self._server_session is None
|
||||
|
||||
@property
|
||||
def in_transaction(self):
|
||||
def in_transaction(self) -> bool:
|
||||
"""True if this session has an active multi-statement transaction.
|
||||
|
||||
.. versionadded:: 3.10
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
.. _collations: http://userguide.icu-project.org/collation/concepts
|
||||
"""
|
||||
from typing import Any, Dict, Mapping, Optional, Union
|
||||
|
||||
from pymongo import common
|
||||
|
||||
@ -151,18 +152,18 @@ class Collation(object):
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(self, locale,
|
||||
caseLevel=None,
|
||||
caseFirst=None,
|
||||
strength=None,
|
||||
numericOrdering=None,
|
||||
alternate=None,
|
||||
maxVariable=None,
|
||||
normalization=None,
|
||||
backwards=None,
|
||||
**kwargs):
|
||||
def __init__(self, locale: str,
|
||||
caseLevel: Optional[bool] = None,
|
||||
caseFirst: Optional[str] = None,
|
||||
strength: Optional[int] = None,
|
||||
numericOrdering: Optional[bool] = None,
|
||||
alternate: Optional[str] = None,
|
||||
maxVariable: Optional[str] = None,
|
||||
normalization: Optional[bool] = None,
|
||||
backwards: Optional[bool] = None,
|
||||
**kwargs: Any) -> None:
|
||||
locale = common.validate_string('locale', locale)
|
||||
self.__document = {'locale': locale}
|
||||
self.__document: Dict[str, Any] = {'locale': locale}
|
||||
if caseLevel is not None:
|
||||
self.__document['caseLevel'] = common.validate_boolean(
|
||||
'caseLevel', caseLevel)
|
||||
@ -190,7 +191,7 @@ class Collation(object):
|
||||
self.__document.update(kwargs)
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
def document(self) -> Dict[str, Any]:
|
||||
"""The document representation of this collation.
|
||||
|
||||
.. note::
|
||||
@ -204,16 +205,16 @@ class Collation(object):
|
||||
return 'Collation(%s)' % (
|
||||
', '.join('%s=%r' % (key, document[key]) for key in document),)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Collation):
|
||||
return self.document == other.document
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
def validate_collation_or_none(value):
|
||||
def validate_collation_or_none(value: Optional[Union[Mapping[str, Any], Collation]]) -> Optional[Dict[str, Any]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, Collation):
|
||||
|
||||
@ -14,44 +14,45 @@
|
||||
|
||||
"""Collection level utilities for Mongo."""
|
||||
|
||||
import datetime
|
||||
import warnings
|
||||
|
||||
from collections import abc
|
||||
from typing import (TYPE_CHECKING, Any, Generic, Iterable, List, Mapping,
|
||||
MutableMapping, Optional, Sequence, Tuple, Union)
|
||||
|
||||
from bson.code import Code
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.son import SON
|
||||
from pymongo import (common,
|
||||
helpers,
|
||||
message)
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import common, helpers, message
|
||||
from pymongo.aggregation import (_CollectionAggregationCommand,
|
||||
_CollectionRawAggregationCommand)
|
||||
from pymongo.bulk import _Bulk
|
||||
from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.change_stream import CollectionChangeStream
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor
|
||||
from pymongo.cursor import Cursor, RawBatchCursor
|
||||
from pymongo.errors import (ConfigurationError,
|
||||
InvalidName,
|
||||
InvalidOperation,
|
||||
from pymongo.errors import (ConfigurationError, InvalidName, InvalidOperation,
|
||||
OperationFailure)
|
||||
from pymongo.helpers import _check_write_command_response
|
||||
from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS
|
||||
from pymongo.operations import IndexModel
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.results import (BulkWriteResult,
|
||||
DeleteResult,
|
||||
InsertOneResult,
|
||||
InsertManyResult,
|
||||
UpdateResult)
|
||||
from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, InsertOne,
|
||||
ReplaceOne, UpdateMany, UpdateOne)
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import (BulkWriteResult, DeleteResult, InsertManyResult,
|
||||
InsertOneResult, UpdateResult)
|
||||
from pymongo.typings import _CollationIn, _DocumentIn, _DocumentType, _Pipeline
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_FIND_AND_MODIFY_DOC_FIELDS = {'value': 1}
|
||||
|
||||
|
||||
_WriteOp = Union[InsertOne, DeleteOne, DeleteMany, ReplaceOne, UpdateOne, UpdateMany]
|
||||
# Hint supports index name, "myIndex", or list of index pairs: [('x', 1), ('y', -1)]
|
||||
_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
|
||||
_IndexKeyHint = Union[str, _IndexList]
|
||||
|
||||
|
||||
class ReturnDocument(object):
|
||||
"""An enum used with
|
||||
:meth:`~pymongo.collection.Collection.find_one_and_replace` and
|
||||
@ -65,13 +66,28 @@ class ReturnDocument(object):
|
||||
"""Return the updated/replaced or inserted document."""
|
||||
|
||||
|
||||
class Collection(common.BaseObject):
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.database import Database
|
||||
from pymongo.read_concern import ReadConcern
|
||||
|
||||
|
||||
class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""A Mongo collection.
|
||||
"""
|
||||
|
||||
def __init__(self, database, name, create=False, codec_options=None,
|
||||
read_preference=None, write_concern=None, read_concern=None,
|
||||
session=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
database: "Database[_DocumentType]",
|
||||
name: str,
|
||||
create: Optional[bool] = False,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional["ReadConcern"] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Get / create a Mongo collection.
|
||||
|
||||
Raises :class:`TypeError` if `name` is not an instance of
|
||||
@ -169,7 +185,7 @@ class Collection(common.BaseObject):
|
||||
"null character")
|
||||
collation = validate_collation_or_none(kwargs.pop('collation', None))
|
||||
|
||||
self.__database = database
|
||||
self.__database: Database[_DocumentType] = database
|
||||
self.__name = name
|
||||
self.__full_name = "%s.%s" % (self.__database.name, self.__name)
|
||||
if create or kwargs or collation:
|
||||
@ -252,7 +268,7 @@ class Collection(common.BaseObject):
|
||||
write_concern=self._write_concern_for(session),
|
||||
collation=collation, session=session)
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> "Collection[_DocumentType]":
|
||||
"""Get a sub-collection of this collection by name.
|
||||
|
||||
Raises InvalidName if an invalid collection name is used.
|
||||
@ -268,7 +284,7 @@ class Collection(common.BaseObject):
|
||||
name, full_name, full_name))
|
||||
return self.__getitem__(name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
def __getitem__(self, name: str) -> "Collection[_DocumentType]":
|
||||
return Collection(self.__database,
|
||||
"%s.%s" % (self.__name, name),
|
||||
False,
|
||||
@ -280,25 +296,25 @@ class Collection(common.BaseObject):
|
||||
def __repr__(self):
|
||||
return "Collection(%r, %r)" % (self.__database, self.__name)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Collection):
|
||||
return (self.__database == other.database and
|
||||
self.__name == other.name)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.__database, self.__name))
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
raise NotImplementedError("Collection objects do not implement truth "
|
||||
"value testing or bool(). Please compare "
|
||||
"with None instead: collection is not None")
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
def full_name(self) -> str:
|
||||
"""The full name of this :class:`Collection`.
|
||||
|
||||
The full name is of the form `database_name.collection_name`.
|
||||
@ -306,19 +322,24 @@ class Collection(common.BaseObject):
|
||||
return self.__full_name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""The name of this :class:`Collection`."""
|
||||
return self.__name
|
||||
|
||||
@property
|
||||
def database(self):
|
||||
def database(self) -> "Database[_DocumentType]":
|
||||
"""The :class:`~pymongo.database.Database` that this
|
||||
:class:`Collection` is a part of.
|
||||
"""
|
||||
return self.__database
|
||||
|
||||
def with_options(self, codec_options=None, read_preference=None,
|
||||
write_concern=None, read_concern=None):
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional["ReadConcern"] = None,
|
||||
) -> "Collection[_DocumentType]":
|
||||
"""Get a clone of this collection changing the specified settings.
|
||||
|
||||
>>> coll1.read_preference
|
||||
@ -356,8 +377,13 @@ class Collection(common.BaseObject):
|
||||
write_concern or self.write_concern,
|
||||
read_concern or self.read_concern)
|
||||
|
||||
def bulk_write(self, requests, ordered=True,
|
||||
bypass_document_validation=False, session=None):
|
||||
def bulk_write(
|
||||
self,
|
||||
requests: Sequence[_WriteOp],
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: bool = False,
|
||||
session: Optional["ClientSession"] = None
|
||||
) -> BulkWriteResult:
|
||||
"""Send a batch of write operations to the server.
|
||||
|
||||
Requests are passed as a list of write operation instances (
|
||||
@ -470,8 +496,10 @@ class Collection(common.BaseObject):
|
||||
if not isinstance(doc, RawBSONDocument):
|
||||
return doc.get('_id')
|
||||
|
||||
def insert_one(self, document, bypass_document_validation=False,
|
||||
session=None):
|
||||
def insert_one(self, document: _DocumentIn,
|
||||
bypass_document_validation: bool = False,
|
||||
session: Optional["ClientSession"] = None
|
||||
) -> InsertOneResult:
|
||||
"""Insert a single document.
|
||||
|
||||
>>> db.test.count_documents({'x': 1})
|
||||
@ -520,8 +548,12 @@ class Collection(common.BaseObject):
|
||||
bypass_doc_val=bypass_document_validation, session=session),
|
||||
write_concern.acknowledged)
|
||||
|
||||
def insert_many(self, documents, ordered=True,
|
||||
bypass_document_validation=False, session=None):
|
||||
def insert_many(self,
|
||||
documents: Iterable[_DocumentIn],
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: bool = False,
|
||||
session: Optional["ClientSession"] = None
|
||||
) -> InsertManyResult:
|
||||
"""Insert an iterable of documents.
|
||||
|
||||
>>> db.test.count_documents({})
|
||||
@ -565,7 +597,7 @@ class Collection(common.BaseObject):
|
||||
or isinstance(documents, abc.Mapping)
|
||||
or not documents):
|
||||
raise TypeError("documents must be a non-empty list")
|
||||
inserted_ids = []
|
||||
inserted_ids: List[ObjectId] = []
|
||||
def gen():
|
||||
"""A generator that validates documents and handles _ids."""
|
||||
for document in documents:
|
||||
@ -671,9 +703,16 @@ class Collection(common.BaseObject):
|
||||
(write_concern or self.write_concern).acknowledged and not multi,
|
||||
_update, session)
|
||||
|
||||
def replace_one(self, filter, replacement, upsert=False,
|
||||
bypass_document_validation=False, collation=None,
|
||||
hint=None, session=None, let=None):
|
||||
def replace_one(self,
|
||||
filter: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
upsert: bool = False,
|
||||
bypass_document_validation: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
let: Optional[Mapping[str, Any]] = None
|
||||
) -> UpdateResult:
|
||||
"""Replace a single document matching the filter.
|
||||
|
||||
>>> for doc in db.test.find({}):
|
||||
@ -755,10 +794,17 @@ class Collection(common.BaseObject):
|
||||
collation=collation, hint=hint, session=session, let=let),
|
||||
write_concern.acknowledged)
|
||||
|
||||
def update_one(self, filter, update, upsert=False,
|
||||
bypass_document_validation=False,
|
||||
collation=None, array_filters=None, hint=None,
|
||||
session=None, let=None):
|
||||
def update_one(self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
bypass_document_validation: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
let: Optional[Mapping[str, Any]] = None
|
||||
) -> UpdateResult:
|
||||
"""Update a single document matching the filter.
|
||||
|
||||
>>> for doc in db.test.find():
|
||||
@ -800,8 +846,8 @@ class Collection(common.BaseObject):
|
||||
- `session` (optional): a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
- `let` (optional): Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
aggregate expression context (e.g. "$$var").
|
||||
|
||||
:Returns:
|
||||
@ -836,9 +882,17 @@ class Collection(common.BaseObject):
|
||||
hint=hint, session=session, let=let),
|
||||
write_concern.acknowledged)
|
||||
|
||||
def update_many(self, filter, update, upsert=False, array_filters=None,
|
||||
bypass_document_validation=False, collation=None,
|
||||
hint=None, session=None, let=None):
|
||||
def update_many(self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
let: Optional[Mapping[str, Any]] = None
|
||||
) -> UpdateResult:
|
||||
"""Update one or more documents that match the filter.
|
||||
|
||||
>>> for doc in db.test.find():
|
||||
@ -880,8 +934,8 @@ class Collection(common.BaseObject):
|
||||
- `session` (optional): a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
- `let` (optional): Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
aggregate expression context (e.g. "$$var").
|
||||
|
||||
:Returns:
|
||||
@ -916,7 +970,7 @@ class Collection(common.BaseObject):
|
||||
hint=hint, session=session, let=let),
|
||||
write_concern.acknowledged)
|
||||
|
||||
def drop(self, session=None):
|
||||
def drop(self, session: Optional["ClientSession"] = None) -> None:
|
||||
"""Alias for :meth:`~pymongo.database.Database.drop_collection`.
|
||||
|
||||
:Parameters:
|
||||
@ -1005,8 +1059,13 @@ class Collection(common.BaseObject):
|
||||
(write_concern or self.write_concern).acknowledged and not multi,
|
||||
_delete, session)
|
||||
|
||||
def delete_one(self, filter, collation=None, hint=None, session=None,
|
||||
let=None):
|
||||
def delete_one(self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
let: Optional[Mapping[str, Any]] = None
|
||||
) -> DeleteResult:
|
||||
"""Delete a single document matching the filter.
|
||||
|
||||
>>> db.test.count_documents({'x': 1})
|
||||
@ -1030,8 +1089,8 @@ class Collection(common.BaseObject):
|
||||
- `session` (optional): a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
- `let` (optional): Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
aggregate expression context (e.g. "$$var").
|
||||
|
||||
:Returns:
|
||||
@ -1055,8 +1114,13 @@ class Collection(common.BaseObject):
|
||||
collation=collation, hint=hint, session=session, let=let),
|
||||
write_concern.acknowledged)
|
||||
|
||||
def delete_many(self, filter, collation=None, hint=None, session=None,
|
||||
let=None):
|
||||
def delete_many(self,
|
||||
filter: Mapping[str, Any],
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
let: Optional[Mapping[str, Any]] = None
|
||||
) -> DeleteResult:
|
||||
"""Delete one or more documents matching the filter.
|
||||
|
||||
>>> db.test.count_documents({'x': 1})
|
||||
@ -1080,8 +1144,8 @@ class Collection(common.BaseObject):
|
||||
- `session` (optional): a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
- `let` (optional): Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
aggregate expression context (e.g. "$$var").
|
||||
|
||||
:Returns:
|
||||
@ -1105,7 +1169,7 @@ class Collection(common.BaseObject):
|
||||
collation=collation, hint=hint, session=session, let=let),
|
||||
write_concern.acknowledged)
|
||||
|
||||
def find_one(self, filter=None, *args, **kwargs):
|
||||
def find_one(self, filter: Optional[Any] = None, *args: Any, **kwargs: Any) -> Optional[_DocumentType]:
|
||||
"""Get a single document from the database.
|
||||
|
||||
All arguments to :meth:`find` are also valid arguments for
|
||||
@ -1139,7 +1203,7 @@ class Collection(common.BaseObject):
|
||||
return result
|
||||
return None
|
||||
|
||||
def find(self, *args, **kwargs):
|
||||
def find(self, *args: Any, **kwargs: Any) -> Cursor[_DocumentType]:
|
||||
"""Query the database.
|
||||
|
||||
The `filter` argument is a prototype document that all results
|
||||
@ -1328,7 +1392,7 @@ class Collection(common.BaseObject):
|
||||
"""
|
||||
return Cursor(self, *args, **kwargs)
|
||||
|
||||
def find_raw_batches(self, *args, **kwargs):
|
||||
def find_raw_batches(self, *args: Any, **kwargs: Any) -> RawBatchCursor[_DocumentType]:
|
||||
"""Query the database and retrieve batches of raw BSON.
|
||||
|
||||
Similar to the :meth:`find` method but returns a
|
||||
@ -1396,7 +1460,7 @@ class Collection(common.BaseObject):
|
||||
batch = result['cursor']['firstBatch']
|
||||
return batch[0] if batch else None
|
||||
|
||||
def estimated_document_count(self, **kwargs):
|
||||
def estimated_document_count(self, **kwargs: Any) -> int:
|
||||
"""Get an estimate of the number of documents in this collection using
|
||||
collection metadata.
|
||||
|
||||
@ -1445,7 +1509,7 @@ class Collection(common.BaseObject):
|
||||
return self.__database.client._retryable_read(
|
||||
_cmd, self.read_preference, None)
|
||||
|
||||
def count_documents(self, filter, session=None, **kwargs):
|
||||
def count_documents(self, filter: Mapping[str, Any], session: Optional["ClientSession"] = None, **kwargs: Any) -> int:
|
||||
"""Count the number of documents in this collection.
|
||||
|
||||
.. note:: For a fast count of the total documents in a collection see
|
||||
@ -1523,7 +1587,7 @@ class Collection(common.BaseObject):
|
||||
return self.__database.client._retryable_read(
|
||||
_cmd, self._read_preference_for(session), session)
|
||||
|
||||
def create_indexes(self, indexes, session=None, **kwargs):
|
||||
def create_indexes(self, indexes: Sequence[IndexModel], session: Optional["ClientSession"] = None, **kwargs: Any) -> List[str]:
|
||||
"""Create one or more indexes on this collection.
|
||||
|
||||
>>> from pymongo import IndexModel, ASCENDING, DESCENDING
|
||||
@ -1598,7 +1662,7 @@ class Collection(common.BaseObject):
|
||||
session=session)
|
||||
return names
|
||||
|
||||
def create_index(self, keys, session=None, **kwargs):
|
||||
def create_index(self, keys: _IndexKeyHint, session: Optional["ClientSession"] = None, **kwargs: Any) -> str:
|
||||
"""Creates an index on this collection.
|
||||
|
||||
Takes either a single key or a list of (key, direction) pairs.
|
||||
@ -1701,7 +1765,7 @@ class Collection(common.BaseObject):
|
||||
index = IndexModel(keys, **kwargs)
|
||||
return self.__create_indexes([index], session, **cmd_options)[0]
|
||||
|
||||
def drop_indexes(self, session=None, **kwargs):
|
||||
def drop_indexes(self, session: Optional["ClientSession"] = None, **kwargs: Any) -> None:
|
||||
"""Drops all indexes on this collection.
|
||||
|
||||
Can be used on non-existant collections or collections with no indexes.
|
||||
@ -1727,7 +1791,7 @@ class Collection(common.BaseObject):
|
||||
"""
|
||||
self.drop_index("*", session=session, **kwargs)
|
||||
|
||||
def drop_index(self, index_or_name, session=None, **kwargs):
|
||||
def drop_index(self, index_or_name: _IndexKeyHint, session: Optional["ClientSession"] = None, **kwargs: Any) -> None:
|
||||
"""Drops the specified index on this collection.
|
||||
|
||||
Can be used on non-existant collections or collections with no
|
||||
@ -1780,7 +1844,7 @@ class Collection(common.BaseObject):
|
||||
write_concern=self._write_concern_for(session),
|
||||
session=session)
|
||||
|
||||
def list_indexes(self, session=None):
|
||||
def list_indexes(self, session: Optional["ClientSession"] = None) -> CommandCursor[MutableMapping[str, Any]]:
|
||||
"""Get a cursor over the index documents for this collection.
|
||||
|
||||
>>> for index in db.test.list_indexes():
|
||||
@ -1829,7 +1893,7 @@ class Collection(common.BaseObject):
|
||||
return self.__database.client._retryable_read(
|
||||
_cmd, read_pref, session)
|
||||
|
||||
def index_information(self, session=None):
|
||||
def index_information(self, session: Optional["ClientSession"] = None) -> MutableMapping[str, Any]:
|
||||
"""Get information on this collection's indexes.
|
||||
|
||||
Returns a dictionary where the keys are index names (as
|
||||
@ -1863,7 +1927,7 @@ class Collection(common.BaseObject):
|
||||
info[index.pop("name")] = index
|
||||
return info
|
||||
|
||||
def options(self, session=None):
|
||||
def options(self, session: Optional["ClientSession"] = None) -> MutableMapping[str, Any]:
|
||||
"""Get the options set on this collection.
|
||||
|
||||
Returns a dictionary of options and their values - see
|
||||
@ -1896,6 +1960,7 @@ class Collection(common.BaseObject):
|
||||
return {}
|
||||
|
||||
options = result.get("options", {})
|
||||
assert options is not None
|
||||
if "create" in options:
|
||||
del options["create"]
|
||||
|
||||
@ -1911,7 +1976,7 @@ class Collection(common.BaseObject):
|
||||
cmd.get_cursor, cmd.get_read_preference(session), session,
|
||||
retryable=not cmd._performs_write)
|
||||
|
||||
def aggregate(self, pipeline, session=None, let=None, **kwargs):
|
||||
def aggregate(self, pipeline: _Pipeline, session: Optional["ClientSession"] = None, let: Optional[Mapping[str, Any]] = None, **kwargs: Any) -> CommandCursor[_DocumentType]:
|
||||
"""Perform an aggregation using the aggregation framework on this
|
||||
collection.
|
||||
|
||||
@ -1993,7 +2058,9 @@ class Collection(common.BaseObject):
|
||||
let=let,
|
||||
**kwargs)
|
||||
|
||||
def aggregate_raw_batches(self, pipeline, session=None, **kwargs):
|
||||
def aggregate_raw_batches(
|
||||
self, pipeline: _Pipeline, session: Optional["ClientSession"] = None, **kwargs: Any
|
||||
) -> RawBatchCursor[_DocumentType]:
|
||||
"""Perform an aggregation and retrieve batches of raw BSON.
|
||||
|
||||
Similar to the :meth:`aggregate` method but returns a
|
||||
@ -2030,9 +2097,17 @@ class Collection(common.BaseObject):
|
||||
explicit_session=session is not None,
|
||||
**kwargs)
|
||||
|
||||
def watch(self, pipeline=None, full_document=None, resume_after=None,
|
||||
max_await_time_ms=None, batch_size=None, collation=None,
|
||||
start_at_operation_time=None, session=None, start_after=None):
|
||||
def watch(self,
|
||||
pipeline: Optional[_Pipeline] = None,
|
||||
full_document: Optional[str] = None,
|
||||
resume_after: Optional[Mapping[str, Any]] = None,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
start_at_operation_time: Optional[Timestamp] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
start_after: Optional[Mapping[str, Any]] = None,
|
||||
) -> CollectionChangeStream[_DocumentType]:
|
||||
"""Watch changes on this collection.
|
||||
|
||||
Performs an aggregation with an implicit initial ``$changeStream``
|
||||
@ -2132,7 +2207,7 @@ class Collection(common.BaseObject):
|
||||
batch_size, collation, start_at_operation_time, session,
|
||||
start_after)
|
||||
|
||||
def rename(self, new_name, session=None, **kwargs):
|
||||
def rename(self, new_name: str, session: Optional["ClientSession"] = None, **kwargs: Any) -> MutableMapping[str, Any]:
|
||||
"""Rename this collection.
|
||||
|
||||
If operating in auth mode, client must be authorized as an
|
||||
@ -2183,7 +2258,9 @@ class Collection(common.BaseObject):
|
||||
parse_write_concern_error=True,
|
||||
session=s, client=self.__database.client)
|
||||
|
||||
def distinct(self, key, filter=None, session=None, **kwargs):
|
||||
def distinct(
|
||||
self, key: str, filter: Optional[Mapping[str, Any]] = None, session: Optional["ClientSession"] = None, **kwargs: Any
|
||||
) -> List:
|
||||
"""Get a list of distinct values for `key` among all documents
|
||||
in this collection.
|
||||
|
||||
@ -2283,7 +2360,7 @@ class Collection(common.BaseObject):
|
||||
raise ConfigurationError(
|
||||
'arrayFilters is unsupported for unacknowledged '
|
||||
'writes.')
|
||||
cmd["arrayFilters"] = array_filters
|
||||
cmd["arrayFilters"] = list(array_filters)
|
||||
if hint is not None:
|
||||
if sock_info.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
@ -2307,9 +2384,15 @@ class Collection(common.BaseObject):
|
||||
return self.__database.client._retryable_write(
|
||||
write_concern.acknowledged, _find_and_modify, session)
|
||||
|
||||
def find_one_and_delete(self, filter,
|
||||
projection=None, sort=None, hint=None,
|
||||
session=None, let=None, **kwargs):
|
||||
def find_one_and_delete(self,
|
||||
filter: Mapping[str, Any],
|
||||
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
|
||||
sort: Optional[_IndexList] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> _DocumentType:
|
||||
"""Finds a single document and deletes it, returning the document.
|
||||
|
||||
>>> db.test.count_documents({'x': 1})
|
||||
@ -2357,8 +2440,8 @@ class Collection(common.BaseObject):
|
||||
as keyword arguments (for example maxTimeMS can be used with
|
||||
recent server versions).
|
||||
- `let` (optional): Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
aggregate expression context (e.g. "$$var").
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
@ -2384,10 +2467,18 @@ class Collection(common.BaseObject):
|
||||
return self.__find_and_modify(filter, projection, sort, let=let,
|
||||
hint=hint, session=session, **kwargs)
|
||||
|
||||
def find_one_and_replace(self, filter, replacement,
|
||||
projection=None, sort=None, upsert=False,
|
||||
return_document=ReturnDocument.BEFORE,
|
||||
hint=None, session=None, let=None, **kwargs):
|
||||
def find_one_and_replace(self,
|
||||
filter: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
|
||||
sort: Optional[_IndexList] = None,
|
||||
upsert: bool = False,
|
||||
return_document: bool = ReturnDocument.BEFORE,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> _DocumentType:
|
||||
"""Finds a single document and replaces it, returning either the
|
||||
original or the replaced document.
|
||||
|
||||
@ -2438,8 +2529,8 @@ class Collection(common.BaseObject):
|
||||
- `session` (optional): a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
- `let` (optional): Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
aggregate expression context (e.g. "$$var").
|
||||
- `**kwargs` (optional): additional command arguments can be passed
|
||||
as keyword arguments (for example maxTimeMS can be used with
|
||||
@ -2470,11 +2561,19 @@ class Collection(common.BaseObject):
|
||||
sort, upsert, return_document, let=let,
|
||||
hint=hint, session=session, **kwargs)
|
||||
|
||||
def find_one_and_update(self, filter, update,
|
||||
projection=None, sort=None, upsert=False,
|
||||
return_document=ReturnDocument.BEFORE,
|
||||
array_filters=None, hint=None, session=None,
|
||||
let=None, **kwargs):
|
||||
def find_one_and_update(self,
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
|
||||
sort: Optional[_IndexList] = None,
|
||||
upsert: bool = False,
|
||||
return_document: bool = ReturnDocument.BEFORE,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> _DocumentType:
|
||||
"""Finds a single document and updates it, returning either the
|
||||
original or the updated document.
|
||||
|
||||
@ -2564,8 +2663,8 @@ class Collection(common.BaseObject):
|
||||
- `session` (optional): a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
- `let` (optional): Map of parameter names and values. Values must be
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
constant or closed expressions that do not reference document
|
||||
fields. Parameters can then be accessed as variables in an
|
||||
aggregate expression context (e.g. "$$var").
|
||||
- `**kwargs` (optional): additional command arguments can be passed
|
||||
as keyword arguments (for example maxTimeMS can be used with
|
||||
@ -2600,15 +2699,15 @@ class Collection(common.BaseObject):
|
||||
array_filters, hint=hint, let=let,
|
||||
session=session, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> "Collection[_DocumentType]":
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
def __next__(self) -> None:
|
||||
raise TypeError("'Collection' object is not iterable")
|
||||
|
||||
next = __next__
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""This is only here so that some API misusages are easier to debug.
|
||||
"""
|
||||
if "." not in self.__name:
|
||||
|
||||
@ -15,28 +15,38 @@
|
||||
"""CommandCursor class to iterate over command results."""
|
||||
|
||||
from collections import deque
|
||||
from typing import (TYPE_CHECKING, Any, Generic, Iterator, Mapping, Optional,
|
||||
Tuple)
|
||||
|
||||
from bson import _convert_raw_document_lists_to_streams
|
||||
from pymongo.cursor import _SocketManager, _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import (ConnectionFailure,
|
||||
InvalidOperation,
|
||||
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager
|
||||
from pymongo.errors import (ConnectionFailure, InvalidOperation,
|
||||
OperationFailure)
|
||||
from pymongo.message import (_CursorAddress,
|
||||
_GetMore,
|
||||
_RawBatchGetMore)
|
||||
from pymongo.message import _CursorAddress, _GetMore, _RawBatchGetMore
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.typings import _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
|
||||
|
||||
class CommandCursor(object):
|
||||
class CommandCursor(Generic[_DocumentType]):
|
||||
"""A cursor / iterator over command cursors."""
|
||||
_getmore_class = _GetMore
|
||||
|
||||
def __init__(self, collection, cursor_info, address,
|
||||
batch_size=0, max_await_time_ms=None, session=None,
|
||||
explicit_session=False):
|
||||
def __init__(self,
|
||||
collection: "Collection[_DocumentType]",
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[Tuple[str, Optional[int]]],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
explicit_session: bool = False,
|
||||
) -> None:
|
||||
"""Create a new command cursor."""
|
||||
self.__sock_mgr = None
|
||||
self.__collection = collection
|
||||
self.__sock_mgr: Any = None
|
||||
self.__collection: Collection[_DocumentType] = collection
|
||||
self.__id = cursor_info['id']
|
||||
self.__data = deque(cursor_info['firstBatch'])
|
||||
self.__postbatchresumetoken = cursor_info.get('postBatchResumeToken')
|
||||
@ -60,7 +70,7 @@ class CommandCursor(object):
|
||||
and max_await_time_ms is not None):
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self.__die()
|
||||
|
||||
def __die(self, synchronous=False):
|
||||
@ -92,12 +102,12 @@ class CommandCursor(object):
|
||||
self.__session._end_session(lock=synchronous)
|
||||
self.__session = None
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Explicitly close / kill this cursor.
|
||||
"""
|
||||
self.__die(True)
|
||||
|
||||
def batch_size(self, batch_size):
|
||||
def batch_size(self, batch_size: int) -> "CommandCursor[_DocumentType]":
|
||||
"""Limits the number of documents returned in one batch. Each batch
|
||||
requires a round trip to the server. It can be adjusted to optimize
|
||||
performance and limit data transfer.
|
||||
@ -222,7 +232,7 @@ class CommandCursor(object):
|
||||
return len(self.__data)
|
||||
|
||||
@property
|
||||
def alive(self):
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
@ -239,12 +249,12 @@ class CommandCursor(object):
|
||||
return bool(len(self.__data) or (not self.__killed))
|
||||
|
||||
@property
|
||||
def cursor_id(self):
|
||||
def cursor_id(self) -> int:
|
||||
"""Returns the id of the cursor."""
|
||||
return self.__id
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
def address(self) -> Optional[Tuple[str, Optional[int]]]:
|
||||
"""The (host, port) of the server used, or None.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
@ -252,18 +262,19 @@ class CommandCursor(object):
|
||||
return self.__address
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
def session(self) -> Optional["ClientSession"]:
|
||||
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
if self.__explicit_session:
|
||||
return self.__session
|
||||
return None
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[_DocumentType]:
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor."""
|
||||
# Block until a document is returnable.
|
||||
while self.alive:
|
||||
@ -284,19 +295,25 @@ class CommandCursor(object):
|
||||
else:
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "CommandCursor[_DocumentType]":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class RawBatchCommandCursor(CommandCursor):
|
||||
class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
|
||||
_getmore_class = _RawBatchGetMore
|
||||
|
||||
def __init__(self, collection, cursor_info, address,
|
||||
batch_size=0, max_await_time_ms=None, session=None,
|
||||
explicit_session=False):
|
||||
def __init__(self,
|
||||
collection: "Collection[_DocumentType]",
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[Tuple[str, Optional[int]]],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
explicit_session: bool = False,
|
||||
) -> None:
|
||||
"""Create a new cursor / iterator over raw batches of BSON data.
|
||||
|
||||
Should not be called directly by application developers -
|
||||
|
||||
@ -17,8 +17,9 @@
|
||||
|
||||
import datetime
|
||||
import warnings
|
||||
|
||||
from collections import abc, OrderedDict
|
||||
from collections import OrderedDict, abc
|
||||
from typing import (Any, Callable, Dict, List, Mapping, MutableMapping,
|
||||
Optional, Sequence, Tuple, Type, Union, cast)
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from bson import SON
|
||||
@ -29,18 +30,18 @@ from pymongo.auth import MECHANISMS
|
||||
from pymongo.compression_support import (validate_compressors,
|
||||
validate_zlib_compression_level)
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.monitoring import _validate_event_listeners
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _MONGOS_MODES, _ServerMode
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
||||
|
||||
ORDERED_TYPES = (SON, OrderedDict)
|
||||
ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict)
|
||||
|
||||
# Defaults until we connect to a server and get updated limits.
|
||||
MAX_BSON_SIZE = 16 * (1024 ** 2)
|
||||
MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE
|
||||
MAX_MESSAGE_SIZE: int = 2 * MAX_BSON_SIZE
|
||||
MIN_WIRE_VERSION = 0
|
||||
MAX_WIRE_VERSION = 0
|
||||
MAX_WRITE_BATCH_SIZE = 1000
|
||||
@ -85,13 +86,13 @@ MIN_POOL_SIZE = 0
|
||||
MAX_CONNECTING = 2
|
||||
|
||||
# Default value for maxIdleTimeMS.
|
||||
MAX_IDLE_TIME_MS = None
|
||||
MAX_IDLE_TIME_MS: Optional[int] = None
|
||||
|
||||
# Default value for maxIdleTimeMS in seconds.
|
||||
MAX_IDLE_TIME_SEC = None
|
||||
MAX_IDLE_TIME_SEC: Optional[int] = None
|
||||
|
||||
# Default value for waitQueueTimeoutMS in seconds.
|
||||
WAIT_QUEUE_TIMEOUT = None
|
||||
WAIT_QUEUE_TIMEOUT: Optional[int] = None
|
||||
|
||||
# Default value for localThresholdMS.
|
||||
LOCAL_THRESHOLD_MS = 15
|
||||
@ -103,10 +104,10 @@ RETRY_WRITES = True
|
||||
RETRY_READS = True
|
||||
|
||||
# The error code returned when a command doesn't exist.
|
||||
COMMAND_NOT_FOUND_CODES = (59,)
|
||||
COMMAND_NOT_FOUND_CODES: Sequence[int] = (59,)
|
||||
|
||||
# Error codes to ignore if GridFS calls createIndex on a secondary
|
||||
UNAUTHORIZED_CODES = (13, 16547, 16548)
|
||||
UNAUTHORIZED_CODES: Sequence[int] = (13, 16547, 16548)
|
||||
|
||||
# Maximum number of sessions to send in a single endSessions command.
|
||||
# From the driver sessions spec.
|
||||
@ -116,7 +117,7 @@ _MAX_END_SESSIONS = 10000
|
||||
SRV_SERVICE_NAME = "mongodb"
|
||||
|
||||
|
||||
def partition_node(node):
|
||||
def partition_node(node: str) -> Tuple[str, int]:
|
||||
"""Split a host:port string into (host, int(port)) pair."""
|
||||
host = node
|
||||
port = 27017
|
||||
@ -128,7 +129,7 @@ def partition_node(node):
|
||||
return host, port
|
||||
|
||||
|
||||
def clean_node(node):
|
||||
def clean_node(node: str) -> Tuple[str, int]:
|
||||
"""Split and normalize a node name from a hello response."""
|
||||
host, port = partition_node(node)
|
||||
|
||||
@ -139,7 +140,7 @@ def clean_node(node):
|
||||
return host.lower(), port
|
||||
|
||||
|
||||
def raise_config_error(key, dummy):
|
||||
def raise_config_error(key: str, dummy: Any) -> None:
|
||||
"""Raise ConfigurationError with the given key name."""
|
||||
raise ConfigurationError("Unknown option %s" % (key,))
|
||||
|
||||
@ -154,14 +155,14 @@ _UUID_REPRESENTATIONS = {
|
||||
}
|
||||
|
||||
|
||||
def validate_boolean(option, value):
|
||||
def validate_boolean(option: str, value: Any) -> bool:
|
||||
"""Validates that 'value' is True or False."""
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
raise TypeError("%s must be True or False" % (option,))
|
||||
|
||||
|
||||
def validate_boolean_or_string(option, value):
|
||||
def validate_boolean_or_string(option: str, value: Any) -> bool:
|
||||
"""Validates that value is True, False, 'true', or 'false'."""
|
||||
if isinstance(value, str):
|
||||
if value not in ('true', 'false'):
|
||||
@ -171,7 +172,7 @@ def validate_boolean_or_string(option, value):
|
||||
return validate_boolean(option, value)
|
||||
|
||||
|
||||
def validate_integer(option, value):
|
||||
def validate_integer(option: str, value: Any) -> int:
|
||||
"""Validates that 'value' is an integer (or basestring representation).
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
@ -185,7 +186,7 @@ def validate_integer(option, value):
|
||||
raise TypeError("Wrong type for %s, value must be an integer" % (option,))
|
||||
|
||||
|
||||
def validate_positive_integer(option, value):
|
||||
def validate_positive_integer(option: str, value: Any) -> int:
|
||||
"""Validate that 'value' is a positive integer, which does not include 0.
|
||||
"""
|
||||
val = validate_integer(option, value)
|
||||
@ -195,7 +196,7 @@ def validate_positive_integer(option, value):
|
||||
return val
|
||||
|
||||
|
||||
def validate_non_negative_integer(option, value):
|
||||
def validate_non_negative_integer(option: str, value: Any) -> int:
|
||||
"""Validate that 'value' is a positive integer or 0.
|
||||
"""
|
||||
val = validate_integer(option, value)
|
||||
@ -205,7 +206,7 @@ def validate_non_negative_integer(option, value):
|
||||
return val
|
||||
|
||||
|
||||
def validate_readable(option, value):
|
||||
def validate_readable(option: str, value: Any) -> Optional[str]:
|
||||
"""Validates that 'value' is file-like and readable.
|
||||
"""
|
||||
if value is None:
|
||||
@ -217,7 +218,7 @@ def validate_readable(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_positive_integer_or_none(option, value):
|
||||
def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]:
|
||||
"""Validate that 'value' is a positive integer or None.
|
||||
"""
|
||||
if value is None:
|
||||
@ -225,7 +226,7 @@ def validate_positive_integer_or_none(option, value):
|
||||
return validate_positive_integer(option, value)
|
||||
|
||||
|
||||
def validate_non_negative_integer_or_none(option, value):
|
||||
def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]:
|
||||
"""Validate that 'value' is a positive integer or 0 or None.
|
||||
"""
|
||||
if value is None:
|
||||
@ -233,7 +234,7 @@ def validate_non_negative_integer_or_none(option, value):
|
||||
return validate_non_negative_integer(option, value)
|
||||
|
||||
|
||||
def validate_string(option, value):
|
||||
def validate_string(option: str, value: Any) -> str:
|
||||
"""Validates that 'value' is an instance of `str`.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
@ -242,7 +243,7 @@ def validate_string(option, value):
|
||||
"str" % (option,))
|
||||
|
||||
|
||||
def validate_string_or_none(option, value):
|
||||
def validate_string_or_none(option: str, value: Any) -> Optional[str]:
|
||||
"""Validates that 'value' is an instance of `basestring` or `None`.
|
||||
"""
|
||||
if value is None:
|
||||
@ -250,7 +251,7 @@ def validate_string_or_none(option, value):
|
||||
return validate_string(option, value)
|
||||
|
||||
|
||||
def validate_int_or_basestring(option, value):
|
||||
def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]:
|
||||
"""Validates that 'value' is an integer or string.
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
@ -264,7 +265,7 @@ def validate_int_or_basestring(option, value):
|
||||
"integer or a string" % (option,))
|
||||
|
||||
|
||||
def validate_non_negative_int_or_basestring(option, value):
|
||||
def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]:
|
||||
"""Validates that 'value' is an integer or string.
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
@ -279,7 +280,7 @@ def validate_non_negative_int_or_basestring(option, value):
|
||||
"non negative integer or a string" % (option,))
|
||||
|
||||
|
||||
def validate_positive_float(option, value):
|
||||
def validate_positive_float(option: str, value: Any) -> float:
|
||||
"""Validates that 'value' is a float, or can be converted to one, and is
|
||||
positive.
|
||||
"""
|
||||
@ -299,7 +300,7 @@ def validate_positive_float(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_positive_float_or_zero(option, value):
|
||||
def validate_positive_float_or_zero(option: str, value: Any) -> float:
|
||||
"""Validates that 'value' is 0 or a positive float, or can be converted to
|
||||
0 or a positive float.
|
||||
"""
|
||||
@ -308,7 +309,7 @@ def validate_positive_float_or_zero(option, value):
|
||||
return validate_positive_float(option, value)
|
||||
|
||||
|
||||
def validate_timeout_or_none(option, value):
|
||||
def validate_timeout_or_none(option: str, value: Any) -> Optional[float]:
|
||||
"""Validates a timeout specified in milliseconds returning
|
||||
a value in floating point seconds.
|
||||
"""
|
||||
@ -317,7 +318,7 @@ def validate_timeout_or_none(option, value):
|
||||
return validate_positive_float(option, value) / 1000.0
|
||||
|
||||
|
||||
def validate_timeout_or_zero(option, value):
|
||||
def validate_timeout_or_zero(option: str, value: Any) -> float:
|
||||
"""Validates a timeout specified in milliseconds returning
|
||||
a value in floating point seconds for the case where None is an error
|
||||
and 0 is valid. Setting the timeout to nothing in the URI string is a
|
||||
@ -330,7 +331,7 @@ def validate_timeout_or_zero(option, value):
|
||||
return validate_positive_float(option, value) / 1000.0
|
||||
|
||||
|
||||
def validate_timeout_or_none_or_zero(option, value):
|
||||
def validate_timeout_or_none_or_zero(option: Any, value: Any) -> Optional[float]:
|
||||
"""Validates a timeout specified in milliseconds returning
|
||||
a value in floating point seconds. value=0 and value="0" are treated the
|
||||
same as value=None which means unlimited timeout.
|
||||
@ -340,7 +341,7 @@ def validate_timeout_or_none_or_zero(option, value):
|
||||
return validate_positive_float(option, value) / 1000.0
|
||||
|
||||
|
||||
def validate_max_staleness(option, value):
|
||||
def validate_max_staleness(option: str, value: Any) -> int:
|
||||
"""Validates maxStalenessSeconds according to the Max Staleness Spec."""
|
||||
if value == -1 or value == "-1":
|
||||
# Default: No maximum staleness.
|
||||
@ -348,7 +349,7 @@ def validate_max_staleness(option, value):
|
||||
return validate_positive_integer(option, value)
|
||||
|
||||
|
||||
def validate_read_preference(dummy, value):
|
||||
def validate_read_preference(dummy: Any, value: Any) -> _ServerMode:
|
||||
"""Validate a read preference.
|
||||
"""
|
||||
if not isinstance(value, _ServerMode):
|
||||
@ -356,7 +357,7 @@ def validate_read_preference(dummy, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_read_preference_mode(dummy, value):
|
||||
def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode:
|
||||
"""Validate read preference mode for a MongoClient.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
@ -368,7 +369,7 @@ def validate_read_preference_mode(dummy, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_auth_mechanism(option, value):
|
||||
def validate_auth_mechanism(option: str, value: Any) -> str:
|
||||
"""Validate the authMechanism URI option.
|
||||
"""
|
||||
if value not in MECHANISMS:
|
||||
@ -376,7 +377,7 @@ def validate_auth_mechanism(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_uuid_representation(dummy, value):
|
||||
def validate_uuid_representation(dummy: Any, value: Any) -> int:
|
||||
"""Validate the uuid representation option selected in the URI.
|
||||
"""
|
||||
try:
|
||||
@ -387,13 +388,13 @@ def validate_uuid_representation(dummy, value):
|
||||
"%s" % (value, tuple(_UUID_REPRESENTATIONS)))
|
||||
|
||||
|
||||
def validate_read_preference_tags(name, value):
|
||||
def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]]:
|
||||
"""Parse readPreferenceTags if passed as a client kwarg.
|
||||
"""
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
|
||||
tag_sets = []
|
||||
tag_sets: List = []
|
||||
for tag_set in value:
|
||||
if tag_set == '':
|
||||
tag_sets.append({})
|
||||
@ -416,10 +417,10 @@ _MECHANISM_PROPS = frozenset(['SERVICE_NAME',
|
||||
'AWS_SESSION_TOKEN'])
|
||||
|
||||
|
||||
def validate_auth_mechanism_properties(option, value):
|
||||
def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]:
|
||||
"""Validate authMechanismProperties."""
|
||||
value = validate_string(option, value)
|
||||
props = {}
|
||||
props: Dict[str, Any] = {}
|
||||
for opt in value.split(','):
|
||||
try:
|
||||
key, val = opt.split(':')
|
||||
@ -443,7 +444,7 @@ def validate_auth_mechanism_properties(option, value):
|
||||
return props
|
||||
|
||||
|
||||
def validate_document_class(option, value):
|
||||
def validate_document_class(option: str, value: Any) -> Union[Type[MutableMapping], Type[RawBSONDocument]]:
|
||||
"""Validate the document_class option."""
|
||||
if not issubclass(value, (abc.MutableMapping, RawBSONDocument)):
|
||||
raise TypeError("%s must be dict, bson.son.SON, "
|
||||
@ -452,7 +453,7 @@ def validate_document_class(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_type_registry(option, value):
|
||||
def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]:
|
||||
"""Validate the type_registry option."""
|
||||
if value is not None and not isinstance(value, TypeRegistry):
|
||||
raise TypeError("%s must be an instance of %s" % (
|
||||
@ -460,21 +461,21 @@ def validate_type_registry(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_list(option, value):
|
||||
def validate_list(option: str, value: Any) -> List:
|
||||
"""Validates that 'value' is a list."""
|
||||
if not isinstance(value, list):
|
||||
raise TypeError("%s must be a list" % (option,))
|
||||
return value
|
||||
|
||||
|
||||
def validate_list_or_none(option, value):
|
||||
def validate_list_or_none(option: Any, value: Any) -> Optional[List]:
|
||||
"""Validates that 'value' is a list or None."""
|
||||
if value is None:
|
||||
return value
|
||||
return validate_list(option, value)
|
||||
|
||||
|
||||
def validate_list_or_mapping(option, value):
|
||||
def validate_list_or_mapping(option: Any, value: Any) -> None:
|
||||
"""Validates that 'value' is a list or a document."""
|
||||
if not isinstance(value, (abc.Mapping, list)):
|
||||
raise TypeError("%s must either be a list or an instance of dict, "
|
||||
@ -482,7 +483,7 @@ def validate_list_or_mapping(option, value):
|
||||
"collections.Mapping" % (option,))
|
||||
|
||||
|
||||
def validate_is_mapping(option, value):
|
||||
def validate_is_mapping(option: str, value: Any) -> None:
|
||||
"""Validate the type of method arguments that expect a document."""
|
||||
if not isinstance(value, abc.Mapping):
|
||||
raise TypeError("%s must be an instance of dict, bson.son.SON, or "
|
||||
@ -490,7 +491,7 @@ def validate_is_mapping(option, value):
|
||||
"collections.Mapping" % (option,))
|
||||
|
||||
|
||||
def validate_is_document_type(option, value):
|
||||
def validate_is_document_type(option: str, value: Any) -> None:
|
||||
"""Validate the type of method arguments that expect a MongoDB document."""
|
||||
if not isinstance(value, (abc.MutableMapping, RawBSONDocument)):
|
||||
raise TypeError("%s must be an instance of dict, bson.son.SON, "
|
||||
@ -499,7 +500,7 @@ def validate_is_document_type(option, value):
|
||||
"collections.MutableMapping" % (option,))
|
||||
|
||||
|
||||
def validate_appname_or_none(option, value):
|
||||
def validate_appname_or_none(option: str, value: Any) -> Optional[str]:
|
||||
"""Validate the appname option."""
|
||||
if value is None:
|
||||
return value
|
||||
@ -510,7 +511,7 @@ def validate_appname_or_none(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_driver_or_none(option, value):
|
||||
def validate_driver_or_none(option: Any, value: Any) -> Optional[DriverInfo]:
|
||||
"""Validate the driver keyword arg."""
|
||||
if value is None:
|
||||
return value
|
||||
@ -519,7 +520,7 @@ def validate_driver_or_none(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_server_api_or_none(option, value):
|
||||
def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]:
|
||||
"""Validate the server_api keyword arg."""
|
||||
if value is None:
|
||||
return value
|
||||
@ -528,7 +529,7 @@ def validate_server_api_or_none(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_is_callable_or_none(option, value):
|
||||
def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]:
|
||||
"""Validates that 'value' is a callable."""
|
||||
if value is None:
|
||||
return value
|
||||
@ -537,7 +538,7 @@ def validate_is_callable_or_none(option, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_ok_for_replace(replacement):
|
||||
def validate_ok_for_replace(replacement: Mapping[str, Any]) -> None:
|
||||
"""Validate a replacement document."""
|
||||
validate_is_mapping("replacement", replacement)
|
||||
# Replacement can be {}
|
||||
@ -547,7 +548,7 @@ def validate_ok_for_replace(replacement):
|
||||
raise ValueError('replacement can not include $ operators')
|
||||
|
||||
|
||||
def validate_ok_for_update(update):
|
||||
def validate_ok_for_update(update: Any) -> None:
|
||||
"""Validate an update document."""
|
||||
validate_list_or_mapping("update", update)
|
||||
# Update cannot be {}.
|
||||
@ -563,7 +564,7 @@ def validate_ok_for_update(update):
|
||||
_UNICODE_DECODE_ERROR_HANDLERS = frozenset(['strict', 'replace', 'ignore'])
|
||||
|
||||
|
||||
def validate_unicode_decode_error_handler(dummy, value):
|
||||
def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str:
|
||||
"""Validate the Unicode decode error handler option of CodecOptions.
|
||||
"""
|
||||
if value not in _UNICODE_DECODE_ERROR_HANDLERS:
|
||||
@ -573,7 +574,7 @@ def validate_unicode_decode_error_handler(dummy, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_tzinfo(dummy, value):
|
||||
def validate_tzinfo(dummy: Any, value: Any) -> Optional[datetime.tzinfo]:
|
||||
"""Validate the tzinfo option
|
||||
"""
|
||||
if value is not None and not isinstance(value, datetime.tzinfo):
|
||||
@ -581,7 +582,7 @@ def validate_tzinfo(dummy, value):
|
||||
return value
|
||||
|
||||
|
||||
def validate_auto_encryption_opts_or_none(option, value):
|
||||
def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[Any]:
|
||||
"""Validate the driver keyword arg."""
|
||||
if value is None:
|
||||
return value
|
||||
@ -595,7 +596,7 @@ def validate_auto_encryption_opts_or_none(option, value):
|
||||
|
||||
# Dictionary where keys are the names of public URI options, and values
|
||||
# are lists of aliases for that option.
|
||||
URI_OPTIONS_ALIAS_MAP = {
|
||||
URI_OPTIONS_ALIAS_MAP: Dict[str, List[str]] = {
|
||||
'tls': ['ssl'],
|
||||
}
|
||||
|
||||
@ -603,7 +604,7 @@ URI_OPTIONS_ALIAS_MAP = {
|
||||
# are functions that validate user-input values for that option. If an option
|
||||
# alias uses a different validator than its public counterpart, it should be
|
||||
# included here as a key, value pair.
|
||||
URI_OPTIONS_VALIDATOR_MAP = {
|
||||
URI_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = {
|
||||
'appname': validate_appname_or_none,
|
||||
'authmechanism': validate_auth_mechanism,
|
||||
'authmechanismproperties': validate_auth_mechanism_properties,
|
||||
@ -644,7 +645,7 @@ URI_OPTIONS_VALIDATOR_MAP = {
|
||||
|
||||
# Dictionary where keys are the names of URI options specific to pymongo,
|
||||
# and values are functions that validate user-input values for those options.
|
||||
NONSPEC_OPTIONS_VALIDATOR_MAP = {
|
||||
NONSPEC_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = {
|
||||
'connect': validate_boolean_or_string,
|
||||
'driver': validate_driver_or_none,
|
||||
'server_api': validate_server_api_or_none,
|
||||
@ -661,7 +662,7 @@ NONSPEC_OPTIONS_VALIDATOR_MAP = {
|
||||
# Dictionary where keys are the names of keyword-only options for the
|
||||
# MongoClient constructor, and values are functions that validate user-input
|
||||
# values for those options.
|
||||
KW_VALIDATORS = {
|
||||
KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = {
|
||||
'document_class': validate_document_class,
|
||||
'type_registry': validate_type_registry,
|
||||
'read_preference': validate_read_preference,
|
||||
@ -677,14 +678,14 @@ KW_VALIDATORS = {
|
||||
# internally-used names of that URI option. Options with only one name
|
||||
# variant need not be included here. Options whose public and internal
|
||||
# names are the same need not be included here.
|
||||
INTERNAL_URI_OPTION_NAME_MAP = {
|
||||
INTERNAL_URI_OPTION_NAME_MAP: Dict[str, str] = {
|
||||
'ssl': 'tls',
|
||||
}
|
||||
|
||||
# Map from deprecated URI option names to a tuple indicating the method of
|
||||
# their deprecation and any additional information that may be needed to
|
||||
# construct the warning message.
|
||||
URI_OPTIONS_DEPRECATION_MAP = {
|
||||
URI_OPTIONS_DEPRECATION_MAP: Dict[str, Tuple[str, str]] = {
|
||||
# format: <deprecated option name>: (<mode>, <message>),
|
||||
# Supported <mode> values:
|
||||
# - 'renamed': <message> should be the new option name. Note that case is
|
||||
@ -704,11 +705,11 @@ for optname, aliases in URI_OPTIONS_ALIAS_MAP.items():
|
||||
URI_OPTIONS_VALIDATOR_MAP[optname])
|
||||
|
||||
# Map containing all URI option and keyword argument validators.
|
||||
VALIDATORS = URI_OPTIONS_VALIDATOR_MAP.copy()
|
||||
VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy()
|
||||
VALIDATORS.update(KW_VALIDATORS)
|
||||
|
||||
# List of timeout-related options.
|
||||
TIMEOUT_OPTIONS = [
|
||||
TIMEOUT_OPTIONS: List[str] = [
|
||||
'connecttimeoutms',
|
||||
'heartbeatfrequencyms',
|
||||
'maxidletimems',
|
||||
@ -722,7 +723,7 @@ TIMEOUT_OPTIONS = [
|
||||
_AUTH_OPTIONS = frozenset(['authmechanismproperties'])
|
||||
|
||||
|
||||
def validate_auth_option(option, value):
|
||||
def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]:
|
||||
"""Validate optional authentication parameters.
|
||||
"""
|
||||
lower, value = validate(option, value)
|
||||
@ -732,7 +733,7 @@ def validate_auth_option(option, value):
|
||||
return option, value
|
||||
|
||||
|
||||
def validate(option, value):
|
||||
def validate(option: str, value: Any) -> Tuple[str, Any]:
|
||||
"""Generic validation function.
|
||||
"""
|
||||
lower = option.lower()
|
||||
@ -741,7 +742,7 @@ def validate(option, value):
|
||||
return option, value
|
||||
|
||||
|
||||
def get_validated_options(options, warn=True):
|
||||
def get_validated_options(options: Mapping[str, Any], warn: bool = True) -> MutableMapping[str, Any]:
|
||||
"""Validate each entry in options and raise a warning if it is not valid.
|
||||
Returns a copy of options with invalid entries removed.
|
||||
|
||||
@ -751,6 +752,7 @@ def get_validated_options(options, warn=True):
|
||||
invalid options will be ignored. Otherwise, invalid options will
|
||||
cause errors.
|
||||
"""
|
||||
validated_options: MutableMapping[str, Any]
|
||||
if isinstance(options, _CaseInsensitiveDictionary):
|
||||
validated_options = _CaseInsensitiveDictionary()
|
||||
get_normed_key = lambda x: x
|
||||
@ -794,8 +796,8 @@ class BaseObject(object):
|
||||
SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB.
|
||||
"""
|
||||
|
||||
def __init__(self, codec_options, read_preference, write_concern,
|
||||
read_concern):
|
||||
def __init__(self, codec_options: CodecOptions, read_preference: _ServerMode, write_concern: WriteConcern,
|
||||
read_concern: ReadConcern) -> None:
|
||||
|
||||
if not isinstance(codec_options, CodecOptions):
|
||||
raise TypeError("codec_options must be an instance of "
|
||||
@ -819,14 +821,14 @@ class BaseObject(object):
|
||||
self.__read_concern = read_concern
|
||||
|
||||
@property
|
||||
def codec_options(self):
|
||||
def codec_options(self) -> CodecOptions:
|
||||
"""Read only access to the :class:`~bson.codec_options.CodecOptions`
|
||||
of this instance.
|
||||
"""
|
||||
return self.__codec_options
|
||||
|
||||
@property
|
||||
def write_concern(self):
|
||||
def write_concern(self) -> WriteConcern:
|
||||
"""Read only access to the :class:`~pymongo.write_concern.WriteConcern`
|
||||
of this instance.
|
||||
|
||||
@ -844,7 +846,7 @@ class BaseObject(object):
|
||||
return self.write_concern
|
||||
|
||||
@property
|
||||
def read_preference(self):
|
||||
def read_preference(self) -> _ServerMode:
|
||||
"""Read only access to the read preference of this instance.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
@ -861,7 +863,7 @@ class BaseObject(object):
|
||||
return self.__read_preference
|
||||
|
||||
@property
|
||||
def read_concern(self):
|
||||
def read_concern(self) -> ReadConcern:
|
||||
"""Read only access to the :class:`~pymongo.read_concern.ReadConcern`
|
||||
of this instance.
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
try:
|
||||
import snappy
|
||||
@ -99,7 +100,7 @@ class CompressionSettings(object):
|
||||
return ZstdContext()
|
||||
|
||||
|
||||
def _zlib_no_compress(data):
|
||||
def _zlib_no_compress(data, level=None):
|
||||
"""Compress data with zlib level 0."""
|
||||
cobj = zlib.compressobj(0)
|
||||
return b"".join([cobj.compress(data), cobj.flush()])
|
||||
@ -117,6 +118,8 @@ class ZlibContext(object):
|
||||
compressor_id = 2
|
||||
|
||||
def __init__(self, level):
|
||||
self.compress: Callable[[bytes], bytes]
|
||||
|
||||
# Jython zlib.compress doesn't support -1
|
||||
if level == -1:
|
||||
self.compress = zlib.compress
|
||||
@ -124,7 +127,7 @@ class ZlibContext(object):
|
||||
elif level == 0:
|
||||
self.compress = _zlib_no_compress
|
||||
else:
|
||||
self.compress = lambda data: zlib.compress(data, level)
|
||||
self.compresss = lambda data, _: zlib.compress(data, level)
|
||||
|
||||
|
||||
class ZstdContext(object):
|
||||
|
||||
@ -13,29 +13,26 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Cursor class to iterate over Mongo query results."""
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from collections import deque
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Mapping,
|
||||
MutableMapping, Optional, Sequence, Tuple, Union, cast, overload)
|
||||
|
||||
from bson import RE_TYPE, _convert_raw_document_lists_to_streams
|
||||
from bson.code import Code
|
||||
from bson.son import SON
|
||||
from pymongo import helpers
|
||||
from pymongo.common import (validate_boolean, validate_is_mapping,
|
||||
validate_is_document_type)
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import (ConnectionFailure,
|
||||
InvalidOperation,
|
||||
from pymongo.common import (validate_boolean, validate_is_document_type,
|
||||
validate_is_mapping)
|
||||
from pymongo.errors import (ConnectionFailure, InvalidOperation,
|
||||
OperationFailure)
|
||||
from pymongo.message import (_CursorAddress,
|
||||
_GetMore,
|
||||
_RawBatchGetMore,
|
||||
_Query,
|
||||
_RawBatchQuery)
|
||||
from pymongo.message import (_CursorAddress, _GetMore, _Query,
|
||||
_RawBatchGetMore, _RawBatchQuery)
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.typings import _CollationIn, _DocumentType
|
||||
|
||||
# These errors mean that the server has already killed the cursor so there is
|
||||
# no need to send killCursors.
|
||||
@ -126,22 +123,47 @@ class _SocketManager(object):
|
||||
self.sock.unpin()
|
||||
self.sock = None
|
||||
|
||||
_Sort = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
|
||||
_Hint = Union[str, _Sort]
|
||||
|
||||
class Cursor(object):
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collection import Collection
|
||||
|
||||
|
||||
class Cursor(Generic[_DocumentType]):
|
||||
"""A cursor / iterator over Mongo query results.
|
||||
"""
|
||||
_query_class = _Query
|
||||
_getmore_class = _GetMore
|
||||
|
||||
def __init__(self, collection, filter=None, projection=None, skip=0,
|
||||
limit=0, no_cursor_timeout=False,
|
||||
cursor_type=CursorType.NON_TAILABLE,
|
||||
sort=None, allow_partial_results=False, oplog_replay=False,
|
||||
batch_size=0,
|
||||
collation=None, hint=None, max_scan=None, max_time_ms=None,
|
||||
max=None, min=None, return_key=None, show_record_id=None,
|
||||
snapshot=None, comment=None, session=None,
|
||||
allow_disk_use=None, let=None):
|
||||
def __init__(self,
|
||||
collection: "Collection[_DocumentType]",
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 0,
|
||||
no_cursor_timeout: bool = False,
|
||||
cursor_type: int = CursorType.NON_TAILABLE,
|
||||
sort: Optional[_Sort] = None,
|
||||
allow_partial_results: bool = False,
|
||||
oplog_replay: bool = False,
|
||||
batch_size: int = 0,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_Hint] = None,
|
||||
max_scan: Optional[int] = None,
|
||||
max_time_ms: Optional[int] = None,
|
||||
max: Optional[_Sort] = None,
|
||||
min: Optional[_Sort] = None,
|
||||
return_key: Optional[bool] = None,
|
||||
show_record_id: Optional[bool] = None,
|
||||
snapshot: Optional[bool] = None,
|
||||
comment: Any = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
allow_disk_use: Optional[bool] = None,
|
||||
let: Optional[bool] = None
|
||||
) -> None:
|
||||
"""Create a new cursor.
|
||||
|
||||
Should not be called directly by application developers - see
|
||||
@ -151,11 +173,12 @@ class Cursor(object):
|
||||
"""
|
||||
# Initialize all attributes used in __del__ before possibly raising
|
||||
# an error to avoid attribute errors during garbage collection.
|
||||
self.__collection = collection
|
||||
self.__id = None
|
||||
self.__collection: Collection[_DocumentType] = collection
|
||||
self.__id: Any = None
|
||||
self.__exhaust = False
|
||||
self.__sock_mgr = None
|
||||
self.__sock_mgr: Any = None
|
||||
self.__killed = False
|
||||
self.__session: Optional["ClientSession"]
|
||||
|
||||
if session:
|
||||
self.__session = session
|
||||
@ -164,10 +187,7 @@ class Cursor(object):
|
||||
self.__session = None
|
||||
self.__explicit_session = False
|
||||
|
||||
spec = filter
|
||||
if spec is None:
|
||||
spec = {}
|
||||
|
||||
spec: Mapping[str, Any] = filter or {}
|
||||
validate_is_mapping("filter", spec)
|
||||
if not isinstance(skip, int):
|
||||
raise TypeError("skip must be an instance of int")
|
||||
@ -203,6 +223,7 @@ class Cursor(object):
|
||||
|
||||
self.__let = let
|
||||
self.__spec = spec
|
||||
self.__has_filter = filter is not None
|
||||
self.__projection = projection
|
||||
self.__skip = skip
|
||||
self.__limit = limit
|
||||
@ -212,9 +233,9 @@ class Cursor(object):
|
||||
self.__explain = False
|
||||
self.__comment = comment
|
||||
self.__max_time_ms = max_time_ms
|
||||
self.__max_await_time_ms = None
|
||||
self.__max = max
|
||||
self.__min = min
|
||||
self.__max_await_time_ms: Optional[int] = None
|
||||
self.__max: Optional[Union[SON[Any, Any], _Sort]] = max
|
||||
self.__min: Optional[Union[SON[Any, Any], _Sort]] = min
|
||||
self.__collation = validate_collation_or_none(collation)
|
||||
self.__return_key = return_key
|
||||
self.__show_record_id = show_record_id
|
||||
@ -239,7 +260,7 @@ class Cursor(object):
|
||||
# it anytime we change __limit.
|
||||
self.__empty = False
|
||||
|
||||
self.__data = deque()
|
||||
self.__data: deque = deque()
|
||||
self.__address = None
|
||||
self.__retrieved = 0
|
||||
|
||||
@ -261,22 +282,22 @@ class Cursor(object):
|
||||
self.__collname = collection.name
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
def collection(self) -> "Collection[_DocumentType]":
|
||||
"""The :class:`~pymongo.collection.Collection` that this
|
||||
:class:`Cursor` is iterating.
|
||||
"""
|
||||
return self.__collection
|
||||
|
||||
@property
|
||||
def retrieved(self):
|
||||
def retrieved(self) -> int:
|
||||
"""The number of documents retrieved so far.
|
||||
"""
|
||||
return self.__retrieved
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
self.__die()
|
||||
|
||||
def rewind(self):
|
||||
def rewind(self) -> "Cursor[_DocumentType]":
|
||||
"""Rewind this cursor to its unevaluated state.
|
||||
|
||||
Reset this cursor if it has been partially or completely evaluated.
|
||||
@ -294,7 +315,7 @@ class Cursor(object):
|
||||
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
def clone(self) -> "Cursor[_DocumentType]":
|
||||
"""Get a clone of this cursor.
|
||||
|
||||
Returns a new Cursor instance with options matching those that have
|
||||
@ -318,7 +339,7 @@ class Cursor(object):
|
||||
"batch_size", "max_scan",
|
||||
"query_flags", "collation", "empty",
|
||||
"show_record_id", "return_key", "allow_disk_use",
|
||||
"snapshot", "exhaust")
|
||||
"snapshot", "exhaust", "has_filter")
|
||||
data = dict((k, v) for k, v in self.__dict__.items()
|
||||
if k.startswith('_Cursor__') and k[9:] in values_to_clone)
|
||||
if deepcopy:
|
||||
@ -360,7 +381,7 @@ class Cursor(object):
|
||||
self.__session = None
|
||||
self.__sock_mgr = None
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Explicitly close / kill this cursor.
|
||||
"""
|
||||
self.__die(True)
|
||||
@ -397,7 +418,7 @@ class Cursor(object):
|
||||
|
||||
if operators:
|
||||
# Make a shallow copy so we can cleanly rewind or clone.
|
||||
spec = self.__spec.copy()
|
||||
spec = copy.copy(self.__spec)
|
||||
|
||||
# Allow-listed commands must be wrapped in $query.
|
||||
if "$query" not in spec:
|
||||
@ -429,7 +450,7 @@ class Cursor(object):
|
||||
if self.__retrieved or self.__id is not None:
|
||||
raise InvalidOperation("cannot set options after executing query")
|
||||
|
||||
def add_option(self, mask):
|
||||
def add_option(self, mask: int) -> "Cursor[_DocumentType]":
|
||||
"""Set arbitrary query flags using a bitmask.
|
||||
|
||||
To set the tailable flag:
|
||||
@ -450,7 +471,7 @@ class Cursor(object):
|
||||
self.__query_flags |= mask
|
||||
return self
|
||||
|
||||
def remove_option(self, mask):
|
||||
def remove_option(self, mask: int) -> "Cursor[_DocumentType]":
|
||||
"""Unset arbitrary query flags using a bitmask.
|
||||
|
||||
To unset the tailable flag:
|
||||
@ -466,7 +487,7 @@ class Cursor(object):
|
||||
self.__query_flags &= ~mask
|
||||
return self
|
||||
|
||||
def allow_disk_use(self, allow_disk_use):
|
||||
def allow_disk_use(self, allow_disk_use: bool) -> "Cursor[_DocumentType]":
|
||||
"""Specifies whether MongoDB can use temporary disk files while
|
||||
processing a blocking sort operation.
|
||||
|
||||
@ -488,7 +509,7 @@ class Cursor(object):
|
||||
self.__allow_disk_use = allow_disk_use
|
||||
return self
|
||||
|
||||
def limit(self, limit):
|
||||
def limit(self, limit: int) -> "Cursor[_DocumentType]":
|
||||
"""Limits the number of results to be returned by this cursor.
|
||||
|
||||
Raises :exc:`TypeError` if `limit` is not an integer. Raises
|
||||
@ -511,7 +532,7 @@ class Cursor(object):
|
||||
self.__limit = limit
|
||||
return self
|
||||
|
||||
def batch_size(self, batch_size):
|
||||
def batch_size(self, batch_size: int) -> "Cursor[_DocumentType]":
|
||||
"""Limits the number of documents returned in one batch. Each batch
|
||||
requires a round trip to the server. It can be adjusted to optimize
|
||||
performance and limit data transfer.
|
||||
@ -539,7 +560,7 @@ class Cursor(object):
|
||||
self.__batch_size = batch_size
|
||||
return self
|
||||
|
||||
def skip(self, skip):
|
||||
def skip(self, skip: int) -> "Cursor[_DocumentType]":
|
||||
"""Skips the first `skip` results of this cursor.
|
||||
|
||||
Raises :exc:`TypeError` if `skip` is not an integer. Raises
|
||||
@ -560,7 +581,7 @@ class Cursor(object):
|
||||
self.__skip = skip
|
||||
return self
|
||||
|
||||
def max_time_ms(self, max_time_ms):
|
||||
def max_time_ms(self, max_time_ms: Optional[int]) -> "Cursor[_DocumentType]":
|
||||
"""Specifies a time limit for a query operation. If the specified
|
||||
time is exceeded, the operation will be aborted and
|
||||
:exc:`~pymongo.errors.ExecutionTimeout` is raised. If `max_time_ms`
|
||||
@ -581,7 +602,7 @@ class Cursor(object):
|
||||
self.__max_time_ms = max_time_ms
|
||||
return self
|
||||
|
||||
def max_await_time_ms(self, max_await_time_ms):
|
||||
def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> "Cursor[_DocumentType]":
|
||||
"""Specifies a time limit for a getMore operation on a
|
||||
:attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` cursor. For all other
|
||||
types of cursor max_await_time_ms is ignored.
|
||||
@ -609,6 +630,14 @@ class Cursor(object):
|
||||
|
||||
return self
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> _DocumentType:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> "Cursor[_DocumentType]":
|
||||
...
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Get a single document or a slice of documents from this cursor.
|
||||
|
||||
@ -691,7 +720,7 @@ class Cursor(object):
|
||||
raise TypeError("index %r cannot be applied to Cursor "
|
||||
"instances" % index)
|
||||
|
||||
def max_scan(self, max_scan):
|
||||
def max_scan(self, max_scan: Optional[int]) -> "Cursor[_DocumentType]":
|
||||
"""**DEPRECATED** - Limit the number of documents to scan when
|
||||
performing the query.
|
||||
|
||||
@ -711,7 +740,7 @@ class Cursor(object):
|
||||
self.__max_scan = max_scan
|
||||
return self
|
||||
|
||||
def max(self, spec):
|
||||
def max(self, spec: _Sort) -> "Cursor[_DocumentType]":
|
||||
"""Adds ``max`` operator that specifies upper bound for specific index.
|
||||
|
||||
When using ``max``, :meth:`~hint` should also be configured to ensure
|
||||
@ -734,7 +763,7 @@ class Cursor(object):
|
||||
self.__max = SON(spec)
|
||||
return self
|
||||
|
||||
def min(self, spec):
|
||||
def min(self, spec: _Sort) -> "Cursor[_DocumentType]":
|
||||
"""Adds ``min`` operator that specifies lower bound for specific index.
|
||||
|
||||
When using ``min``, :meth:`~hint` should also be configured to ensure
|
||||
@ -757,7 +786,7 @@ class Cursor(object):
|
||||
self.__min = SON(spec)
|
||||
return self
|
||||
|
||||
def sort(self, key_or_list, direction=None):
|
||||
def sort(self, key_or_list: _Hint, direction: Optional[Union[int, str]] = None) -> "Cursor[_DocumentType]":
|
||||
"""Sorts this cursor's results.
|
||||
|
||||
Pass a field name and a direction, either
|
||||
@ -803,7 +832,7 @@ class Cursor(object):
|
||||
self.__ordering = helpers._index_document(keys)
|
||||
return self
|
||||
|
||||
def distinct(self, key):
|
||||
def distinct(self, key: str) -> List:
|
||||
"""Get a list of distinct values for `key` among all documents
|
||||
in the result set of this query.
|
||||
|
||||
@ -820,7 +849,7 @@ class Cursor(object):
|
||||
|
||||
.. seealso:: :meth:`pymongo.collection.Collection.distinct`
|
||||
"""
|
||||
options = {}
|
||||
options: Dict[str, Any] = {}
|
||||
if self.__spec:
|
||||
options["query"] = self.__spec
|
||||
if self.__max_time_ms is not None:
|
||||
@ -833,7 +862,7 @@ class Cursor(object):
|
||||
return self.__collection.distinct(
|
||||
key, session=self.__session, **options)
|
||||
|
||||
def explain(self):
|
||||
def explain(self) -> _DocumentType:
|
||||
"""Returns an explain plan record for this cursor.
|
||||
|
||||
.. note:: This method uses the default verbosity mode of the
|
||||
@ -863,7 +892,7 @@ class Cursor(object):
|
||||
else:
|
||||
self.__hint = helpers._index_document(index)
|
||||
|
||||
def hint(self, index):
|
||||
def hint(self, index: Optional[_Hint]) -> "Cursor[_DocumentType]":
|
||||
"""Adds a 'hint', telling Mongo the proper index to use for the query.
|
||||
|
||||
Judicious use of hints can greatly improve query
|
||||
@ -888,7 +917,7 @@ class Cursor(object):
|
||||
self.__set_hint(index)
|
||||
return self
|
||||
|
||||
def comment(self, comment):
|
||||
def comment(self, comment: Any) -> "Cursor[_DocumentType]":
|
||||
"""Adds a 'comment' to the cursor.
|
||||
|
||||
http://docs.mongodb.org/manual/reference/operator/comment/
|
||||
@ -903,7 +932,7 @@ class Cursor(object):
|
||||
self.__comment = comment
|
||||
return self
|
||||
|
||||
def where(self, code):
|
||||
def where(self, code: Union[str, Code]) -> "Cursor[_DocumentType]":
|
||||
"""Adds a `$where`_ clause to this query.
|
||||
|
||||
The `code` argument must be an instance of :class:`basestring`
|
||||
@ -937,10 +966,18 @@ class Cursor(object):
|
||||
if not isinstance(code, Code):
|
||||
code = Code(code)
|
||||
|
||||
self.__spec["$where"] = code
|
||||
# Avoid overwriting a filter argument that was given by the user
|
||||
# when updating the spec.
|
||||
spec: Dict[str, Any]
|
||||
if self.__has_filter:
|
||||
spec = dict(self.__spec)
|
||||
else:
|
||||
spec = cast(Dict, self.__spec)
|
||||
spec["$where"] = code
|
||||
self.__spec = spec
|
||||
return self
|
||||
|
||||
def collation(self, collation):
|
||||
def collation(self, collation: Optional[_CollationIn]) -> "Cursor[_DocumentType]":
|
||||
"""Adds a :class:`~pymongo.collation.Collation` to this query.
|
||||
|
||||
Raises :exc:`TypeError` if `collation` is not an instance of
|
||||
@ -1106,7 +1143,7 @@ class Cursor(object):
|
||||
return len(self.__data)
|
||||
|
||||
@property
|
||||
def alive(self):
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
This is mostly useful with `tailable cursors
|
||||
@ -1128,7 +1165,7 @@ class Cursor(object):
|
||||
return bool(len(self.__data) or (not self.__killed))
|
||||
|
||||
@property
|
||||
def cursor_id(self):
|
||||
def cursor_id(self) -> Optional[int]:
|
||||
"""Returns the id of the cursor
|
||||
|
||||
.. versionadded:: 2.2
|
||||
@ -1136,7 +1173,7 @@ class Cursor(object):
|
||||
return self.__id
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
def address(self) -> Optional[Tuple[str, Any]]:
|
||||
"""The (host, port) of the server used, or None.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
@ -1145,18 +1182,19 @@ class Cursor(object):
|
||||
return self.__address
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
def session(self) -> Optional["ClientSession"]:
|
||||
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
if self.__explicit_session:
|
||||
return self.__session
|
||||
return None
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> "Cursor[_DocumentType]":
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
def next(self) -> _DocumentType:
|
||||
"""Advance the cursor."""
|
||||
if self.__empty:
|
||||
raise StopIteration
|
||||
@ -1167,20 +1205,20 @@ class Cursor(object):
|
||||
|
||||
__next__ = next
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "Cursor[_DocumentType]":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __copy__(self):
|
||||
def __copy__(self) -> "Cursor[_DocumentType]":
|
||||
"""Support function for `copy.copy()`.
|
||||
|
||||
.. versionadded:: 2.4
|
||||
"""
|
||||
return self._clone(deepcopy=False)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
def __deepcopy__(self, memo: Any) -> Any:
|
||||
"""Support function for `copy.deepcopy()`.
|
||||
|
||||
.. versionadded:: 2.4
|
||||
@ -1193,6 +1231,7 @@ class Cursor(object):
|
||||
Regular expressions cannot be deep copied but as they are immutable we
|
||||
don't have to copy them when cloning.
|
||||
"""
|
||||
y: Any
|
||||
if not hasattr(x, 'items'):
|
||||
y, is_list, iterator = [], True, enumerate(x)
|
||||
else:
|
||||
@ -1220,13 +1259,13 @@ class Cursor(object):
|
||||
return y
|
||||
|
||||
|
||||
class RawBatchCursor(Cursor):
|
||||
class RawBatchCursor(Cursor, Generic[_DocumentType]):
|
||||
"""A cursor / iterator over raw batches of BSON data from a query result."""
|
||||
|
||||
_query_class = _RawBatchQuery
|
||||
_getmore_class = _RawBatchGetMore
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, collection: "Collection[_DocumentType]", *args: Any, **kwargs: Any) -> None:
|
||||
"""Create a new cursor / iterator over raw batches of BSON data.
|
||||
|
||||
Should not be called directly by application developers -
|
||||
@ -1235,7 +1274,7 @@ class RawBatchCursor(Cursor):
|
||||
|
||||
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
|
||||
"""
|
||||
super(RawBatchCursor, self).__init__(*args, **kwargs)
|
||||
super(RawBatchCursor, self).__init__(collection, *args, **kwargs)
|
||||
|
||||
def _unpack_response(self, response, cursor_id, codec_options,
|
||||
user_fields=None, legacy_response=False):
|
||||
@ -1247,7 +1286,7 @@ class RawBatchCursor(Cursor):
|
||||
_convert_raw_document_lists_to_streams(raw_response[0])
|
||||
return raw_response
|
||||
|
||||
def explain(self):
|
||||
def explain(self) -> _DocumentType:
|
||||
"""Returns an explain plan record for this cursor.
|
||||
|
||||
.. seealso:: The MongoDB documentation on `explain <https://dochub.mongodb.org/core/explain>`_.
|
||||
@ -1255,5 +1294,5 @@ class RawBatchCursor(Cursor):
|
||||
clone = self._clone(deepcopy=True, base=Cursor(self.collection))
|
||||
return clone.explain()
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: Any) -> "Cursor[_DocumentType]":
|
||||
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")
|
||||
|
||||
@ -13,18 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Database level operations."""
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Mapping, MutableMapping, Optional,
|
||||
Sequence, Union)
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
|
||||
from bson.dbref import DBRef
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import common
|
||||
from pymongo.aggregation import _DatabaseAggregationCommand
|
||||
from pymongo.change_stream import DatabaseChangeStream
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.errors import (CollectionInvalid,
|
||||
InvalidName)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.errors import CollectionInvalid, InvalidName
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
|
||||
|
||||
def _check_name(name):
|
||||
@ -39,12 +42,24 @@ def _check_name(name):
|
||||
"character %r" % invalid_char)
|
||||
|
||||
|
||||
class Database(common.BaseObject):
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
"""A Mongo database.
|
||||
"""
|
||||
|
||||
def __init__(self, client, name, codec_options=None, read_preference=None,
|
||||
write_concern=None, read_concern=None):
|
||||
def __init__(self,
|
||||
client: "MongoClient[_DocumentType]",
|
||||
name: str,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional["WriteConcern"] = None,
|
||||
read_concern: Optional["ReadConcern"] = None,
|
||||
) -> None:
|
||||
"""Get a database by client and name.
|
||||
|
||||
Raises :class:`TypeError` if `name` is not an instance of
|
||||
@ -104,20 +119,24 @@ class Database(common.BaseObject):
|
||||
_check_name(name)
|
||||
|
||||
self.__name = name
|
||||
self.__client = client
|
||||
self.__client: MongoClient[_DocumentType] = client
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
def client(self) -> "MongoClient[_DocumentType]":
|
||||
"""The client instance for this :class:`Database`."""
|
||||
return self.__client
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""The name of this :class:`Database`."""
|
||||
return self.__name
|
||||
|
||||
def with_options(self, codec_options=None, read_preference=None,
|
||||
write_concern=None, read_concern=None):
|
||||
def with_options(self,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional["WriteConcern"] = None,
|
||||
read_concern: Optional["ReadConcern"] = None,
|
||||
) -> "Database[_DocumentType]":
|
||||
"""Get a clone of this database changing the specified settings.
|
||||
|
||||
>>> db1.read_preference
|
||||
@ -156,22 +175,22 @@ class Database(common.BaseObject):
|
||||
write_concern or self.write_concern,
|
||||
read_concern or self.read_concern)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Database):
|
||||
return (self.__client == other.client and
|
||||
self.__name == other.name)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.__client, self.__name))
|
||||
|
||||
def __repr__(self):
|
||||
return "Database(%r, %r)" % (self.__client, self.__name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> Collection[_DocumentType]:
|
||||
"""Get a collection of this database by name.
|
||||
|
||||
Raises InvalidName if an invalid collection name is used.
|
||||
@ -185,7 +204,7 @@ class Database(common.BaseObject):
|
||||
" collection, use database[%r]." % (name, name, name))
|
||||
return self.__getitem__(name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
def __getitem__(self, name: str) -> "Collection[_DocumentType]":
|
||||
"""Get a collection of this database by name.
|
||||
|
||||
Raises InvalidName if an invalid collection name is used.
|
||||
@ -195,8 +214,13 @@ class Database(common.BaseObject):
|
||||
"""
|
||||
return Collection(self, name)
|
||||
|
||||
def get_collection(self, name, codec_options=None, read_preference=None,
|
||||
write_concern=None, read_concern=None):
|
||||
def get_collection(self,
|
||||
name: str,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional["WriteConcern"] = None,
|
||||
read_concern: Optional["ReadConcern"] = None,
|
||||
) -> Collection[_DocumentType]:
|
||||
"""Get a :class:`~pymongo.collection.Collection` with the given name
|
||||
and options.
|
||||
|
||||
@ -238,9 +262,15 @@ class Database(common.BaseObject):
|
||||
self, name, False, codec_options, read_preference,
|
||||
write_concern, read_concern)
|
||||
|
||||
def create_collection(self, name, codec_options=None,
|
||||
read_preference=None, write_concern=None,
|
||||
read_concern=None, session=None, **kwargs):
|
||||
def create_collection(self,
|
||||
name: str,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional["WriteConcern"] = None,
|
||||
read_concern: Optional["ReadConcern"] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Collection[_DocumentType]:
|
||||
"""Create a new :class:`~pymongo.collection.Collection` in this
|
||||
database.
|
||||
|
||||
@ -286,20 +316,20 @@ class Database(common.BaseObject):
|
||||
timeseries collections
|
||||
- ``expireAfterSeconds`` (int): the number of seconds after which a
|
||||
document in a timeseries collection expires
|
||||
- ``validator`` (dict): a document specifying validation rules or expressions
|
||||
- ``validator`` (dict): a document specifying validation rules or expressions
|
||||
for the collection
|
||||
- ``validationLevel`` (str): how strictly to apply the
|
||||
- ``validationLevel`` (str): how strictly to apply the
|
||||
validation rules to existing documents during an update. The default level
|
||||
is "strict"
|
||||
- ``validationAction`` (str): whether to "error" on invalid documents
|
||||
(the default) or just "warn" about the violations but allow invalid
|
||||
(the default) or just "warn" about the violations but allow invalid
|
||||
documents to be inserted
|
||||
- ``indexOptionDefaults`` (dict): a document specifying a default configuration
|
||||
for indexes when creating a collection
|
||||
- ``viewOn`` (str): the name of the source collection or view from which
|
||||
- ``viewOn`` (str): the name of the source collection or view from which
|
||||
to create the view
|
||||
- ``pipeline`` (list): a list of aggregation pipeline stages
|
||||
- ``comment`` (str): a user-provided comment to attach to this command.
|
||||
- ``comment`` (str): a user-provided comment to attach to this command.
|
||||
This option is only supported on MongoDB >= 4.4.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
@ -330,7 +360,11 @@ class Database(common.BaseObject):
|
||||
read_preference, write_concern,
|
||||
read_concern, session=s, **kwargs)
|
||||
|
||||
def aggregate(self, pipeline, session=None, **kwargs):
|
||||
def aggregate(self,
|
||||
pipeline: _Pipeline,
|
||||
session: Optional["ClientSession"] = None,
|
||||
**kwargs: Any
|
||||
) -> CommandCursor[_DocumentType]:
|
||||
"""Perform a database-level aggregation.
|
||||
|
||||
See the `aggregation pipeline`_ documentation for a list of stages
|
||||
@ -400,9 +434,17 @@ class Database(common.BaseObject):
|
||||
cmd.get_cursor, cmd.get_read_preference(s), s,
|
||||
retryable=not cmd._performs_write)
|
||||
|
||||
def watch(self, pipeline=None, full_document=None, resume_after=None,
|
||||
max_await_time_ms=None, batch_size=None, collation=None,
|
||||
start_at_operation_time=None, session=None, start_after=None):
|
||||
def watch(self,
|
||||
pipeline: Optional[_Pipeline] = None,
|
||||
full_document: Optional[str] = None,
|
||||
resume_after: Optional[Mapping[str, Any]] = None,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
start_at_operation_time: Optional[Timestamp] = None,
|
||||
session: Optional["ClientSession"] = None,
|
||||
start_after: Optional[Mapping[str, Any]] = None,
|
||||
) -> DatabaseChangeStream[_DocumentType]:
|
||||
"""Watch changes on this database.
|
||||
|
||||
Performs an aggregation with an implicit initial ``$changeStream``
|
||||
@ -515,9 +557,16 @@ class Database(common.BaseObject):
|
||||
session=s,
|
||||
client=self.__client)
|
||||
|
||||
def command(self, command, value=1, check=True,
|
||||
allowable_errors=None, read_preference=None,
|
||||
codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs):
|
||||
def command(self,
|
||||
command: Union[str, MutableMapping[str, Any]],
|
||||
value: Any = 1,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[CodecOptions] = DEFAULT_CODEC_OPTIONS,
|
||||
session: Optional["ClientSession"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Issue a MongoDB command.
|
||||
|
||||
Send command `command` to the database and return the
|
||||
@ -648,7 +697,11 @@ class Database(common.BaseObject):
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
return cmd_cursor
|
||||
|
||||
def list_collections(self, session=None, filter=None, **kwargs):
|
||||
def list_collections(self,
|
||||
session: Optional["ClientSession"] = None,
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> CommandCursor[Dict[str, Any]]:
|
||||
"""Get a cursor over the collections of this database.
|
||||
|
||||
:Parameters:
|
||||
@ -680,7 +733,11 @@ class Database(common.BaseObject):
|
||||
return self.__client._retryable_read(
|
||||
_cmd, read_pref, session)
|
||||
|
||||
def list_collection_names(self, session=None, filter=None, **kwargs):
|
||||
def list_collection_names(self,
|
||||
session: Optional["ClientSession"] = None,
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Get a list of all the collection names in this database.
|
||||
|
||||
For example, to list all non-system collections::
|
||||
@ -717,7 +774,10 @@ class Database(common.BaseObject):
|
||||
return [result["name"]
|
||||
for result in self.list_collections(session=session, **kwargs)]
|
||||
|
||||
def drop_collection(self, name_or_collection, session=None):
|
||||
def drop_collection(self,
|
||||
name_or_collection: Union[str, Collection],
|
||||
session: Optional["ClientSession"] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Drop a collection.
|
||||
|
||||
:Parameters:
|
||||
@ -752,9 +812,13 @@ class Database(common.BaseObject):
|
||||
parse_write_concern_error=True,
|
||||
session=session)
|
||||
|
||||
def validate_collection(self, name_or_collection,
|
||||
scandata=False, full=False, session=None,
|
||||
background=None):
|
||||
def validate_collection(self,
|
||||
name_or_collection: Union[str, Collection],
|
||||
scandata: bool = False,
|
||||
full: bool = False,
|
||||
session: Optional["ClientSession"] = None,
|
||||
background: Optional[bool] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate a collection.
|
||||
|
||||
Returns a dict of validation info. Raises CollectionInvalid if
|
||||
@ -827,20 +891,23 @@ class Database(common.BaseObject):
|
||||
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> "Database[_DocumentType]":
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
def __next__(self) -> "Database[_DocumentType]":
|
||||
raise TypeError("'Database' object is not iterable")
|
||||
|
||||
next = __next__
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
raise NotImplementedError("Database objects do not implement truth "
|
||||
"value testing or bool(). Please compare "
|
||||
"with None instead: database is not None")
|
||||
|
||||
def dereference(self, dbref, session=None, **kwargs):
|
||||
def dereference(self, dbref: DBRef,
|
||||
session: Optional["ClientSession"] = None,
|
||||
**kwargs: Any
|
||||
) -> Optional[_DocumentType]:
|
||||
"""Dereference a :class:`~bson.dbref.DBRef`, getting the
|
||||
document it points to.
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Advanced options for MongoDB drivers implemented on top of PyMongo."""
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])):
|
||||
@ -26,7 +27,7 @@ class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])):
|
||||
like 'MyDriver', '1.2.3', 'some platform info'. Any of these strings may be
|
||||
None to accept PyMongo's default.
|
||||
"""
|
||||
def __new__(cls, name, version=None, platform=None):
|
||||
def __new__(cls, name: str, version: Optional[str] = None, platform: Optional[str] = None) -> "DriverInfo":
|
||||
self = super(DriverInfo, cls).__new__(cls, name, version, platform)
|
||||
for key, value in self._asdict().items():
|
||||
if value is not None and not isinstance(value, str):
|
||||
|
||||
@ -15,15 +15,16 @@
|
||||
"""Support for explicit client-side field level encryption."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
import weakref
|
||||
from typing import Any, Mapping, Optional, Sequence
|
||||
|
||||
try:
|
||||
from pymongocrypt.auto_encrypter import AutoEncrypter
|
||||
from pymongocrypt.errors import MongoCryptError
|
||||
from pymongocrypt.explicit_encrypter import ExplicitEncrypter
|
||||
from pymongocrypt.explicit_encrypter import (
|
||||
ExplicitEncrypter
|
||||
)
|
||||
from pymongocrypt.mongocrypt import MongoCryptOptions
|
||||
from pymongocrypt.state_machine import MongoCryptCallback
|
||||
_HAVE_PYMONGOCRYPT = True
|
||||
@ -32,29 +33,22 @@ except ImportError:
|
||||
MongoCryptCallback = object
|
||||
|
||||
from bson import _dict_to_bson, decode, encode
|
||||
from bson.binary import STANDARD, UUID_SUBTYPE, Binary
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.binary import (Binary,
|
||||
STANDARD,
|
||||
UUID_SUBTYPE)
|
||||
from bson.errors import BSONError
|
||||
from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS,
|
||||
RawBSONDocument,
|
||||
from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument,
|
||||
_inflate_bson)
|
||||
from bson.son import SON
|
||||
|
||||
from pymongo.errors import (ConfigurationError,
|
||||
EncryptionError,
|
||||
InvalidOperation,
|
||||
ServerSelectionTimeoutError)
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
from pymongo.encryption_options import AutoEncryptionOpts
|
||||
from pymongo.errors import (ConfigurationError, EncryptionError,
|
||||
InvalidOperation, ServerSelectionTimeoutError)
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.pool import _configured_socket, PoolOptions
|
||||
from pymongo.pool import PoolOptions, _configured_socket
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.uri_parser import parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
|
||||
|
||||
_HTTPS_PORT = 443
|
||||
_KMS_CONNECT_TIMEOUT = 10 # TODO: CDRIVER-3262 will define this value.
|
||||
@ -80,9 +74,10 @@ def _wrap_encryption_errors():
|
||||
raise EncryptionError(exc)
|
||||
|
||||
|
||||
class _EncryptionIO(MongoCryptCallback):
|
||||
class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
def __init__(self, client, key_vault_coll, mongocryptd_client, opts):
|
||||
"""Internal class to perform I/O on behalf of pymongocrypt."""
|
||||
self.client_ref: Any
|
||||
# Use a weak ref to break reference cycle.
|
||||
if client is not None:
|
||||
self.client_ref = weakref.ref(client)
|
||||
@ -355,11 +350,17 @@ class Algorithm(object):
|
||||
"AEAD_AES_256_CBC_HMAC_SHA_512-Random")
|
||||
|
||||
|
||||
|
||||
class ClientEncryption(object):
|
||||
"""Explicit client-side field level encryption."""
|
||||
|
||||
def __init__(self, kms_providers, key_vault_namespace, key_vault_client,
|
||||
codec_options, kms_tls_options=None):
|
||||
def __init__(self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: MongoClient,
|
||||
codec_options: CodecOptions,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None
|
||||
) -> None:
|
||||
"""Explicit client-side field level encryption.
|
||||
|
||||
The ClientEncryption class encapsulates explicit operations on a key
|
||||
@ -449,12 +450,15 @@ class ClientEncryption(object):
|
||||
|
||||
opts = AutoEncryptionOpts(kms_providers, key_vault_namespace,
|
||||
kms_tls_options=kms_tls_options)
|
||||
self._io_callbacks = _EncryptionIO(None, key_vault_coll, None, opts)
|
||||
self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(None, key_vault_coll, None, opts)
|
||||
self._encryption = ExplicitEncrypter(
|
||||
self._io_callbacks, MongoCryptOptions(kms_providers, None))
|
||||
|
||||
def create_data_key(self, kms_provider, master_key=None,
|
||||
key_alt_names=None):
|
||||
def create_data_key(self,
|
||||
kms_provider: str,
|
||||
master_key: Optional[Mapping[str, Any]] = None,
|
||||
key_alt_names: Optional[Sequence[str]] = None
|
||||
) -> Binary:
|
||||
"""Create and insert a new data key into the key vault collection.
|
||||
|
||||
:Parameters:
|
||||
@ -526,7 +530,12 @@ class ClientEncryption(object):
|
||||
kms_provider, master_key=master_key,
|
||||
key_alt_names=key_alt_names)
|
||||
|
||||
def encrypt(self, value, algorithm, key_id=None, key_alt_name=None):
|
||||
def encrypt(self,
|
||||
value: Any,
|
||||
algorithm: str,
|
||||
key_id: Optional[Binary] = None,
|
||||
key_alt_name: Optional[str] = None
|
||||
) -> Binary:
|
||||
"""Encrypt a BSON value with a given key and algorithm.
|
||||
|
||||
Note that exactly one of ``key_id`` or ``key_alt_name`` must be
|
||||
@ -557,7 +566,7 @@ class ClientEncryption(object):
|
||||
doc, algorithm, key_id=key_id, key_alt_name=key_alt_name)
|
||||
return decode(encrypted_doc)['v']
|
||||
|
||||
def decrypt(self, value):
|
||||
def decrypt(self, value: Binary) -> Any:
|
||||
"""Decrypt an encrypted value.
|
||||
|
||||
:Parameters:
|
||||
@ -578,17 +587,17 @@ class ClientEncryption(object):
|
||||
return decode(decrypted_doc,
|
||||
codec_options=self._codec_options)['v']
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "ClientEncryption":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def _check_closed(self):
|
||||
if self._encryption is None:
|
||||
raise InvalidOperation("Cannot use closed ClientEncryption")
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Release resources.
|
||||
|
||||
Note that using this class in a with-statement will automatically call
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Support for automatic client-side field level encryption."""
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
|
||||
|
||||
try:
|
||||
import pymongocrypt
|
||||
@ -25,19 +26,25 @@ except ImportError:
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser import _parse_kms_tls_options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.mongo_client import MongoClient
|
||||
|
||||
|
||||
class AutoEncryptionOpts(object):
|
||||
"""Options to configure automatic client-side field level encryption."""
|
||||
|
||||
def __init__(self, kms_providers, key_vault_namespace,
|
||||
key_vault_client=None,
|
||||
schema_map=None,
|
||||
bypass_auto_encryption=False,
|
||||
mongocryptd_uri='mongodb://localhost:27020',
|
||||
mongocryptd_bypass_spawn=False,
|
||||
mongocryptd_spawn_path='mongocryptd',
|
||||
mongocryptd_spawn_args=None,
|
||||
kms_tls_options=None):
|
||||
def __init__(self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: Optional["MongoClient"] = None,
|
||||
schema_map: Optional[Mapping[str, Any]] = None,
|
||||
bypass_auto_encryption: Optional[bool] = False,
|
||||
mongocryptd_uri: str = 'mongodb://localhost:27020',
|
||||
mongocryptd_bypass_spawn: bool = False,
|
||||
mongocryptd_spawn_path: str = 'mongocryptd',
|
||||
mongocryptd_spawn_args: Optional[List[str]] = None,
|
||||
kms_tls_options: Optional[Mapping[str, Any]] = None
|
||||
) -> None:
|
||||
"""Options to configure automatic client-side field level encryption.
|
||||
|
||||
Automatic client-side field level encryption requires MongoDB 4.2
|
||||
@ -152,8 +159,9 @@ class AutoEncryptionOpts(object):
|
||||
self._mongocryptd_uri = mongocryptd_uri
|
||||
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
|
||||
self._mongocryptd_spawn_path = mongocryptd_spawn_path
|
||||
self._mongocryptd_spawn_args = (copy.copy(mongocryptd_spawn_args) or
|
||||
['--idleShutdownTimeoutSecs=60'])
|
||||
if mongocryptd_spawn_args is None:
|
||||
mongocryptd_spawn_args = ['--idleShutdownTimeoutSecs=60']
|
||||
self._mongocryptd_spawn_args = mongocryptd_spawn_args
|
||||
if not isinstance(self._mongocryptd_spawn_args, list):
|
||||
raise TypeError('mongocryptd_spawn_args must be a list')
|
||||
if not any('idleShutdownTimeoutSecs' in s
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Exceptions raised by PyMongo."""
|
||||
from typing import (Any, Iterable, List, Mapping, Optional, Sequence, Tuple,
|
||||
Union)
|
||||
|
||||
from bson.errors import *
|
||||
|
||||
@ -23,18 +25,21 @@ except ImportError:
|
||||
try:
|
||||
from ssl import CertificateError as _CertificateError
|
||||
except ImportError:
|
||||
class _CertificateError(ValueError):
|
||||
class _CertificateError(ValueError): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class PyMongoError(Exception):
|
||||
"""Base class for all PyMongo exceptions."""
|
||||
def __init__(self, message='', error_labels=None):
|
||||
def __init__(self,
|
||||
message: str = '',
|
||||
error_labels: Optional[Iterable[str]] = None
|
||||
) -> None:
|
||||
super(PyMongoError, self).__init__(message)
|
||||
self._message = message
|
||||
self._error_labels = set(error_labels or [])
|
||||
|
||||
def has_error_label(self, label):
|
||||
def has_error_label(self, label: str) -> bool:
|
||||
"""Return True if this error contains the given label.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
@ -70,10 +75,17 @@ class AutoReconnect(ConnectionFailure):
|
||||
|
||||
Subclass of :exc:`~pymongo.errors.ConnectionFailure`.
|
||||
"""
|
||||
def __init__(self, message='', errors=None):
|
||||
errors: Union[Mapping[str, Any], Sequence]
|
||||
details: Union[Mapping[str, Any], Sequence]
|
||||
|
||||
def __init__(self,
|
||||
message: str = '',
|
||||
errors: Optional[Union[Mapping[str, Any], Sequence]] = None
|
||||
) -> None:
|
||||
error_labels = None
|
||||
if errors is not None and isinstance(errors, dict):
|
||||
error_labels = errors.get('errorLabels')
|
||||
if errors is not None:
|
||||
if isinstance(errors, dict):
|
||||
error_labels = errors.get('errorLabels')
|
||||
super(AutoReconnect, self).__init__(message, error_labels)
|
||||
self.errors = self.details = errors or []
|
||||
|
||||
@ -109,7 +121,10 @@ class NotPrimaryError(AutoReconnect):
|
||||
|
||||
.. versionadded:: 3.12
|
||||
"""
|
||||
def __init__(self, message='', errors=None):
|
||||
def __init__(self,
|
||||
message: str = '',
|
||||
errors: Optional[Union[Mapping[str, Any], List]] = None
|
||||
) -> None:
|
||||
super(NotPrimaryError, self).__init__(
|
||||
_format_detailed_error(message, errors), errors=errors)
|
||||
|
||||
@ -139,7 +154,12 @@ class OperationFailure(PyMongoError):
|
||||
The :attr:`details` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, error, code=None, details=None, max_wire_version=None):
|
||||
def __init__(self,
|
||||
error: str,
|
||||
code: Optional[int] = None,
|
||||
details: Optional[Mapping[str, Any]] = None,
|
||||
max_wire_version: Optional[int] = None,
|
||||
) -> None:
|
||||
error_labels = None
|
||||
if details is not None:
|
||||
error_labels = details.get('errorLabels')
|
||||
@ -154,13 +174,13 @@ class OperationFailure(PyMongoError):
|
||||
return self.__max_wire_version
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
def code(self) -> Optional[int]:
|
||||
"""The error code returned by the server, if any.
|
||||
"""
|
||||
return self.__code
|
||||
|
||||
@property
|
||||
def details(self):
|
||||
def details(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The complete error document returned by the server.
|
||||
|
||||
Depending on the error that occurred, the error document
|
||||
@ -225,14 +245,17 @@ class BulkWriteError(OperationFailure):
|
||||
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
def __init__(self, results):
|
||||
details: Mapping[str, Any]
|
||||
|
||||
def __init__(self, results: Mapping[str, Any]) -> None:
|
||||
super(BulkWriteError, self).__init__(
|
||||
"batch op errors occurred", 65, results)
|
||||
|
||||
def __reduce__(self):
|
||||
def __reduce__(self) -> Tuple[Any, Any]:
|
||||
return self.__class__, (self.details,)
|
||||
|
||||
|
||||
|
||||
class InvalidOperation(PyMongoError):
|
||||
"""Raised when a client attempts to perform an invalid operation."""
|
||||
|
||||
@ -264,12 +287,12 @@ class EncryptionError(PyMongoError):
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
|
||||
def __init__(self, cause):
|
||||
def __init__(self, cause: Exception) -> None:
|
||||
super(EncryptionError, self).__init__(str(cause))
|
||||
self.__cause = cause
|
||||
|
||||
@property
|
||||
def cause(self):
|
||||
def cause(self) -> Exception:
|
||||
"""The exception that caused this encryption or decryption error."""
|
||||
return self.__cause
|
||||
|
||||
|
||||
@ -26,8 +26,6 @@ or
|
||||
|
||||
``MongoClient(event_listeners=[CommandLogger()])``
|
||||
"""
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from pymongo import monitoring
|
||||
@ -42,18 +40,18 @@ class CommandLogger(monitoring.CommandListener):
|
||||
logs them at the `INFO` severity level using :mod:`logging`.
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
def started(self, event):
|
||||
def started(self, event: monitoring.CommandStartedEvent) -> None:
|
||||
logging.info("Command {0.command_name} with request id "
|
||||
"{0.request_id} started on server "
|
||||
"{0.connection_id}".format(event))
|
||||
|
||||
def succeeded(self, event):
|
||||
def succeeded(self, event: monitoring.CommandSucceededEvent) -> None:
|
||||
logging.info("Command {0.command_name} with request id "
|
||||
"{0.request_id} on server {0.connection_id} "
|
||||
"succeeded in {0.duration_micros} "
|
||||
"microseconds".format(event))
|
||||
|
||||
def failed(self, event):
|
||||
def failed(self, event: monitoring.CommandFailedEvent) -> None:
|
||||
logging.info("Command {0.command_name} with request id "
|
||||
"{0.request_id} on server {0.connection_id} "
|
||||
"failed in {0.duration_micros} "
|
||||
@ -70,11 +68,11 @@ class ServerLogger(monitoring.ServerListener):
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
def opened(self, event):
|
||||
def opened(self, event: monitoring.ServerOpeningEvent) -> None:
|
||||
logging.info("Server {0.server_address} added to topology "
|
||||
"{0.topology_id}".format(event))
|
||||
|
||||
def description_changed(self, event):
|
||||
def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None:
|
||||
previous_server_type = event.previous_description.server_type
|
||||
new_server_type = event.new_description.server_type
|
||||
if new_server_type != previous_server_type:
|
||||
@ -84,7 +82,7 @@ class ServerLogger(monitoring.ServerListener):
|
||||
"{0.previous_description.server_type_name} to "
|
||||
"{0.new_description.server_type_name}".format(event))
|
||||
|
||||
def closed(self, event):
|
||||
def closed(self, event: monitoring.ServerClosedEvent) -> None:
|
||||
logging.warning("Server {0.server_address} removed from topology "
|
||||
"{0.topology_id}".format(event))
|
||||
|
||||
@ -99,17 +97,17 @@ class HeartbeatLogger(monitoring.ServerHeartbeatListener):
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
def started(self, event):
|
||||
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
|
||||
logging.info("Heartbeat sent to server "
|
||||
"{0.connection_id}".format(event))
|
||||
|
||||
def succeeded(self, event):
|
||||
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
|
||||
# The reply.document attribute was added in PyMongo 3.4.
|
||||
logging.info("Heartbeat to server {0.connection_id} "
|
||||
"succeeded with reply "
|
||||
"{0.reply.document}".format(event))
|
||||
|
||||
def failed(self, event):
|
||||
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
|
||||
logging.warning("Heartbeat to server {0.connection_id} "
|
||||
"failed with error {0.reply}".format(event))
|
||||
|
||||
@ -124,11 +122,11 @@ class TopologyLogger(monitoring.TopologyListener):
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
def opened(self, event):
|
||||
def opened(self, event: monitoring.TopologyOpenedEvent) -> None:
|
||||
logging.info("Topology with id {0.topology_id} "
|
||||
"opened".format(event))
|
||||
|
||||
def description_changed(self, event):
|
||||
def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None:
|
||||
logging.info("Topology description updated for "
|
||||
"topology id {0.topology_id}".format(event))
|
||||
previous_topology_type = event.previous_description.topology_type
|
||||
@ -146,7 +144,7 @@ class TopologyLogger(monitoring.TopologyListener):
|
||||
if not event.new_description.has_readable_server():
|
||||
logging.warning("No readable servers available.")
|
||||
|
||||
def closed(self, event):
|
||||
def closed(self, event: monitoring.TopologyClosedEvent) -> None:
|
||||
logging.info("Topology with id {0.topology_id} "
|
||||
"closed".format(event))
|
||||
|
||||
@ -168,43 +166,43 @@ class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
|
||||
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
def pool_created(self, event):
|
||||
def pool_created(self, event: monitoring.PoolCreatedEvent) -> None:
|
||||
logging.info("[pool {0.address}] pool created".format(event))
|
||||
|
||||
def pool_ready(self, event):
|
||||
logging.info("[pool {0.address}] pool ready".format(event))
|
||||
|
||||
def pool_cleared(self, event):
|
||||
def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None:
|
||||
logging.info("[pool {0.address}] pool cleared".format(event))
|
||||
|
||||
def pool_closed(self, event):
|
||||
def pool_closed(self, event: monitoring.PoolClosedEvent) -> None:
|
||||
logging.info("[pool {0.address}] pool closed".format(event))
|
||||
|
||||
def connection_created(self, event):
|
||||
def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None:
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
"connection created".format(event))
|
||||
|
||||
def connection_ready(self, event):
|
||||
def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None:
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
"connection setup succeeded".format(event))
|
||||
|
||||
def connection_closed(self, event):
|
||||
def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None:
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
"connection closed, reason: "
|
||||
"{0.reason}".format(event))
|
||||
|
||||
def connection_check_out_started(self, event):
|
||||
def connection_check_out_started(self, event: monitoring.ConnectionCheckOutStartedEvent) -> None:
|
||||
logging.info("[pool {0.address}] connection check out "
|
||||
"started".format(event))
|
||||
|
||||
def connection_check_out_failed(self, event):
|
||||
def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None:
|
||||
logging.info("[pool {0.address}] connection check out "
|
||||
"failed, reason: {0.reason}".format(event))
|
||||
|
||||
def connection_checked_out(self, event):
|
||||
def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None:
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
"connection checked out of pool".format(event))
|
||||
|
||||
def connection_checked_in(self, event):
|
||||
def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None:
|
||||
logging.info("[pool {0.address}][conn #{0.connection_id}] "
|
||||
"connection checked into pool".format(event))
|
||||
|
||||
@ -14,10 +14,15 @@
|
||||
|
||||
"""Helpers for the 'hello' and legacy hello commands."""
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import itertools
|
||||
from typing import Any, Generic, List, Mapping, Optional, Set, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo import common
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import _DocumentType
|
||||
|
||||
|
||||
class HelloCompat:
|
||||
@ -56,7 +61,7 @@ def _get_server_type(doc):
|
||||
return SERVER_TYPE.Standalone
|
||||
|
||||
|
||||
class Hello(object):
|
||||
class Hello(Generic[_DocumentType]):
|
||||
"""Parse a hello response from the server.
|
||||
|
||||
.. versionadded:: 3.12
|
||||
@ -64,9 +69,9 @@ class Hello(object):
|
||||
__slots__ = ('_doc', '_server_type', '_is_writable', '_is_readable',
|
||||
'_awaitable')
|
||||
|
||||
def __init__(self, doc, awaitable=False):
|
||||
def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None:
|
||||
self._server_type = _get_server_type(doc)
|
||||
self._doc = doc
|
||||
self._doc: _DocumentType = doc
|
||||
self._is_writable = self._server_type in (
|
||||
SERVER_TYPE.RSPrimary,
|
||||
SERVER_TYPE.Standalone,
|
||||
@ -79,19 +84,19 @@ class Hello(object):
|
||||
self._awaitable = awaitable
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
def document(self) -> _DocumentType:
|
||||
"""The complete hello command response document.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
"""
|
||||
return self._doc.copy()
|
||||
return copy.copy(self._doc)
|
||||
|
||||
@property
|
||||
def server_type(self):
|
||||
def server_type(self) -> int:
|
||||
return self._server_type
|
||||
|
||||
@property
|
||||
def all_hosts(self):
|
||||
def all_hosts(self) -> Set[Tuple[str, int]]:
|
||||
"""List of hosts, passives, and arbiters known to this server."""
|
||||
return set(map(common.clean_node, itertools.chain(
|
||||
self._doc.get('hosts', []),
|
||||
@ -99,12 +104,12 @@ class Hello(object):
|
||||
self._doc.get('arbiters', []))))
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
def tags(self) -> Mapping[str, Any]:
|
||||
"""Replica set member tags or empty dict."""
|
||||
return self._doc.get('tags', {})
|
||||
|
||||
@property
|
||||
def primary(self):
|
||||
def primary(self) -> Optional[Tuple[str, int]]:
|
||||
"""This server's opinion about who the primary is, or None."""
|
||||
if self._doc.get('primary'):
|
||||
return common.partition_node(self._doc['primary'])
|
||||
@ -112,70 +117,71 @@ class Hello(object):
|
||||
return None
|
||||
|
||||
@property
|
||||
def replica_set_name(self):
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self._doc.get('setName')
|
||||
|
||||
@property
|
||||
def max_bson_size(self):
|
||||
def max_bson_size(self) -> int:
|
||||
return self._doc.get('maxBsonObjectSize', common.MAX_BSON_SIZE)
|
||||
|
||||
@property
|
||||
def max_message_size(self):
|
||||
def max_message_size(self) -> int:
|
||||
return self._doc.get('maxMessageSizeBytes', 2 * self.max_bson_size)
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self):
|
||||
def max_write_batch_size(self) -> int:
|
||||
return self._doc.get('maxWriteBatchSize', common.MAX_WRITE_BATCH_SIZE)
|
||||
|
||||
@property
|
||||
def min_wire_version(self):
|
||||
def min_wire_version(self) -> int:
|
||||
return self._doc.get('minWireVersion', common.MIN_WIRE_VERSION)
|
||||
|
||||
@property
|
||||
def max_wire_version(self):
|
||||
def max_wire_version(self) -> int:
|
||||
return self._doc.get('maxWireVersion', common.MAX_WIRE_VERSION)
|
||||
|
||||
@property
|
||||
def set_version(self):
|
||||
def set_version(self) -> Optional[int]:
|
||||
return self._doc.get('setVersion')
|
||||
|
||||
@property
|
||||
def election_id(self):
|
||||
def election_id(self) -> Optional[ObjectId]:
|
||||
return self._doc.get('electionId')
|
||||
|
||||
@property
|
||||
def cluster_time(self):
|
||||
def cluster_time(self) -> Optional[Mapping[str, Any]]:
|
||||
return self._doc.get('$clusterTime')
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self):
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
return self._doc.get('logicalSessionTimeoutMinutes')
|
||||
|
||||
@property
|
||||
def is_writable(self):
|
||||
def is_writable(self) -> bool:
|
||||
return self._is_writable
|
||||
|
||||
@property
|
||||
def is_readable(self):
|
||||
def is_readable(self) -> bool:
|
||||
return self._is_readable
|
||||
|
||||
@property
|
||||
def me(self):
|
||||
def me(self) -> Optional[Tuple[str, int]]:
|
||||
me = self._doc.get('me')
|
||||
if me:
|
||||
return common.clean_node(me)
|
||||
return None
|
||||
|
||||
@property
|
||||
def last_write_date(self):
|
||||
def last_write_date(self) -> Optional[datetime.datetime]:
|
||||
return self._doc.get('lastWrite', {}).get('lastWriteDate')
|
||||
|
||||
@property
|
||||
def compressors(self):
|
||||
def compressors(self) -> Optional[List[str]]:
|
||||
return self._doc.get('compression')
|
||||
|
||||
@property
|
||||
def sasl_supported_mechs(self):
|
||||
def sasl_supported_mechs(self) -> List[str]:
|
||||
"""Supported authentication mechanisms for the current user.
|
||||
|
||||
For example::
|
||||
@ -187,22 +193,22 @@ class Hello(object):
|
||||
return self._doc.get('saslSupportedMechs', [])
|
||||
|
||||
@property
|
||||
def speculative_authenticate(self):
|
||||
def speculative_authenticate(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The speculativeAuthenticate field."""
|
||||
return self._doc.get('speculativeAuthenticate')
|
||||
|
||||
@property
|
||||
def topology_version(self):
|
||||
def topology_version(self) -> Optional[Mapping[str, Any]]:
|
||||
return self._doc.get('topologyVersion')
|
||||
|
||||
@property
|
||||
def awaitable(self):
|
||||
def awaitable(self) -> bool:
|
||||
return self._awaitable
|
||||
|
||||
@property
|
||||
def service_id(self):
|
||||
def service_id(self) -> Optional[ObjectId]:
|
||||
return self._doc.get('serviceId')
|
||||
|
||||
@property
|
||||
def hello_ok(self):
|
||||
def hello_ok(self) -> bool:
|
||||
return self._doc.get('helloOk', False)
|
||||
|
||||
@ -16,18 +16,14 @@
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from collections import abc
|
||||
from typing import Any
|
||||
|
||||
from bson.son import SON
|
||||
from pymongo import ASCENDING
|
||||
from pymongo.errors import (CursorNotFound,
|
||||
DuplicateKeyError,
|
||||
ExecutionTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
WriteError,
|
||||
WriteConcernError,
|
||||
from pymongo.errors import (CursorNotFound, DuplicateKeyError,
|
||||
ExecutionTimeout, NotPrimaryError,
|
||||
OperationFailure, WriteConcernError, WriteError,
|
||||
WTimeoutError)
|
||||
from pymongo.hello import HelloCompat
|
||||
|
||||
@ -95,7 +91,7 @@ def _index_document(index_list):
|
||||
if not len(index_list):
|
||||
raise ValueError("key_or_list must not be the empty list")
|
||||
|
||||
index = SON()
|
||||
index: SON[str, Any] = SON()
|
||||
for (key, value) in index_list:
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
|
||||
@ -23,8 +23,8 @@ MongoDB.
|
||||
import datetime
|
||||
import random
|
||||
import struct
|
||||
|
||||
from io import BytesIO as _BytesIO
|
||||
from typing import Any
|
||||
|
||||
import bson
|
||||
from bson import (CodecOptions,
|
||||
@ -32,30 +32,24 @@ from bson import (CodecOptions,
|
||||
_decode_selective,
|
||||
_dict_to_bson,
|
||||
_make_c_string)
|
||||
from bson import codec_options
|
||||
from bson.int64 import Int64
|
||||
from bson.raw_bson import (_inflate_bson, DEFAULT_RAW_BSON_OPTIONS,
|
||||
RawBSONDocument)
|
||||
from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument,
|
||||
_inflate_bson)
|
||||
from bson.son import SON
|
||||
|
||||
try:
|
||||
from pymongo import _cmessage
|
||||
from pymongo import _cmessage # type: ignore[attr-defined]
|
||||
_use_c = True
|
||||
except ImportError:
|
||||
_use_c = False
|
||||
from pymongo.errors import (ConfigurationError,
|
||||
CursorNotFound,
|
||||
DocumentTooLarge,
|
||||
ExecutionTimeout,
|
||||
InvalidOperation,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError)
|
||||
from pymongo.errors import (ConfigurationError, CursorNotFound,
|
||||
DocumentTooLarge, ExecutionTimeout,
|
||||
InvalidOperation, NotPrimaryError,
|
||||
OperationFailure, ProtocolError)
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
MAX_INT32 = 2147483647
|
||||
MIN_INT32 = -2147483648
|
||||
|
||||
@ -457,6 +451,7 @@ class _RawBatchGetMore(_GetMore):
|
||||
|
||||
class _CursorAddress(tuple):
|
||||
"""The server address (host, port) of a cursor, with namespace property."""
|
||||
__namespace: Any
|
||||
|
||||
def __new__(cls, address, namespace):
|
||||
self = tuple.__new__(cls, address)
|
||||
@ -762,6 +757,7 @@ class _BulkWriteContext(object):
|
||||
"""A proxy for SocketInfo.unack_write that handles event publishing.
|
||||
"""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
cmd = self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
@ -777,6 +773,7 @@ class _BulkWriteContext(object):
|
||||
self._succeed(request_id, reply, duration)
|
||||
except Exception as exc:
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
if isinstance(exc, OperationFailure):
|
||||
failure = _convert_write_result(
|
||||
@ -795,6 +792,7 @@ class _BulkWriteContext(object):
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing.
|
||||
"""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
@ -1171,7 +1169,8 @@ class _OpReply(object):
|
||||
if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR):
|
||||
raise NotPrimaryError(error_object["$err"], error_object)
|
||||
elif error_object.get("code") == 50:
|
||||
raise ExecutionTimeout(error_object.get("$err"),
|
||||
default_msg = "operation exceeded time limit"
|
||||
raise ExecutionTimeout(error_object.get("$err", default_msg),
|
||||
error_object.get("code"),
|
||||
error_object)
|
||||
raise OperationFailure("database error: %s" %
|
||||
|
||||
@ -34,46 +34,41 @@ access:
|
||||
import contextlib
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import (TYPE_CHECKING, Any, Dict, FrozenSet, Generic, List,
|
||||
Mapping, Optional, Sequence, Set, Tuple, Type, Union, cast)
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS
|
||||
import bson
|
||||
from bson.codec_options import (DEFAULT_CODEC_OPTIONS, CodecOptions,
|
||||
TypeRegistry)
|
||||
from bson.son import SON
|
||||
from pymongo import (common,
|
||||
database,
|
||||
helpers,
|
||||
message,
|
||||
periodic_executor,
|
||||
uri_parser,
|
||||
client_session)
|
||||
from pymongo.change_stream import ClusterChangeStream
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import (client_session, common, database, helpers, message,
|
||||
periodic_executor, uri_parser)
|
||||
from pymongo.change_stream import ChangeStream, ClusterChangeStream
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
from pymongo.errors import (AutoReconnect,
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
ConnectionFailure,
|
||||
InvalidOperation,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
from pymongo.errors import (AutoReconnect, BulkWriteError, ConfigurationError,
|
||||
ConnectionFailure, InvalidOperation,
|
||||
NotPrimaryError, OperationFailure, PyMongoError,
|
||||
ServerSelectionTimeoutError)
|
||||
from pymongo.pool import ConnectionClosedReason
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.topology import (Topology,
|
||||
_ErrorContext)
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
from pymongo.settings import TopologySettings
|
||||
from pymongo.uri_parser import (_handle_option_deprecations,
|
||||
_handle_security_options,
|
||||
_normalize_options,
|
||||
_check_options)
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN
|
||||
from pymongo.topology import Topology, _ErrorContext
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
from pymongo.uri_parser import (_check_options, _handle_option_deprecations,
|
||||
_handle_security_options, _normalize_options)
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.read_concern import ReadConcern
|
||||
|
||||
|
||||
class MongoClient(common.BaseObject):
|
||||
class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
A client-side representation of a MongoDB cluster.
|
||||
|
||||
@ -89,15 +84,15 @@ class MongoClient(common.BaseObject):
|
||||
# No host/port; these are retrieved from TopologySettings.
|
||||
_constructor_args = ('document_class', 'tz_aware', 'connect')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host=None,
|
||||
port=None,
|
||||
document_class=dict,
|
||||
tz_aware=None,
|
||||
connect=None,
|
||||
type_registry=None,
|
||||
**kwargs):
|
||||
def __init__(self,
|
||||
host: Optional[Union[str, Sequence[str]]] = None,
|
||||
port: Optional[int] = None,
|
||||
document_class: Type[_DocumentType] = dict,
|
||||
tz_aware: Optional[bool] = None,
|
||||
connect: Optional[bool] = None,
|
||||
type_registry: Optional[TypeRegistry] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Client for a MongoDB instance, a replica set, or a set of mongoses.
|
||||
|
||||
The client object is thread-safe and has connection-pooling built in.
|
||||
@ -621,7 +616,7 @@ class MongoClient(common.BaseObject):
|
||||
|
||||
client.__my_database__
|
||||
"""
|
||||
self.__init_kwargs = {'host': host,
|
||||
self.__init_kwargs: Dict[str, Any] = {'host': host,
|
||||
'port': port,
|
||||
'document_class': document_class,
|
||||
'tz_aware': tz_aware,
|
||||
@ -722,7 +717,7 @@ class MongoClient(common.BaseObject):
|
||||
|
||||
self.__default_database_name = dbase
|
||||
self.__lock = threading.Lock()
|
||||
self.__kill_cursors_queue = []
|
||||
self.__kill_cursors_queue: List = []
|
||||
|
||||
self._event_listeners = options.pool_options._event_listeners
|
||||
super(MongoClient, self).__init__(options.codec_options,
|
||||
@ -765,7 +760,7 @@ class MongoClient(common.BaseObject):
|
||||
|
||||
# We strongly reference the executor and it weakly references us via
|
||||
# this closure. When the client is freed, stop the executor soon.
|
||||
self_ref = weakref.ref(self, executor.close)
|
||||
self_ref: Any = weakref.ref(self, executor.close)
|
||||
self._kill_cursors_executor = executor
|
||||
|
||||
if connect:
|
||||
@ -798,9 +793,17 @@ class MongoClient(common.BaseObject):
|
||||
|
||||
return getattr(server.description, attr_name)
|
||||
|
||||
def watch(self, pipeline=None, full_document=None, resume_after=None,
|
||||
max_await_time_ms=None, batch_size=None, collation=None,
|
||||
start_at_operation_time=None, session=None, start_after=None):
|
||||
def watch(self,
|
||||
pipeline: Optional[_Pipeline] = None,
|
||||
full_document: Optional[str] = None,
|
||||
resume_after: Optional[Mapping[str, Any]] = None,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
start_at_operation_time: Optional[Timestamp] = None,
|
||||
session: Optional[client_session.ClientSession] = None,
|
||||
start_after: Optional[Mapping[str, Any]] = None,
|
||||
) -> ChangeStream[_DocumentType]:
|
||||
"""Watch changes on this cluster.
|
||||
|
||||
Performs an aggregation with an implicit initial ``$changeStream``
|
||||
@ -891,7 +894,7 @@ class MongoClient(common.BaseObject):
|
||||
start_after)
|
||||
|
||||
@property
|
||||
def topology_description(self):
|
||||
def topology_description(self) -> TopologyDescription:
|
||||
"""The description of the connected MongoDB deployment.
|
||||
|
||||
>>> client.topology_description
|
||||
@ -913,7 +916,7 @@ class MongoClient(common.BaseObject):
|
||||
return self._topology.description
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
def address(self) -> Optional[Tuple[str, int]]:
|
||||
"""(host, port) of the current standalone, primary, or mongos, or None.
|
||||
|
||||
Accessing :attr:`address` raises :exc:`~.errors.InvalidOperation` if
|
||||
@ -940,7 +943,7 @@ class MongoClient(common.BaseObject):
|
||||
return self._server_property('address')
|
||||
|
||||
@property
|
||||
def primary(self):
|
||||
def primary(self) -> Optional[Tuple[str, int]]:
|
||||
"""The (host, port) of the current primary of the replica set.
|
||||
|
||||
Returns ``None`` if this client is not connected to a replica set,
|
||||
@ -953,7 +956,7 @@ class MongoClient(common.BaseObject):
|
||||
return self._topology.get_primary()
|
||||
|
||||
@property
|
||||
def secondaries(self):
|
||||
def secondaries(self) -> Set[Tuple[str, int]]:
|
||||
"""The secondary members known to this client.
|
||||
|
||||
A sequence of (host, port) pairs. Empty if this client is not
|
||||
@ -966,7 +969,7 @@ class MongoClient(common.BaseObject):
|
||||
return self._topology.get_secondaries()
|
||||
|
||||
@property
|
||||
def arbiters(self):
|
||||
def arbiters(self) -> Set[Tuple[str, int]]:
|
||||
"""Arbiters in the replica set.
|
||||
|
||||
A sequence of (host, port) pairs. Empty if this client is not
|
||||
@ -976,7 +979,7 @@ class MongoClient(common.BaseObject):
|
||||
return self._topology.get_arbiters()
|
||||
|
||||
@property
|
||||
def is_primary(self):
|
||||
def is_primary(self) -> bool:
|
||||
"""If this client is connected to a server that can accept writes.
|
||||
|
||||
True if the current server is a standalone, mongos, or the primary of
|
||||
@ -987,7 +990,7 @@ class MongoClient(common.BaseObject):
|
||||
return self._server_property('is_writable')
|
||||
|
||||
@property
|
||||
def is_mongos(self):
|
||||
def is_mongos(self) -> bool:
|
||||
"""If this client is connected to mongos. If the client is not
|
||||
connected, this will block until a connection is established or raise
|
||||
ServerSelectionTimeoutError if no server is available..
|
||||
@ -995,7 +998,7 @@ class MongoClient(common.BaseObject):
|
||||
return self._server_property('server_type') == SERVER_TYPE.Mongos
|
||||
|
||||
@property
|
||||
def nodes(self):
|
||||
def nodes(self) -> FrozenSet[Tuple[str, Optional[int]]]:
|
||||
"""Set of all currently connected servers.
|
||||
|
||||
.. warning:: When connected to a replica set the value of :attr:`nodes`
|
||||
@ -1009,7 +1012,7 @@ class MongoClient(common.BaseObject):
|
||||
return frozenset(s.address for s in description.known_servers)
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
def options(self) -> ClientOptions:
|
||||
"""The configuration options for this client.
|
||||
|
||||
:Returns:
|
||||
@ -1040,7 +1043,7 @@ class MongoClient(common.BaseObject):
|
||||
# command.
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Cleanup client resources and disconnect from MongoDB.
|
||||
|
||||
End all server sessions created by this client by sending one or more
|
||||
@ -1214,7 +1217,7 @@ class MongoClient(common.BaseObject):
|
||||
def _retry_internal(self, retryable, func, session, bulk):
|
||||
"""Internal retryable write helper."""
|
||||
max_wire_version = 0
|
||||
last_error = None
|
||||
last_error: Optional[Exception] = None
|
||||
retrying = False
|
||||
|
||||
def is_retrying():
|
||||
@ -1239,6 +1242,7 @@ class MongoClient(common.BaseObject):
|
||||
if is_retrying():
|
||||
# A retry is not possible because this server does
|
||||
# not support sessions raise the last error.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
retryable = False
|
||||
return func(session, sock_info, retryable)
|
||||
@ -1247,6 +1251,7 @@ class MongoClient(common.BaseObject):
|
||||
# The application may think the write was never attempted
|
||||
# if we raise ServerSelectionTimeoutError on the retry
|
||||
# attempt. Raise the original exception instead.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
# A ServerSelectionTimeoutError error indicates that there may
|
||||
# be a persistent outage. Attempting to retry in this case will
|
||||
@ -1280,7 +1285,7 @@ class MongoClient(common.BaseObject):
|
||||
retryable = (retryable and
|
||||
self.options.retry_reads
|
||||
and not (session and session.in_transaction))
|
||||
last_error = None
|
||||
last_error: Optional[Exception] = None
|
||||
retrying = False
|
||||
|
||||
while True:
|
||||
@ -1292,6 +1297,7 @@ class MongoClient(common.BaseObject):
|
||||
if retrying and not retryable:
|
||||
# A retry is not possible because this server does
|
||||
# not support retryable reads, raise the last error.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
return func(session, server, sock_info, read_pref)
|
||||
except ServerSelectionTimeoutError:
|
||||
@ -1299,6 +1305,7 @@ class MongoClient(common.BaseObject):
|
||||
# The application may think the write was never attempted
|
||||
# if we raise ServerSelectionTimeoutError on the retry
|
||||
# attempt. Raise the original exception instead.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
# A ServerSelectionTimeoutError error indicates that there may
|
||||
# be a persistent outage. Attempting to retry in this case will
|
||||
@ -1322,15 +1329,15 @@ class MongoClient(common.BaseObject):
|
||||
with self._tmp_session(session) as s:
|
||||
return self._retry_with_session(retryable, func, s, None)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return self._topology == other._topology
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._topology)
|
||||
|
||||
def _repr_helper(self):
|
||||
@ -1366,7 +1373,7 @@ class MongoClient(common.BaseObject):
|
||||
def __repr__(self):
|
||||
return ("MongoClient(%s)" % (self._repr_helper(),))
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> database.Database[_DocumentType]:
|
||||
"""Get a database by name.
|
||||
|
||||
Raises :class:`~pymongo.errors.InvalidName` if an invalid
|
||||
@ -1381,7 +1388,7 @@ class MongoClient(common.BaseObject):
|
||||
" database, use client[%r]." % (name, name, name))
|
||||
return self.__getitem__(name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
def __getitem__(self, name: str) -> database.Database[_DocumentType]:
|
||||
"""Get a database by name.
|
||||
|
||||
Raises :class:`~pymongo.errors.InvalidName` if an invalid
|
||||
@ -1539,9 +1546,10 @@ class MongoClient(common.BaseObject):
|
||||
self, server_session, opts, implicit)
|
||||
|
||||
def start_session(self,
|
||||
causal_consistency=None,
|
||||
default_transaction_options=None,
|
||||
snapshot=False):
|
||||
causal_consistency: Optional[bool] = None,
|
||||
default_transaction_options: Optional[client_session.TransactionOptions] = None,
|
||||
snapshot: Optional[bool] = False,
|
||||
) -> client_session.ClientSession[_DocumentType]:
|
||||
"""Start a logical session.
|
||||
|
||||
This method takes the same parameters as
|
||||
@ -1630,7 +1638,9 @@ class MongoClient(common.BaseObject):
|
||||
if session is not None:
|
||||
session._process_response(reply)
|
||||
|
||||
def server_info(self, session=None):
|
||||
def server_info(self,
|
||||
session: Optional[client_session.ClientSession] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get information about the MongoDB server we're connected to.
|
||||
|
||||
:Parameters:
|
||||
@ -1644,7 +1654,10 @@ class MongoClient(common.BaseObject):
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
session=session)
|
||||
|
||||
def list_databases(self, session=None, **kwargs):
|
||||
def list_databases(self,
|
||||
session: Optional[client_session.ClientSession] = None,
|
||||
**kwargs: Any
|
||||
) -> CommandCursor[Dict[str, Any]]:
|
||||
"""Get a cursor over the databases of the connected server.
|
||||
|
||||
:Parameters:
|
||||
@ -1673,7 +1686,9 @@ class MongoClient(common.BaseObject):
|
||||
}
|
||||
return CommandCursor(admin["$cmd"], cursor, None)
|
||||
|
||||
def list_database_names(self, session=None):
|
||||
def list_database_names(self,
|
||||
session: Optional[client_session.ClientSession] = None
|
||||
) -> List[str]:
|
||||
"""Get a list of the names of all databases on the connected server.
|
||||
|
||||
:Parameters:
|
||||
@ -1685,7 +1700,10 @@ class MongoClient(common.BaseObject):
|
||||
return [doc["name"]
|
||||
for doc in self.list_databases(session, nameOnly=True)]
|
||||
|
||||
def drop_database(self, name_or_database, session=None):
|
||||
def drop_database(self,
|
||||
name_or_database: Union[str, database.Database],
|
||||
session: Optional[client_session.ClientSession] = None
|
||||
) -> None:
|
||||
"""Drop a database.
|
||||
|
||||
Raises :class:`TypeError` if `name_or_database` is not an instance of
|
||||
@ -1727,8 +1745,13 @@ class MongoClient(common.BaseObject):
|
||||
parse_write_concern_error=True,
|
||||
session=session)
|
||||
|
||||
def get_default_database(self, default=None, codec_options=None,
|
||||
read_preference=None, write_concern=None, read_concern=None):
|
||||
def get_default_database(self,
|
||||
default: Optional[str] = None,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional["ReadConcern"] = None,
|
||||
) -> database.Database[_DocumentType]:
|
||||
"""Get the database named in the MongoDB connection URI.
|
||||
|
||||
>>> uri = 'mongodb://host/my_database'
|
||||
@ -1773,12 +1796,18 @@ class MongoClient(common.BaseObject):
|
||||
raise ConfigurationError(
|
||||
'No default database name defined or provided.')
|
||||
|
||||
name = cast(str, self.__default_database_name or default)
|
||||
return database.Database(
|
||||
self, self.__default_database_name or default, codec_options,
|
||||
self, name, codec_options,
|
||||
read_preference, write_concern, read_concern)
|
||||
|
||||
def get_database(self, name=None, codec_options=None, read_preference=None,
|
||||
write_concern=None, read_concern=None):
|
||||
def get_database(self,
|
||||
name: Optional[str] = None,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional["ReadConcern"] = None,
|
||||
) -> database.Database[_DocumentType]:
|
||||
"""Get a :class:`~pymongo.database.Database` with the given name and
|
||||
options.
|
||||
|
||||
@ -1838,16 +1867,16 @@ class MongoClient(common.BaseObject):
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
write_concern=DEFAULT_WRITE_CONCERN)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "MongoClient[_DocumentType]":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> "MongoClient[_DocumentType]":
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
def __next__(self) -> None:
|
||||
raise TypeError("'MongoClient' object is not iterable")
|
||||
|
||||
next = __next__
|
||||
|
||||
@ -18,10 +18,10 @@ import atexit
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Mapping, cast
|
||||
|
||||
from pymongo import common, periodic_executor
|
||||
from pymongo.errors import (NotPrimaryError,
|
||||
OperationFailure,
|
||||
from pymongo.errors import (NotPrimaryError, OperationFailure,
|
||||
_OperationCancelled)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.periodic_executor import _shutdown_executors
|
||||
@ -50,7 +50,7 @@ class MonitorBase(object):
|
||||
monitor = self_ref()
|
||||
if monitor is None:
|
||||
return False # Stop the executor.
|
||||
monitor._run()
|
||||
monitor._run() # type:ignore[attr-defined]
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
@ -214,8 +214,8 @@ class Monitor(MonitorBase):
|
||||
return self._check_once()
|
||||
except (OperationFailure, NotPrimaryError) as exc:
|
||||
# Update max cluster time even when hello fails.
|
||||
self._topology.receive_cluster_time(
|
||||
exc.details.get('$clusterTime'))
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
self._topology.receive_cluster_time(details.get('$clusterTime'))
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
|
||||
@ -180,12 +180,21 @@ will not add that listener to existing client instances.
|
||||
handler first.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from collections import abc, namedtuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional
|
||||
|
||||
from pymongo.hello import HelloCompat
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.helpers import _handle_exception
|
||||
from pymongo.typings import _Address
|
||||
|
||||
_Listeners = namedtuple('Listeners',
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.topology_description import TopologyDescription
|
||||
|
||||
|
||||
_Listeners = namedtuple('_Listeners',
|
||||
('command_listeners', 'server_listeners',
|
||||
'server_heartbeat_listeners', 'topology_listeners',
|
||||
'cmap_listeners'))
|
||||
@ -193,6 +202,9 @@ _Listeners = namedtuple('Listeners',
|
||||
_LISTENERS = _Listeners([], [], [], [], [])
|
||||
|
||||
|
||||
_DocumentOut = Mapping[str, Any]
|
||||
|
||||
|
||||
class _EventListener(object):
|
||||
"""Abstract base class for all event listeners."""
|
||||
|
||||
@ -204,7 +216,7 @@ class CommandListener(_EventListener):
|
||||
and `CommandFailedEvent`.
|
||||
"""
|
||||
|
||||
def started(self, event):
|
||||
def started(self, event: "CommandStartedEvent") -> None:
|
||||
"""Abstract method to handle a `CommandStartedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -212,7 +224,7 @@ class CommandListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def succeeded(self, event):
|
||||
def succeeded(self, event: "CommandSucceededEvent") -> None:
|
||||
"""Abstract method to handle a `CommandSucceededEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -220,7 +232,7 @@ class CommandListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def failed(self, event):
|
||||
def failed(self, event: "CommandFailedEvent") -> None:
|
||||
"""Abstract method to handle a `CommandFailedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -245,7 +257,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
|
||||
def pool_created(self, event):
|
||||
def pool_created(self, event: "PoolCreatedEvent") -> None:
|
||||
"""Abstract method to handle a :class:`PoolCreatedEvent`.
|
||||
|
||||
Emitted when a Connection Pool is created.
|
||||
@ -255,7 +267,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def pool_ready(self, event):
|
||||
def pool_ready(self, event: "PoolReadyEvent") -> None:
|
||||
"""Abstract method to handle a :class:`PoolReadyEvent`.
|
||||
|
||||
Emitted when a Connection Pool is marked ready.
|
||||
@ -267,7 +279,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def pool_cleared(self, event):
|
||||
def pool_cleared(self, event: "PoolClearedEvent") -> None:
|
||||
"""Abstract method to handle a `PoolClearedEvent`.
|
||||
|
||||
Emitted when a Connection Pool is cleared.
|
||||
@ -277,7 +289,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def pool_closed(self, event):
|
||||
def pool_closed(self, event: "PoolClosedEvent") -> None:
|
||||
"""Abstract method to handle a `PoolClosedEvent`.
|
||||
|
||||
Emitted when a Connection Pool is closed.
|
||||
@ -287,7 +299,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def connection_created(self, event):
|
||||
def connection_created(self, event: "ConnectionCreatedEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionCreatedEvent`.
|
||||
|
||||
Emitted when a Connection Pool creates a Connection object.
|
||||
@ -297,7 +309,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def connection_ready(self, event):
|
||||
def connection_ready(self, event: "ConnectionReadyEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionReadyEvent`.
|
||||
|
||||
Emitted when a Connection has finished its setup, and is now ready to
|
||||
@ -308,7 +320,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def connection_closed(self, event):
|
||||
def connection_closed(self, event: "ConnectionClosedEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionClosedEvent`.
|
||||
|
||||
Emitted when a Connection Pool closes a Connection.
|
||||
@ -318,7 +330,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def connection_check_out_started(self, event):
|
||||
def connection_check_out_started(self, event: "ConnectionCheckOutStartedEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`.
|
||||
|
||||
Emitted when the driver starts attempting to check out a connection.
|
||||
@ -328,7 +340,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def connection_check_out_failed(self, event):
|
||||
def connection_check_out_failed(self, event: "ConnectionCheckOutFailedEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`.
|
||||
|
||||
Emitted when the driver's attempt to check out a connection fails.
|
||||
@ -338,7 +350,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def connection_checked_out(self, event):
|
||||
def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionCheckedOutEvent`.
|
||||
|
||||
Emitted when the driver successfully checks out a Connection.
|
||||
@ -348,7 +360,7 @@ class ConnectionPoolListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def connection_checked_in(self, event):
|
||||
def connection_checked_in(self, event: "ConnectionCheckedInEvent") -> None:
|
||||
"""Abstract method to handle a :class:`ConnectionCheckedInEvent`.
|
||||
|
||||
Emitted when the driver checks in a Connection back to the Connection
|
||||
@ -369,7 +381,7 @@ class ServerHeartbeatListener(_EventListener):
|
||||
.. versionadded:: 3.3
|
||||
"""
|
||||
|
||||
def started(self, event):
|
||||
def started(self, event: "ServerHeartbeatStartedEvent") -> None:
|
||||
"""Abstract method to handle a `ServerHeartbeatStartedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -377,7 +389,7 @@ class ServerHeartbeatListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def succeeded(self, event):
|
||||
def succeeded(self, event: "ServerHeartbeatSucceededEvent") -> None:
|
||||
"""Abstract method to handle a `ServerHeartbeatSucceededEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -385,7 +397,7 @@ class ServerHeartbeatListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def failed(self, event):
|
||||
def failed(self, event: "ServerHeartbeatFailedEvent") -> None:
|
||||
"""Abstract method to handle a `ServerHeartbeatFailedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -402,7 +414,7 @@ class TopologyListener(_EventListener):
|
||||
.. versionadded:: 3.3
|
||||
"""
|
||||
|
||||
def opened(self, event):
|
||||
def opened(self, event: "TopologyOpenedEvent") -> None:
|
||||
"""Abstract method to handle a `TopologyOpenedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -410,7 +422,7 @@ class TopologyListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def description_changed(self, event):
|
||||
def description_changed(self, event: "TopologyDescriptionChangedEvent") -> None:
|
||||
"""Abstract method to handle a `TopologyDescriptionChangedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -418,7 +430,7 @@ class TopologyListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def closed(self, event):
|
||||
def closed(self, event: "TopologyClosedEvent") -> None:
|
||||
"""Abstract method to handle a `TopologyClosedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -435,7 +447,7 @@ class ServerListener(_EventListener):
|
||||
.. versionadded:: 3.3
|
||||
"""
|
||||
|
||||
def opened(self, event):
|
||||
def opened(self, event: "ServerOpeningEvent") -> None:
|
||||
"""Abstract method to handle a `ServerOpeningEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -443,7 +455,7 @@ class ServerListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def description_changed(self, event):
|
||||
def description_changed(self, event: "ServerDescriptionChangedEvent") -> None:
|
||||
"""Abstract method to handle a `ServerDescriptionChangedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -451,7 +463,7 @@ class ServerListener(_EventListener):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def closed(self, event):
|
||||
def closed(self, event: "ServerClosedEvent") -> None:
|
||||
"""Abstract method to handle a `ServerClosedEvent`.
|
||||
|
||||
:Parameters:
|
||||
@ -478,7 +490,7 @@ def _validate_event_listeners(option, listeners):
|
||||
return listeners
|
||||
|
||||
|
||||
def register(listener):
|
||||
def register(listener: _EventListener) -> None:
|
||||
"""Register a global event listener.
|
||||
|
||||
:Parameters:
|
||||
@ -525,8 +537,14 @@ class _CommandEvent(object):
|
||||
__slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id",
|
||||
"__service_id")
|
||||
|
||||
def __init__(self, command_name, request_id, connection_id, operation_id,
|
||||
service_id=None):
|
||||
def __init__(
|
||||
self,
|
||||
command_name: str,
|
||||
request_id: int,
|
||||
connection_id: _Address,
|
||||
operation_id: Optional[int],
|
||||
service_id: Optional[ObjectId] = None,
|
||||
) -> None:
|
||||
self.__cmd_name = command_name
|
||||
self.__rqst_id = request_id
|
||||
self.__conn_id = connection_id
|
||||
@ -534,22 +552,22 @@ class _CommandEvent(object):
|
||||
self.__service_id = service_id
|
||||
|
||||
@property
|
||||
def command_name(self):
|
||||
def command_name(self) -> str:
|
||||
"""The command name."""
|
||||
return self.__cmd_name
|
||||
|
||||
@property
|
||||
def request_id(self):
|
||||
def request_id(self) -> int:
|
||||
"""The request id for this operation."""
|
||||
return self.__rqst_id
|
||||
|
||||
@property
|
||||
def connection_id(self):
|
||||
def connection_id(self) -> _Address:
|
||||
"""The address (host, port) of the server this command was sent to."""
|
||||
return self.__conn_id
|
||||
|
||||
@property
|
||||
def service_id(self):
|
||||
def service_id(self) -> Optional[ObjectId]:
|
||||
"""The service_id this command was sent to, or ``None``.
|
||||
|
||||
.. versionadded:: 3.12
|
||||
@ -557,7 +575,7 @@ class _CommandEvent(object):
|
||||
return self.__service_id
|
||||
|
||||
@property
|
||||
def operation_id(self):
|
||||
def operation_id(self) -> Optional[int]:
|
||||
"""An id for this series of events or None."""
|
||||
return self.__op_id
|
||||
|
||||
@ -576,28 +594,36 @@ class CommandStartedEvent(_CommandEvent):
|
||||
"""
|
||||
__slots__ = ("__cmd", "__db")
|
||||
|
||||
def __init__(self, command, database_name, *args, service_id=None):
|
||||
def __init__(
|
||||
self,
|
||||
command: _DocumentOut,
|
||||
database_name: str,
|
||||
request_id: int,
|
||||
connection_id: _Address,
|
||||
operation_id: Optional[int],
|
||||
service_id: Optional[ObjectId] = None,
|
||||
) -> None:
|
||||
if not command:
|
||||
raise ValueError("%r is not a valid command" % (command,))
|
||||
# Command name must be first key.
|
||||
command_name = next(iter(command))
|
||||
super(CommandStartedEvent, self).__init__(
|
||||
command_name, *args, service_id=service_id)
|
||||
command_name, request_id, connection_id, operation_id, service_id=service_id)
|
||||
cmd_name, cmd_doc = command_name.lower(), command[command_name]
|
||||
if (cmd_name in _SENSITIVE_COMMANDS or
|
||||
_is_speculative_authenticate(cmd_name, command)):
|
||||
self.__cmd = {}
|
||||
self.__cmd: Mapping[str, Any] = {}
|
||||
else:
|
||||
self.__cmd = command
|
||||
self.__db = database_name
|
||||
|
||||
@property
|
||||
def command(self):
|
||||
def command(self) -> _DocumentOut:
|
||||
"""The command document."""
|
||||
return self.__cmd
|
||||
|
||||
@property
|
||||
def database_name(self):
|
||||
def database_name(self) -> str:
|
||||
"""The name of the database this command was run against."""
|
||||
return self.__db
|
||||
|
||||
@ -625,8 +651,16 @@ class CommandSucceededEvent(_CommandEvent):
|
||||
"""
|
||||
__slots__ = ("__duration_micros", "__reply")
|
||||
|
||||
def __init__(self, duration, reply, command_name,
|
||||
request_id, connection_id, operation_id, service_id=None):
|
||||
def __init__(
|
||||
self,
|
||||
duration: datetime.timedelta,
|
||||
reply: _DocumentOut,
|
||||
command_name: str,
|
||||
request_id: int,
|
||||
connection_id: _Address,
|
||||
operation_id: Optional[int],
|
||||
service_id: Optional[ObjectId] = None,
|
||||
) -> None:
|
||||
super(CommandSucceededEvent, self).__init__(
|
||||
command_name, request_id, connection_id, operation_id,
|
||||
service_id=service_id)
|
||||
@ -634,17 +668,17 @@ class CommandSucceededEvent(_CommandEvent):
|
||||
cmd_name = command_name.lower()
|
||||
if (cmd_name in _SENSITIVE_COMMANDS or
|
||||
_is_speculative_authenticate(cmd_name, reply)):
|
||||
self.__reply = {}
|
||||
self.__reply: Mapping[str, Any] = {}
|
||||
else:
|
||||
self.__reply = reply
|
||||
|
||||
@property
|
||||
def duration_micros(self):
|
||||
def duration_micros(self) -> int:
|
||||
"""The duration of this operation in microseconds."""
|
||||
return self.__duration_micros
|
||||
|
||||
@property
|
||||
def reply(self):
|
||||
def reply(self) -> _DocumentOut:
|
||||
"""The server failure document for this operation."""
|
||||
return self.__reply
|
||||
|
||||
@ -672,18 +706,27 @@ class CommandFailedEvent(_CommandEvent):
|
||||
"""
|
||||
__slots__ = ("__duration_micros", "__failure")
|
||||
|
||||
def __init__(self, duration, failure, *args, service_id=None):
|
||||
super(CommandFailedEvent, self).__init__(*args, service_id=service_id)
|
||||
def __init__(
|
||||
self,
|
||||
duration: datetime.timedelta,
|
||||
failure: _DocumentOut,
|
||||
command_name: str,
|
||||
request_id: int,
|
||||
connection_id: _Address,
|
||||
operation_id: Optional[int],
|
||||
service_id: Optional[ObjectId] = None,
|
||||
) -> None:
|
||||
super(CommandFailedEvent, self).__init__(command_name, request_id, connection_id, operation_id, service_id=service_id)
|
||||
self.__duration_micros = _to_micros(duration)
|
||||
self.__failure = failure
|
||||
|
||||
@property
|
||||
def duration_micros(self):
|
||||
def duration_micros(self) -> int:
|
||||
"""The duration of this operation in microseconds."""
|
||||
return self.__duration_micros
|
||||
|
||||
@property
|
||||
def failure(self):
|
||||
def failure(self) -> _DocumentOut:
|
||||
"""The server failure document for this operation."""
|
||||
return self.__failure
|
||||
|
||||
@ -700,11 +743,11 @@ class _PoolEvent(object):
|
||||
"""Base class for pool events."""
|
||||
__slots__ = ("__address",)
|
||||
|
||||
def __init__(self, address):
|
||||
def __init__(self, address: _Address) -> None:
|
||||
self.__address = address
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
def address(self) -> _Address:
|
||||
"""The address (host, port) pair of the server the pool is attempting
|
||||
to connect to.
|
||||
"""
|
||||
@ -725,12 +768,12 @@ class PoolCreatedEvent(_PoolEvent):
|
||||
"""
|
||||
__slots__ = ("__options",)
|
||||
|
||||
def __init__(self, address, options):
|
||||
def __init__(self, address: _Address, options: Dict[str, Any]) -> None:
|
||||
super(PoolCreatedEvent, self).__init__(address)
|
||||
self.__options = options
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
def options(self) -> Dict[str, Any]:
|
||||
"""Any non-default pool options that were set on this Connection Pool.
|
||||
"""
|
||||
return self.__options
|
||||
@ -764,12 +807,12 @@ class PoolClearedEvent(_PoolEvent):
|
||||
"""
|
||||
__slots__ = ("__service_id",)
|
||||
|
||||
def __init__(self, address, service_id=None):
|
||||
def __init__(self, address: _Address, service_id: Optional[ObjectId] = None) -> None:
|
||||
super(PoolClearedEvent, self).__init__(address)
|
||||
self.__service_id = service_id
|
||||
|
||||
@property
|
||||
def service_id(self):
|
||||
def service_id(self) -> Optional[ObjectId]:
|
||||
"""Connections with this service_id are cleared.
|
||||
|
||||
When service_id is ``None``, all connections in the pool are cleared.
|
||||
@ -839,19 +882,19 @@ class _ConnectionEvent(object):
|
||||
"""Private base class for some connection events."""
|
||||
__slots__ = ("__address", "__connection_id")
|
||||
|
||||
def __init__(self, address, connection_id):
|
||||
def __init__(self, address: _Address, connection_id: int) -> None:
|
||||
self.__address = address
|
||||
self.__connection_id = connection_id
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
def address(self) -> _Address:
|
||||
"""The address (host, port) pair of the server this connection is
|
||||
attempting to connect to.
|
||||
"""
|
||||
return self.__address
|
||||
|
||||
@property
|
||||
def connection_id(self):
|
||||
def connection_id(self) -> int:
|
||||
"""The ID of the Connection."""
|
||||
return self.__connection_id
|
||||
|
||||
@ -958,19 +1001,19 @@ class ConnectionCheckOutFailedEvent(object):
|
||||
"""
|
||||
__slots__ = ("__address", "__reason")
|
||||
|
||||
def __init__(self, address, reason):
|
||||
def __init__(self, address: _Address, reason: str) -> None:
|
||||
self.__address = address
|
||||
self.__reason = reason
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
def address(self) -> _Address:
|
||||
"""The address (host, port) pair of the server this connection is
|
||||
attempting to connect to.
|
||||
"""
|
||||
return self.__address
|
||||
|
||||
@property
|
||||
def reason(self):
|
||||
def reason(self) -> str:
|
||||
"""A reason explaining why connection check out failed.
|
||||
|
||||
The reason must be one of the strings from the
|
||||
@ -1014,17 +1057,17 @@ class _ServerEvent(object):
|
||||
|
||||
__slots__ = ("__server_address", "__topology_id")
|
||||
|
||||
def __init__(self, server_address, topology_id):
|
||||
def __init__(self, server_address: _Address, topology_id: ObjectId) -> None:
|
||||
self.__server_address = server_address
|
||||
self.__topology_id = topology_id
|
||||
|
||||
@property
|
||||
def server_address(self):
|
||||
def server_address(self) -> _Address:
|
||||
"""The address (host, port) pair of the server"""
|
||||
return self.__server_address
|
||||
|
||||
@property
|
||||
def topology_id(self):
|
||||
def topology_id(self) -> ObjectId:
|
||||
"""A unique identifier for the topology this server is a part of."""
|
||||
return self.__topology_id
|
||||
|
||||
@ -1041,19 +1084,19 @@ class ServerDescriptionChangedEvent(_ServerEvent):
|
||||
|
||||
__slots__ = ('__previous_description', '__new_description')
|
||||
|
||||
def __init__(self, previous_description, new_description, *args):
|
||||
def __init__(self, previous_description: "ServerDescription", new_description: "ServerDescription", *args: Any) -> None:
|
||||
super(ServerDescriptionChangedEvent, self).__init__(*args)
|
||||
self.__previous_description = previous_description
|
||||
self.__new_description = new_description
|
||||
|
||||
@property
|
||||
def previous_description(self):
|
||||
def previous_description(self) -> "ServerDescription":
|
||||
"""The previous
|
||||
:class:`~pymongo.server_description.ServerDescription`."""
|
||||
return self.__previous_description
|
||||
|
||||
@property
|
||||
def new_description(self):
|
||||
def new_description(self) -> "ServerDescription":
|
||||
"""The new
|
||||
:class:`~pymongo.server_description.ServerDescription`."""
|
||||
return self.__new_description
|
||||
@ -1087,11 +1130,11 @@ class TopologyEvent(object):
|
||||
|
||||
__slots__ = ('__topology_id')
|
||||
|
||||
def __init__(self, topology_id):
|
||||
def __init__(self, topology_id: ObjectId) -> None:
|
||||
self.__topology_id = topology_id
|
||||
|
||||
@property
|
||||
def topology_id(self):
|
||||
def topology_id(self) -> ObjectId:
|
||||
"""A unique identifier for the topology this server is a part of."""
|
||||
return self.__topology_id
|
||||
|
||||
@ -1108,19 +1151,19 @@ class TopologyDescriptionChangedEvent(TopologyEvent):
|
||||
|
||||
__slots__ = ('__previous_description', '__new_description')
|
||||
|
||||
def __init__(self, previous_description, new_description, *args):
|
||||
def __init__(self, previous_description: "TopologyDescription", new_description: "TopologyDescription", *args: Any) -> None:
|
||||
super(TopologyDescriptionChangedEvent, self).__init__(*args)
|
||||
self.__previous_description = previous_description
|
||||
self.__new_description = new_description
|
||||
|
||||
@property
|
||||
def previous_description(self):
|
||||
def previous_description(self) -> "TopologyDescription":
|
||||
"""The previous
|
||||
:class:`~pymongo.topology_description.TopologyDescription`."""
|
||||
return self.__previous_description
|
||||
|
||||
@property
|
||||
def new_description(self):
|
||||
def new_description(self) -> "TopologyDescription":
|
||||
"""The new
|
||||
:class:`~pymongo.topology_description.TopologyDescription`."""
|
||||
return self.__new_description
|
||||
@ -1154,11 +1197,11 @@ class _ServerHeartbeatEvent(object):
|
||||
|
||||
__slots__ = ('__connection_id')
|
||||
|
||||
def __init__(self, connection_id):
|
||||
def __init__(self, connection_id: _Address) -> None:
|
||||
self.__connection_id = connection_id
|
||||
|
||||
@property
|
||||
def connection_id(self):
|
||||
def connection_id(self) -> _Address:
|
||||
"""The address (host, port) of the server this heartbeat was sent
|
||||
to."""
|
||||
return self.__connection_id
|
||||
@ -1184,24 +1227,24 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
|
||||
|
||||
__slots__ = ('__duration', '__reply', '__awaited')
|
||||
|
||||
def __init__(self, duration, reply, connection_id, awaited=False):
|
||||
def __init__(self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False) -> None:
|
||||
super(ServerHeartbeatSucceededEvent, self).__init__(connection_id)
|
||||
self.__duration = duration
|
||||
self.__reply = reply
|
||||
self.__awaited = awaited
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
def duration(self) -> float:
|
||||
"""The duration of this heartbeat in microseconds."""
|
||||
return self.__duration
|
||||
|
||||
@property
|
||||
def reply(self):
|
||||
def reply(self) -> Hello:
|
||||
"""An instance of :class:`~pymongo.hello.Hello`."""
|
||||
return self.__reply
|
||||
|
||||
@property
|
||||
def awaited(self):
|
||||
def awaited(self) -> bool:
|
||||
"""Whether the heartbeat was awaited.
|
||||
|
||||
If true, then :meth:`duration` reflects the sum of the round trip time
|
||||
@ -1225,24 +1268,24 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent):
|
||||
|
||||
__slots__ = ('__duration', '__reply', '__awaited')
|
||||
|
||||
def __init__(self, duration, reply, connection_id, awaited=False):
|
||||
def __init__(self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False) -> None:
|
||||
super(ServerHeartbeatFailedEvent, self).__init__(connection_id)
|
||||
self.__duration = duration
|
||||
self.__reply = reply
|
||||
self.__awaited = awaited
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
def duration(self) -> float:
|
||||
"""The duration of this heartbeat in microseconds."""
|
||||
return self.__duration
|
||||
|
||||
@property
|
||||
def reply(self):
|
||||
def reply(self) -> Exception:
|
||||
"""A subclass of :exc:`Exception`."""
|
||||
return self.__reply
|
||||
|
||||
@property
|
||||
def awaited(self):
|
||||
def awaited(self) -> bool:
|
||||
"""Whether the heartbeat was awaited.
|
||||
|
||||
If true, then :meth:`duration` reflects the sum of the round trip time
|
||||
|
||||
@ -20,21 +20,16 @@ import socket
|
||||
import struct
|
||||
import time
|
||||
|
||||
|
||||
from bson import _decode_all_selective
|
||||
|
||||
from pymongo import helpers, message
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.compression_support import decompress, _NO_COMPRESSION
|
||||
from pymongo.errors import (NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError,
|
||||
from pymongo.compression_support import _NO_COMPRESSION, decompress
|
||||
from pymongo.errors import (NotPrimaryError, OperationFailure, ProtocolError,
|
||||
_OperationCancelled)
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg
|
||||
from pymongo.monitoring import _is_speculative_authenticate
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
|
||||
_UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from threading import Lock
|
||||
|
||||
class _OCSPCache(object):
|
||||
"""A cache for OCSP responses."""
|
||||
CACHE_KEY_TYPE = namedtuple('OcspResponseCacheKey',
|
||||
CACHE_KEY_TYPE = namedtuple('OcspResponseCacheKey', # type: ignore
|
||||
['hash_algorithm', 'issuer_name_hash',
|
||||
'issuer_key_hash', 'serial_number'])
|
||||
|
||||
|
||||
@ -16,41 +16,40 @@
|
||||
|
||||
import logging as _logging
|
||||
import re as _re
|
||||
|
||||
from datetime import datetime as _datetime
|
||||
|
||||
from cryptography.exceptions import InvalidSignature as _InvalidSignature
|
||||
from cryptography.hazmat.backends import default_backend as _default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import (
|
||||
DSAPublicKey as _DSAPublicKey)
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
ECDSA as _ECDSA,
|
||||
EllipticCurvePublicKey as _EllipticCurvePublicKey)
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import (
|
||||
PKCS1v15 as _PKCS1v15)
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
||||
RSAPublicKey as _RSAPublicKey)
|
||||
from cryptography.hazmat.primitives.hashes import (
|
||||
Hash as _Hash,
|
||||
SHA1 as _SHA1)
|
||||
from cryptography.hazmat.primitives.serialization import (
|
||||
Encoding as _Encoding,
|
||||
PublicFormat as _PublicFormat)
|
||||
from cryptography.x509 import (
|
||||
AuthorityInformationAccess as _AuthorityInformationAccess,
|
||||
ExtendedKeyUsage as _ExtendedKeyUsage,
|
||||
ExtensionNotFound as _ExtensionNotFound,
|
||||
load_pem_x509_certificate as _load_pem_x509_certificate,
|
||||
TLSFeature as _TLSFeature,
|
||||
TLSFeatureType as _TLSFeatureType)
|
||||
from cryptography.x509.oid import (
|
||||
AuthorityInformationAccessOID as _AuthorityInformationAccessOID,
|
||||
ExtendedKeyUsageOID as _ExtendedKeyUsageOID)
|
||||
from cryptography.x509.ocsp import (
|
||||
load_der_ocsp_response as _load_der_ocsp_response,
|
||||
OCSPCertStatus as _OCSPCertStatus,
|
||||
OCSPRequestBuilder as _OCSPRequestBuilder,
|
||||
OCSPResponseStatus as _OCSPResponseStatus)
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import \
|
||||
DSAPublicKey as _DSAPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA as _ECDSA
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import \
|
||||
EllipticCurvePublicKey as _EllipticCurvePublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import \
|
||||
PKCS1v15 as _PKCS1v15
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import \
|
||||
RSAPublicKey as _RSAPublicKey
|
||||
from cryptography.hazmat.primitives.hashes import SHA1 as _SHA1
|
||||
from cryptography.hazmat.primitives.hashes import Hash as _Hash
|
||||
from cryptography.hazmat.primitives.serialization import Encoding as _Encoding
|
||||
from cryptography.hazmat.primitives.serialization import \
|
||||
PublicFormat as _PublicFormat
|
||||
from cryptography.x509 import \
|
||||
AuthorityInformationAccess as _AuthorityInformationAccess
|
||||
from cryptography.x509 import ExtendedKeyUsage as _ExtendedKeyUsage
|
||||
from cryptography.x509 import ExtensionNotFound as _ExtensionNotFound
|
||||
from cryptography.x509 import TLSFeature as _TLSFeature
|
||||
from cryptography.x509 import TLSFeatureType as _TLSFeatureType
|
||||
from cryptography.x509 import \
|
||||
load_pem_x509_certificate as _load_pem_x509_certificate
|
||||
from cryptography.x509.ocsp import OCSPCertStatus as _OCSPCertStatus
|
||||
from cryptography.x509.ocsp import OCSPRequestBuilder as _OCSPRequestBuilder
|
||||
from cryptography.x509.ocsp import OCSPResponseStatus as _OCSPResponseStatus
|
||||
from cryptography.x509.ocsp import \
|
||||
load_der_ocsp_response as _load_der_ocsp_response
|
||||
from cryptography.x509.oid import \
|
||||
AuthorityInformationAccessOID as _AuthorityInformationAccessOID
|
||||
from cryptography.x509.oid import ExtendedKeyUsageOID as _ExtendedKeyUsageOID
|
||||
from requests import post as _post
|
||||
from requests.exceptions import RequestException as _RequestException
|
||||
|
||||
|
||||
@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Operation class definitions."""
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pymongo import helpers
|
||||
from pymongo.common import validate_boolean, validate_is_mapping, validate_list
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.common import validate_boolean, validate_is_mapping, validate_list
|
||||
from pymongo.helpers import _gen_index_name, _index_document, _index_list
|
||||
from pymongo.typings import _CollationIn, _DocumentIn, _Pipeline
|
||||
|
||||
|
||||
class InsertOne(object):
|
||||
@ -25,7 +27,7 @@ class InsertOne(object):
|
||||
|
||||
__slots__ = ("_doc",)
|
||||
|
||||
def __init__(self, document):
|
||||
def __init__(self, document: _DocumentIn) -> None:
|
||||
"""Create an InsertOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
@ -43,21 +45,25 @@ class InsertOne(object):
|
||||
def __repr__(self):
|
||||
return "InsertOne(%r)" % (self._doc,)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return other._doc == self._doc
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
|
||||
_IndexKeyHint = Union[str, _IndexList]
|
||||
|
||||
|
||||
class DeleteOne(object):
|
||||
"""Represents a delete_one operation."""
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(self, filter, collation=None, hint=None):
|
||||
def __init__(self, filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None) -> None:
|
||||
"""Create a DeleteOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
@ -95,13 +101,13 @@ class DeleteOne(object):
|
||||
def __repr__(self):
|
||||
return "DeleteOne(%r, %r)" % (self._filter, self._collation)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return ((other._filter, other._collation) ==
|
||||
(self._filter, self._collation))
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
@ -110,7 +116,7 @@ class DeleteMany(object):
|
||||
|
||||
__slots__ = ("_filter", "_collation", "_hint")
|
||||
|
||||
def __init__(self, filter, collation=None, hint=None):
|
||||
def __init__(self, filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None) -> None:
|
||||
"""Create a DeleteMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
@ -148,13 +154,13 @@ class DeleteMany(object):
|
||||
def __repr__(self):
|
||||
return "DeleteMany(%r, %r)" % (self._filter, self._collation)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return ((other._filter, other._collation) ==
|
||||
(self._filter, self._collation))
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
@ -163,8 +169,8 @@ class ReplaceOne(object):
|
||||
|
||||
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
|
||||
|
||||
def __init__(self, filter, replacement, upsert=False, collation=None,
|
||||
hint=None):
|
||||
def __init__(self, filter: Mapping[str, Any], replacement: Mapping[str, Any], upsert: bool = False, collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None) -> None:
|
||||
"""Create a ReplaceOne instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
@ -207,7 +213,7 @@ class ReplaceOne(object):
|
||||
bulkobj.add_replace(self._filter, self._doc, self._upsert,
|
||||
collation=self._collation, hint=self._hint)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if type(other) == type(self):
|
||||
return (
|
||||
(other._filter, other._doc, other._upsert, other._collation,
|
||||
@ -215,7 +221,7 @@ class ReplaceOne(object):
|
||||
self._collation, other._hint))
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self):
|
||||
@ -241,7 +247,6 @@ class _UpdateOp(object):
|
||||
if not isinstance(hint, str):
|
||||
hint = helpers._index_document(hint)
|
||||
|
||||
|
||||
self._filter = filter
|
||||
self._doc = doc
|
||||
self._upsert = upsert
|
||||
@ -272,8 +277,8 @@ class UpdateOne(_UpdateOp):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, filter, update, upsert=False, collation=None,
|
||||
array_filters=None, hint=None):
|
||||
def __init__(self, filter: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[List[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None) -> None:
|
||||
"""Represents an update_one operation.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
@ -319,8 +324,8 @@ class UpdateMany(_UpdateOp):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, filter, update, upsert=False, collation=None,
|
||||
array_filters=None, hint=None):
|
||||
def __init__(self, filter: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[List[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None) -> None:
|
||||
"""Create an UpdateMany instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
|
||||
@ -366,7 +371,7 @@ class IndexModel(object):
|
||||
|
||||
__slots__ = ("__document",)
|
||||
|
||||
def __init__(self, keys, **kwargs):
|
||||
def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None:
|
||||
"""Create an Index instance.
|
||||
|
||||
For use with :meth:`~pymongo.collection.Collection.create_indexes`.
|
||||
@ -437,7 +442,7 @@ class IndexModel(object):
|
||||
self.__document['collation'] = collation
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
def document(self) -> Dict[str, Any]:
|
||||
"""An index document suitable for passing to the createIndexes
|
||||
command.
|
||||
"""
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class PeriodicExecutor(object):
|
||||
@ -41,7 +42,7 @@ class PeriodicExecutor(object):
|
||||
self._min_interval = min_interval
|
||||
self._target = target
|
||||
self._stopped = False
|
||||
self._thread = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._name = name
|
||||
self._skip_sleep = False
|
||||
|
||||
@ -52,7 +53,7 @@ class PeriodicExecutor(object):
|
||||
return '<%s(name=%s) object at 0x%x>' % (
|
||||
self.__class__.__name__, self._name, id(self))
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect.
|
||||
|
||||
Not safe to call from multiple threads at once.
|
||||
@ -64,13 +65,14 @@ class PeriodicExecutor(object):
|
||||
# join should not block indefinitely because there is no
|
||||
# other work done outside the while loop in self._run.
|
||||
try:
|
||||
assert self._thread is not None
|
||||
self._thread.join()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
self._thread_will_exit = False
|
||||
self._stopped = False
|
||||
started = False
|
||||
started: Any = False
|
||||
try:
|
||||
started = self._thread and self._thread.is_alive()
|
||||
except ReferenceError:
|
||||
@ -84,7 +86,7 @@ class PeriodicExecutor(object):
|
||||
_register_executor(self)
|
||||
thread.start()
|
||||
|
||||
def close(self, dummy=None):
|
||||
def close(self, dummy: Any = None) -> None:
|
||||
"""Stop. To restart, call open().
|
||||
|
||||
The dummy parameter allows an executor's close method to be a weakref
|
||||
@ -92,7 +94,7 @@ class PeriodicExecutor(object):
|
||||
"""
|
||||
self._stopped = True
|
||||
|
||||
def join(self, timeout=None):
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._thread is not None:
|
||||
try:
|
||||
self._thread.join(timeout)
|
||||
@ -100,14 +102,14 @@ class PeriodicExecutor(object):
|
||||
# Thread already terminated, or not yet started.
|
||||
pass
|
||||
|
||||
def wake(self):
|
||||
def wake(self) -> None:
|
||||
"""Execute the target function soon."""
|
||||
self._event = True
|
||||
|
||||
def update_interval(self, new_interval):
|
||||
def update_interval(self, new_interval: int) -> None:
|
||||
self._interval = new_interval
|
||||
|
||||
def skip_sleep(self):
|
||||
def skip_sleep(self) -> None:
|
||||
self._skip_sleep = True
|
||||
|
||||
def __should_stop(self):
|
||||
|
||||
@ -18,50 +18,38 @@ import copy
|
||||
import ipaddress
|
||||
import os
|
||||
import platform
|
||||
import ssl
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
from bson import DEFAULT_CODEC_OPTIONS
|
||||
from bson.son import SON
|
||||
from pymongo import auth, helpers, __version__
|
||||
from pymongo import __version__, auth, helpers
|
||||
from pymongo.client_session import _validate_session_write_concern
|
||||
from pymongo.common import (MAX_BSON_SIZE,
|
||||
MAX_CONNECTING,
|
||||
MAX_IDLE_TIME_SEC,
|
||||
MAX_MESSAGE_SIZE,
|
||||
MAX_POOL_SIZE,
|
||||
MAX_WIRE_VERSION,
|
||||
MAX_WRITE_BATCH_SIZE,
|
||||
MIN_POOL_SIZE,
|
||||
ORDERED_TYPES,
|
||||
from pymongo.common import (MAX_BSON_SIZE, MAX_CONNECTING, MAX_IDLE_TIME_SEC,
|
||||
MAX_MESSAGE_SIZE, MAX_POOL_SIZE, MAX_WIRE_VERSION,
|
||||
MAX_WRITE_BATCH_SIZE, MIN_POOL_SIZE, ORDERED_TYPES,
|
||||
WAIT_QUEUE_TIMEOUT)
|
||||
from pymongo.errors import (AutoReconnect,
|
||||
_CertificateError,
|
||||
ConnectionFailure,
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
DocumentTooLarge,
|
||||
NetworkTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
PyMongoError)
|
||||
from pymongo.hello import HelloCompat, Hello
|
||||
from pymongo.errors import (AutoReconnect, ConfigurationError,
|
||||
ConnectionFailure, DocumentTooLarge,
|
||||
InvalidOperation, NetworkTimeout, NotPrimaryError,
|
||||
OperationFailure, PyMongoError, _CertificateError)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.monitoring import (ConnectionCheckOutFailedReason,
|
||||
ConnectionClosedReason)
|
||||
from pymongo.network import (command,
|
||||
receive_message)
|
||||
from pymongo.network import command, receive_message
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import _add_to_command
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
from pymongo.ssl_support import (
|
||||
SSLError as _SSLError,
|
||||
HAS_SNI as _HAVE_SNI,
|
||||
IPADDR_SAFE as _IPADDR_SAFE)
|
||||
from pymongo.ssl_support import HAS_SNI as _HAVE_SNI
|
||||
from pymongo.ssl_support import IPADDR_SAFE as _IPADDR_SAFE
|
||||
from pymongo.ssl_support import SSLError as _SSLError
|
||||
|
||||
|
||||
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
|
||||
# not permitted for SNI hostname.
|
||||
@ -73,7 +61,7 @@ def is_ip_address(address):
|
||||
return False
|
||||
|
||||
try:
|
||||
from fcntl import fcntl, F_GETFD, F_SETFD, FD_CLOEXEC
|
||||
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
|
||||
def _set_non_inheritable_non_atomic(fd):
|
||||
"""Set the close-on-exec flag on the given file descriptor."""
|
||||
flags = fcntl(fd, F_GETFD)
|
||||
@ -82,7 +70,7 @@ except ImportError:
|
||||
# Windows, various platforms we don't claim to support
|
||||
# (Jython, IronPython, ...), systems that don't provide
|
||||
# everything we need from fcntl, etc.
|
||||
def _set_non_inheritable_non_atomic(dummy):
|
||||
def _set_non_inheritable_non_atomic(fd):
|
||||
"""Dummy function for platforms that don't provide fcntl."""
|
||||
pass
|
||||
|
||||
@ -145,7 +133,7 @@ else:
|
||||
_set_tcp_option(sock, 'TCP_KEEPINTVL', _MAX_TCP_KEEPINTVL)
|
||||
_set_tcp_option(sock, 'TCP_KEEPCNT', _MAX_TCP_KEEPCNT)
|
||||
|
||||
_METADATA = SON([
|
||||
_METADATA: SON[str, Any] = SON([
|
||||
('driver', SON([('name', 'PyMongo'), ('version', __version__)])),
|
||||
])
|
||||
|
||||
@ -205,7 +193,7 @@ else:
|
||||
if platform.python_implementation().startswith('PyPy'):
|
||||
_METADATA['platform'] = ' '.join(
|
||||
(platform.python_implementation(),
|
||||
'.'.join(map(str, sys.pypy_version_info)),
|
||||
'.'.join(map(str, sys.pypy_version_info)), # type: ignore
|
||||
'(Python %s)' % '.'.join(map(str, sys.version_info))))
|
||||
elif sys.platform.startswith('java'):
|
||||
_METADATA['platform'] = ' '.join(
|
||||
@ -688,7 +676,7 @@ class SocketInfo(object):
|
||||
session = _validate_session_write_concern(session, write_concern)
|
||||
|
||||
# Ensure command name remains in first place.
|
||||
if not isinstance(spec, ORDERED_TYPES):
|
||||
if not isinstance(spec, ORDERED_TYPES): # type:ignore[arg-type]
|
||||
spec = SON(spec)
|
||||
|
||||
if not (write_concern is None or write_concern.acknowledged or
|
||||
@ -1088,7 +1076,7 @@ class Pool:
|
||||
# LIFO pool. Sockets are ordered on idle time. Sockets claimed
|
||||
# and returned to pool from the left side. Stale sockets removed
|
||||
# from the right side.
|
||||
self.sockets = collections.deque()
|
||||
self.sockets: collections.deque = collections.deque()
|
||||
self.lock = threading.Lock()
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
@ -1165,8 +1153,8 @@ class Pool:
|
||||
if service_id is None:
|
||||
sockets, self.sockets = self.sockets, collections.deque()
|
||||
else:
|
||||
discard = collections.deque()
|
||||
keep = collections.deque()
|
||||
discard: collections.deque = collections.deque()
|
||||
keep: collections.deque = collections.deque()
|
||||
for sock_info in self.sockets:
|
||||
if sock_info.service_id == service_id:
|
||||
discard.append(sock_info)
|
||||
|
||||
@ -20,29 +20,28 @@ import socket as _socket
|
||||
import ssl as _stdlibssl
|
||||
import sys as _sys
|
||||
import time as _time
|
||||
|
||||
from errno import EINTR as _EINTR
|
||||
|
||||
from ipaddress import ip_address as _ip_address
|
||||
|
||||
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate
|
||||
from OpenSSL import crypto as _crypto, SSL as _SSL
|
||||
from service_identity.pyopenssl import (
|
||||
verify_hostname as _verify_hostname,
|
||||
verify_ip_address as _verify_ip_address)
|
||||
from cryptography.x509 import \
|
||||
load_der_x509_certificate as _load_der_x509_certificate
|
||||
from OpenSSL import SSL as _SSL
|
||||
from OpenSSL import crypto as _crypto
|
||||
from service_identity import (
|
||||
CertificateError as _SICertificateError,
|
||||
VerificationError as _SIVerificationError)
|
||||
CertificateError as _SICertificateError
|
||||
)
|
||||
from service_identity import VerificationError as _SIVerificationError
|
||||
from service_identity.pyopenssl import ( #
|
||||
verify_hostname as _verify_hostname
|
||||
)
|
||||
from service_identity.pyopenssl import verify_ip_address as _verify_ip_address
|
||||
|
||||
from pymongo.errors import (
|
||||
_CertificateError,
|
||||
ConfigurationError as _ConfigurationError)
|
||||
from pymongo.ocsp_support import (
|
||||
_load_trusted_ca_certs,
|
||||
_ocsp_callback)
|
||||
from pymongo.errors import ConfigurationError as _ConfigurationError
|
||||
from pymongo.errors import _CertificateError
|
||||
from pymongo.ocsp_cache import _OCSPCache
|
||||
from pymongo.socket_checker import (
|
||||
_errno_from_exception, SocketChecker as _SocketChecker)
|
||||
from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback
|
||||
from pymongo.socket_checker import SocketChecker as _SocketChecker
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
try:
|
||||
import certifi
|
||||
@ -132,7 +131,7 @@ class _sslConn(_SSL.Connection):
|
||||
|
||||
def recv_into(self, *args, **kwargs):
|
||||
try:
|
||||
return self._call(super(_sslConn, self).recv_into, *args, **kwargs)
|
||||
return self._call(super(_sslConn, self).recv_into, *args, **kwargs) # type: ignore
|
||||
except _SSL.SysCallError as exc:
|
||||
# Suppress ragged EOFs to match the stdlib.
|
||||
if self.suppress_ragged_eofs and _ragged_eof(exc):
|
||||
@ -147,7 +146,7 @@ class _sslConn(_SSL.Connection):
|
||||
while total_sent < total_length:
|
||||
try:
|
||||
sent = self._call(
|
||||
super(_sslConn, self).send, view[total_sent:], flags)
|
||||
super(_sslConn, self).send, view[total_sent:], flags) # type: ignore
|
||||
# XXX: It's not clear if this can actually happen. PyOpenSSL
|
||||
# doesn't appear to have any interrupt handling, nor any interrupt
|
||||
# errors for OpenSSL connections.
|
||||
@ -296,7 +295,7 @@ class SSLContext(object):
|
||||
"""Attempt to load CA certs from Windows trust store."""
|
||||
cert_store = self._ctx.get_cert_store()
|
||||
oid = _stdlibssl.Purpose.SERVER_AUTH.oid
|
||||
for cert, encoding, trust in _stdlibssl.enum_certificates(store):
|
||||
for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore
|
||||
if encoding == "x509_asn":
|
||||
if trust is True or oid in trust:
|
||||
cert_store.add_cert(
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
|
||||
"""Tools for working with read concerns."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class ReadConcern(object):
|
||||
"""ReadConcern
|
||||
@ -29,7 +31,7 @@ class ReadConcern(object):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, level=None):
|
||||
def __init__(self, level: Optional[str] = None) -> None:
|
||||
if level is None or isinstance(level, str):
|
||||
self.__level = level
|
||||
else:
|
||||
@ -37,18 +39,18 @@ class ReadConcern(object):
|
||||
'level must be a string or None.')
|
||||
|
||||
@property
|
||||
def level(self):
|
||||
def level(self) -> Optional[str]:
|
||||
"""The read concern level."""
|
||||
return self.__level
|
||||
|
||||
@property
|
||||
def ok_for_legacy(self):
|
||||
def ok_for_legacy(self) -> bool:
|
||||
"""Return ``True`` if this read concern is compatible with
|
||||
old wire protocol versions."""
|
||||
return self.level is None or self.level == 'local'
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
def document(self) -> Dict[str, Any]:
|
||||
"""The document representation of this read concern.
|
||||
|
||||
.. note::
|
||||
@ -60,7 +62,7 @@ class ReadConcern(object):
|
||||
doc['level'] = self.level
|
||||
return doc
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, ReadConcern):
|
||||
return self.document == other.document
|
||||
return NotImplemented
|
||||
|
||||
@ -15,13 +15,13 @@
|
||||
"""Utilities for choosing which member of a replica set to read from."""
|
||||
|
||||
from collections import abc
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||
|
||||
from pymongo import max_staleness_selectors
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_selectors import (member_with_tags_server_selector,
|
||||
secondary_with_tags_server_selector)
|
||||
|
||||
|
||||
_PRIMARY = 0
|
||||
_PRIMARY_PREFERRED = 1
|
||||
_SECONDARY = 2
|
||||
@ -44,9 +44,9 @@ def _validate_tag_sets(tag_sets):
|
||||
if tag_sets is None:
|
||||
return tag_sets
|
||||
|
||||
if not isinstance(tag_sets, list):
|
||||
if not isinstance(tag_sets, (list, tuple)):
|
||||
raise TypeError((
|
||||
"Tag sets %r invalid, must be a list") % (tag_sets,))
|
||||
"Tag sets %r invalid, must be a sequence") % (tag_sets,))
|
||||
if len(tag_sets) == 0:
|
||||
raise ValueError((
|
||||
"Tag sets %r invalid, must be None or contain at least one set of"
|
||||
@ -59,7 +59,7 @@ def _validate_tag_sets(tag_sets):
|
||||
"bson.son.SON or other type that inherits from "
|
||||
"collection.Mapping" % (tags,))
|
||||
|
||||
return tag_sets
|
||||
return list(tag_sets)
|
||||
|
||||
|
||||
def _invalid_max_staleness_msg(max_staleness):
|
||||
@ -93,6 +93,10 @@ def _validate_hedge(hedge):
|
||||
return hedge
|
||||
|
||||
|
||||
_Hedge = Mapping[str, Any]
|
||||
_TagSets = Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
class _ServerMode(object):
|
||||
"""Base class for all read preferences.
|
||||
"""
|
||||
@ -100,7 +104,7 @@ class _ServerMode(object):
|
||||
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness",
|
||||
"__hedge")
|
||||
|
||||
def __init__(self, mode, tag_sets=None, max_staleness=-1, hedge=None):
|
||||
def __init__(self, mode: int, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
|
||||
self.__mongos_mode = _MONGOS_MODES[mode]
|
||||
self.__mode = mode
|
||||
self.__tag_sets = _validate_tag_sets(tag_sets)
|
||||
@ -108,22 +112,22 @@ class _ServerMode(object):
|
||||
self.__hedge = _validate_hedge(hedge)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""The name of this read preference.
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def mongos_mode(self):
|
||||
def mongos_mode(self) -> str:
|
||||
"""The mongos mode of this read preference.
|
||||
"""
|
||||
return self.__mongos_mode
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
def document(self) -> Dict[str, Any]:
|
||||
"""Read preference as a document.
|
||||
"""
|
||||
doc = {'mode': self.__mongos_mode}
|
||||
doc: Dict[str, Any] = {'mode': self.__mongos_mode}
|
||||
if self.__tag_sets not in (None, [{}]):
|
||||
doc['tags'] = self.__tag_sets
|
||||
if self.__max_staleness != -1:
|
||||
@ -133,13 +137,13 @@ class _ServerMode(object):
|
||||
return doc
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
def mode(self) -> int:
|
||||
"""The mode of this read preference instance.
|
||||
"""
|
||||
return self.__mode
|
||||
|
||||
@property
|
||||
def tag_sets(self):
|
||||
def tag_sets(self) -> _TagSets:
|
||||
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
|
||||
read only from members whose ``dc`` tag has the value ``"ny"``.
|
||||
To specify a priority-order for tag sets, provide a list of
|
||||
@ -154,14 +158,14 @@ class _ServerMode(object):
|
||||
return list(self.__tag_sets) if self.__tag_sets else [{}]
|
||||
|
||||
@property
|
||||
def max_staleness(self):
|
||||
def max_staleness(self) -> int:
|
||||
"""The maximum estimated length of time (in seconds) a replica set
|
||||
secondary can fall behind the primary in replication before it will
|
||||
no longer be selected for operations, or -1 for no maximum."""
|
||||
return self.__max_staleness
|
||||
|
||||
@property
|
||||
def hedge(self):
|
||||
def hedge(self) -> Optional[_Hedge]:
|
||||
"""The read preference ``hedge`` parameter.
|
||||
|
||||
A dictionary that configures how the server will perform hedged reads.
|
||||
@ -185,7 +189,7 @@ class _ServerMode(object):
|
||||
return self.__hedge
|
||||
|
||||
@property
|
||||
def min_wire_version(self):
|
||||
def min_wire_version(self) -> int:
|
||||
"""The wire protocol version the server must support.
|
||||
|
||||
Some read preferences impose version requirements on all servers (e.g.
|
||||
@ -201,7 +205,7 @@ class _ServerMode(object):
|
||||
return "%s(tag_sets=%r, max_staleness=%r, hedge=%r)" % (
|
||||
self.name, self.__tag_sets, self.__max_staleness, self.__hedge)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return (self.mode == other.mode and
|
||||
self.tag_sets == other.tag_sets and
|
||||
@ -209,7 +213,7 @@ class _ServerMode(object):
|
||||
self.hedge == other.hedge)
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __getstate__(self):
|
||||
@ -243,17 +247,17 @@ class Primary(_ServerMode):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(Primary, self).__init__(_PRIMARY)
|
||||
|
||||
def __call__(self, selection):
|
||||
def __call__(self, selection: Any) -> Any:
|
||||
"""Apply this read preference to a Selection."""
|
||||
return selection.primary_selection
|
||||
|
||||
def __repr__(self):
|
||||
return "Primary()"
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, _ServerMode):
|
||||
return other.mode == _PRIMARY
|
||||
return NotImplemented
|
||||
@ -289,11 +293,11 @@ class PrimaryPreferred(_ServerMode):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, tag_sets=None, max_staleness=-1, hedge=None):
|
||||
def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
|
||||
super(PrimaryPreferred, self).__init__(
|
||||
_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection):
|
||||
def __call__(self, selection: Any) -> Any:
|
||||
"""Apply this read preference to Selection."""
|
||||
if selection.primary:
|
||||
return selection.primary_selection
|
||||
@ -329,11 +333,11 @@ class Secondary(_ServerMode):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, tag_sets=None, max_staleness=-1, hedge=None):
|
||||
def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
|
||||
super(Secondary, self).__init__(
|
||||
_SECONDARY, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection):
|
||||
def __call__(self, selection: Any) -> Any:
|
||||
"""Apply this read preference to Selection."""
|
||||
return secondary_with_tags_server_selector(
|
||||
self.tag_sets,
|
||||
@ -370,11 +374,11 @@ class SecondaryPreferred(_ServerMode):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, tag_sets=None, max_staleness=-1, hedge=None):
|
||||
def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
|
||||
super(SecondaryPreferred, self).__init__(
|
||||
_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection):
|
||||
def __call__(self, selection: Any) -> Any:
|
||||
"""Apply this read preference to Selection."""
|
||||
secondaries = secondary_with_tags_server_selector(
|
||||
self.tag_sets,
|
||||
@ -412,11 +416,11 @@ class Nearest(_ServerMode):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, tag_sets=None, max_staleness=-1, hedge=None):
|
||||
def __init__(self, tag_sets: Optional[_TagSets] = None, max_staleness: int = -1, hedge: Optional[_Hedge] = None) -> None:
|
||||
super(Nearest, self).__init__(
|
||||
_NEAREST, tag_sets, max_staleness, hedge)
|
||||
|
||||
def __call__(self, selection):
|
||||
def __call__(self, selection: Any) -> Any:
|
||||
"""Apply this read preference to Selection."""
|
||||
return member_with_tags_server_selector(
|
||||
self.tag_sets,
|
||||
@ -467,7 +471,7 @@ _ALL_READ_PREFERENCES = (Primary, PrimaryPreferred,
|
||||
Secondary, SecondaryPreferred, Nearest)
|
||||
|
||||
|
||||
def make_read_preference(mode, tag_sets, max_staleness=-1):
|
||||
def make_read_preference(mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1) -> _ServerMode:
|
||||
if mode == _PRIMARY:
|
||||
if tag_sets not in (None, [{}]):
|
||||
raise ConfigurationError("Read preference primary "
|
||||
@ -476,7 +480,7 @@ def make_read_preference(mode, tag_sets, max_staleness=-1):
|
||||
raise ConfigurationError("Read preference primary cannot be "
|
||||
"combined with maxStalenessSeconds")
|
||||
return Primary()
|
||||
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness)
|
||||
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore
|
||||
|
||||
|
||||
_MODES = (
|
||||
@ -545,7 +549,7 @@ class ReadPreference(object):
|
||||
NEAREST = Nearest()
|
||||
|
||||
|
||||
def read_pref_mode_from_name(name):
|
||||
def read_pref_mode_from_name(name: str) -> int:
|
||||
"""Get the read preference mode from mongos/uri name.
|
||||
"""
|
||||
return _MONGOS_MODES.index(name)
|
||||
@ -553,10 +557,12 @@ def read_pref_mode_from_name(name):
|
||||
|
||||
class MovingAverage(object):
|
||||
"""Tracks an exponentially-weighted moving average."""
|
||||
def __init__(self):
|
||||
average: Optional[float]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.average = None
|
||||
|
||||
def add_sample(self, sample):
|
||||
def add_sample(self, sample: float) -> None:
|
||||
if sample < 0:
|
||||
# Likely system time change while waiting for hello response
|
||||
# and not using time.monotonic. Ignore it, the next one will
|
||||
@ -569,9 +575,9 @@ class MovingAverage(object):
|
||||
# average with alpha = 0.2.
|
||||
self.average = 0.8 * self.average + 0.2 * sample
|
||||
|
||||
def get(self):
|
||||
def get(self) -> Optional[float]:
|
||||
"""Get the calculated average, or None if no samples yet."""
|
||||
return self.average
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.average = None
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Result class definitions."""
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
|
||||
|
||||
from pymongo.errors import InvalidOperation
|
||||
|
||||
@ -22,7 +23,7 @@ class _WriteResult(object):
|
||||
|
||||
__slots__ = ("__acknowledged",)
|
||||
|
||||
def __init__(self, acknowledged):
|
||||
def __init__(self, acknowledged: bool) -> None:
|
||||
self.__acknowledged = acknowledged
|
||||
|
||||
def _raise_if_unacknowledged(self, property_name):
|
||||
@ -34,7 +35,7 @@ class _WriteResult(object):
|
||||
"error." % (property_name,))
|
||||
|
||||
@property
|
||||
def acknowledged(self):
|
||||
def acknowledged(self) -> bool:
|
||||
"""Is this the result of an acknowledged write operation?
|
||||
|
||||
The :attr:`acknowledged` attribute will be ``False`` when using
|
||||
@ -59,12 +60,12 @@ class InsertOneResult(_WriteResult):
|
||||
|
||||
__slots__ = ("__inserted_id", "__acknowledged")
|
||||
|
||||
def __init__(self, inserted_id, acknowledged):
|
||||
def __init__(self, inserted_id: Any, acknowledged: bool) -> None:
|
||||
self.__inserted_id = inserted_id
|
||||
super(InsertOneResult, self).__init__(acknowledged)
|
||||
|
||||
@property
|
||||
def inserted_id(self):
|
||||
def inserted_id(self) -> Any:
|
||||
"""The inserted document's _id."""
|
||||
return self.__inserted_id
|
||||
|
||||
@ -75,12 +76,12 @@ class InsertManyResult(_WriteResult):
|
||||
|
||||
__slots__ = ("__inserted_ids", "__acknowledged")
|
||||
|
||||
def __init__(self, inserted_ids, acknowledged):
|
||||
def __init__(self, inserted_ids: List[Any], acknowledged: bool) -> None:
|
||||
self.__inserted_ids = inserted_ids
|
||||
super(InsertManyResult, self).__init__(acknowledged)
|
||||
|
||||
@property
|
||||
def inserted_ids(self):
|
||||
def inserted_ids(self) -> List:
|
||||
"""A list of _ids of the inserted documents, in the order provided.
|
||||
|
||||
.. note:: If ``False`` is passed for the `ordered` parameter to
|
||||
@ -99,17 +100,17 @@ class UpdateResult(_WriteResult):
|
||||
|
||||
__slots__ = ("__raw_result", "__acknowledged")
|
||||
|
||||
def __init__(self, raw_result, acknowledged):
|
||||
def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None:
|
||||
self.__raw_result = raw_result
|
||||
super(UpdateResult, self).__init__(acknowledged)
|
||||
|
||||
@property
|
||||
def raw_result(self):
|
||||
def raw_result(self) -> Dict[str, Any]:
|
||||
"""The raw result document returned by the server."""
|
||||
return self.__raw_result
|
||||
|
||||
@property
|
||||
def matched_count(self):
|
||||
def matched_count(self) -> int:
|
||||
"""The number of documents matched for this update."""
|
||||
self._raise_if_unacknowledged("matched_count")
|
||||
if self.upserted_id is not None:
|
||||
@ -117,13 +118,13 @@ class UpdateResult(_WriteResult):
|
||||
return self.__raw_result.get("n", 0)
|
||||
|
||||
@property
|
||||
def modified_count(self):
|
||||
def modified_count(self) -> int:
|
||||
"""The number of documents modified. """
|
||||
self._raise_if_unacknowledged("modified_count")
|
||||
return self.__raw_result.get("nModified")
|
||||
return cast(int, self.__raw_result.get("nModified"))
|
||||
|
||||
@property
|
||||
def upserted_id(self):
|
||||
def upserted_id(self) -> Any:
|
||||
"""The _id of the inserted document if an upsert took place. Otherwise
|
||||
``None``.
|
||||
"""
|
||||
@ -137,17 +138,17 @@ class DeleteResult(_WriteResult):
|
||||
|
||||
__slots__ = ("__raw_result", "__acknowledged")
|
||||
|
||||
def __init__(self, raw_result, acknowledged):
|
||||
def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None:
|
||||
self.__raw_result = raw_result
|
||||
super(DeleteResult, self).__init__(acknowledged)
|
||||
|
||||
@property
|
||||
def raw_result(self):
|
||||
def raw_result(self) -> Dict[str, Any]:
|
||||
"""The raw result document returned by the server."""
|
||||
return self.__raw_result
|
||||
|
||||
@property
|
||||
def deleted_count(self):
|
||||
def deleted_count(self) -> int:
|
||||
"""The number of documents deleted."""
|
||||
self._raise_if_unacknowledged("deleted_count")
|
||||
return self.__raw_result.get("n", 0)
|
||||
@ -158,7 +159,7 @@ class BulkWriteResult(_WriteResult):
|
||||
|
||||
__slots__ = ("__bulk_api_result", "__acknowledged")
|
||||
|
||||
def __init__(self, bulk_api_result, acknowledged):
|
||||
def __init__(self, bulk_api_result: Dict[str, Any], acknowledged: bool) -> None:
|
||||
"""Create a BulkWriteResult instance.
|
||||
|
||||
:Parameters:
|
||||
@ -171,44 +172,45 @@ class BulkWriteResult(_WriteResult):
|
||||
super(BulkWriteResult, self).__init__(acknowledged)
|
||||
|
||||
@property
|
||||
def bulk_api_result(self):
|
||||
def bulk_api_result(self) -> Dict[str, Any]:
|
||||
"""The raw bulk API result."""
|
||||
return self.__bulk_api_result
|
||||
|
||||
@property
|
||||
def inserted_count(self):
|
||||
def inserted_count(self) -> int:
|
||||
"""The number of documents inserted."""
|
||||
self._raise_if_unacknowledged("inserted_count")
|
||||
return self.__bulk_api_result.get("nInserted")
|
||||
return cast(int, self.__bulk_api_result.get("nInserted"))
|
||||
|
||||
@property
|
||||
def matched_count(self):
|
||||
def matched_count(self) -> int:
|
||||
"""The number of documents matched for an update."""
|
||||
self._raise_if_unacknowledged("matched_count")
|
||||
return self.__bulk_api_result.get("nMatched")
|
||||
return cast(int, self.__bulk_api_result.get("nMatched"))
|
||||
|
||||
@property
|
||||
def modified_count(self):
|
||||
def modified_count(self) -> int:
|
||||
"""The number of documents modified."""
|
||||
self._raise_if_unacknowledged("modified_count")
|
||||
return self.__bulk_api_result.get("nModified")
|
||||
return cast(int, self.__bulk_api_result.get("nModified"))
|
||||
|
||||
@property
|
||||
def deleted_count(self):
|
||||
def deleted_count(self) -> int:
|
||||
"""The number of documents deleted."""
|
||||
self._raise_if_unacknowledged("deleted_count")
|
||||
return self.__bulk_api_result.get("nRemoved")
|
||||
return cast(int, self.__bulk_api_result.get("nRemoved"))
|
||||
|
||||
@property
|
||||
def upserted_count(self):
|
||||
def upserted_count(self) -> int:
|
||||
"""The number of documents upserted."""
|
||||
self._raise_if_unacknowledged("upserted_count")
|
||||
return self.__bulk_api_result.get("nUpserted")
|
||||
return cast(int, self.__bulk_api_result.get("nUpserted"))
|
||||
|
||||
@property
|
||||
def upserted_ids(self):
|
||||
def upserted_ids(self) -> Optional[Dict[int, Any]]:
|
||||
"""A map of operation index to the _id of the upserted document."""
|
||||
self._raise_if_unacknowledged("upserted_ids")
|
||||
if self.__bulk_api_result:
|
||||
return dict((upsert["index"], upsert["_id"])
|
||||
for upsert in self.bulk_api_result["upserted"])
|
||||
return None
|
||||
|
||||
@ -13,13 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""An implementation of RFC4013 SASLprep."""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
import stringprep
|
||||
except ImportError:
|
||||
HAVE_STRINGPREP = False
|
||||
def saslprep(data):
|
||||
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> str:
|
||||
"""SASLprep dummy"""
|
||||
if isinstance(data, str):
|
||||
raise TypeError(
|
||||
@ -29,6 +29,7 @@ except ImportError:
|
||||
else:
|
||||
HAVE_STRINGPREP = True
|
||||
import unicodedata
|
||||
|
||||
# RFC4013 section 2.3 prohibited output.
|
||||
_PROHIBITED = (
|
||||
# A strict reading of RFC 4013 requires table c12 here, but
|
||||
@ -44,7 +45,7 @@ else:
|
||||
stringprep.in_table_c8,
|
||||
stringprep.in_table_c9)
|
||||
|
||||
def saslprep(data, prohibit_unassigned_code_points=True):
|
||||
def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> str:
|
||||
"""An implementation of RFC4013 SASLprep.
|
||||
|
||||
:Parameters:
|
||||
@ -60,6 +61,8 @@ else:
|
||||
:Returns:
|
||||
The SASLprep'ed version of `data`.
|
||||
"""
|
||||
prohibited: Any
|
||||
|
||||
if not isinstance(data, str):
|
||||
return data
|
||||
|
||||
|
||||
@ -17,11 +17,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from bson import _decode_all_selective
|
||||
|
||||
from pymongo.errors import NotPrimaryError, OperationFailure
|
||||
from pymongo.helpers import _check_command_response
|
||||
from pymongo.message import _convert_exception, _OpMsg
|
||||
from pymongo.response import Response, PinnedResponse
|
||||
from pymongo.response import PinnedResponse, Response
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
|
||||
@ -59,6 +58,8 @@ class Server(object):
|
||||
Reconnect with open().
|
||||
"""
|
||||
if self._publish:
|
||||
assert self._listener is not None
|
||||
assert self._events is not None
|
||||
self._events.put((self._listener.publish_server_closed,
|
||||
(self._description.address, self._topology_id)))
|
||||
self._monitor.close()
|
||||
@ -169,6 +170,8 @@ class Server(object):
|
||||
docs = _decode_all_selective(
|
||||
decrypted, operation.codec_options, user_fields)
|
||||
|
||||
response: Response
|
||||
|
||||
if client._should_pin_cursor(operation.session) or operation.exhaust:
|
||||
sock_info.pin_cursor()
|
||||
if isinstance(reply, _OpMsg):
|
||||
|
||||
@ -15,10 +15,13 @@
|
||||
"""Represent one server the driver is connected to."""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Mapping, Optional, Set, Tuple, cast
|
||||
|
||||
from bson import EPOCH_NAIVE
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import _Address
|
||||
|
||||
|
||||
class ServerDescription(object):
|
||||
@ -41,11 +44,12 @@ class ServerDescription(object):
|
||||
'_topology_version')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address,
|
||||
hello=None,
|
||||
round_trip_time=None,
|
||||
error=None):
|
||||
self,
|
||||
address: _Address,
|
||||
hello: Optional[Hello] = None,
|
||||
round_trip_time: Optional[float] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
self._address = address
|
||||
if not hello:
|
||||
hello = Hello({})
|
||||
@ -72,9 +76,11 @@ class ServerDescription(object):
|
||||
self._error = error
|
||||
self._topology_version = hello.topology_version
|
||||
if error:
|
||||
if hasattr(error, 'details') and isinstance(error.details, dict):
|
||||
self._topology_version = error.details.get('topologyVersion')
|
||||
details = getattr(error, 'details', None)
|
||||
if isinstance(details, dict):
|
||||
self._topology_version = details.get('topologyVersion')
|
||||
|
||||
self._last_write_date: Optional[float]
|
||||
if hello.last_write_date:
|
||||
# Convert from datetime to seconds.
|
||||
delta = hello.last_write_date - EPOCH_NAIVE
|
||||
@ -83,17 +89,17 @@ class ServerDescription(object):
|
||||
self._last_write_date = None
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
def address(self) -> _Address:
|
||||
"""The address (host, port) of this server."""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def server_type(self):
|
||||
def server_type(self) -> int:
|
||||
"""The type of this server."""
|
||||
return self._server_type
|
||||
|
||||
@property
|
||||
def server_type_name(self):
|
||||
def server_type_name(self) -> str:
|
||||
"""The server type as a human readable string.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
@ -101,78 +107,78 @@ class ServerDescription(object):
|
||||
return SERVER_TYPE._fields[self._server_type]
|
||||
|
||||
@property
|
||||
def all_hosts(self):
|
||||
def all_hosts(self) -> Set[Tuple[str, int]]:
|
||||
"""List of hosts, passives, and arbiters known to this server."""
|
||||
return self._all_hosts
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
def tags(self) -> Mapping[str, Any]:
|
||||
return self._tags
|
||||
|
||||
@property
|
||||
def replica_set_name(self):
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""Replica set name or None."""
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def primary(self):
|
||||
def primary(self) -> Optional[Tuple[str, int]]:
|
||||
"""This server's opinion about who the primary is, or None."""
|
||||
return self._primary
|
||||
|
||||
@property
|
||||
def max_bson_size(self):
|
||||
def max_bson_size(self) -> int:
|
||||
return self._max_bson_size
|
||||
|
||||
@property
|
||||
def max_message_size(self):
|
||||
def max_message_size(self) -> int:
|
||||
return self._max_message_size
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self):
|
||||
def max_write_batch_size(self) -> int:
|
||||
return self._max_write_batch_size
|
||||
|
||||
@property
|
||||
def min_wire_version(self):
|
||||
def min_wire_version(self) -> int:
|
||||
return self._min_wire_version
|
||||
|
||||
@property
|
||||
def max_wire_version(self):
|
||||
def max_wire_version(self) -> int:
|
||||
return self._max_wire_version
|
||||
|
||||
@property
|
||||
def set_version(self):
|
||||
def set_version(self) -> Optional[int]:
|
||||
return self._set_version
|
||||
|
||||
@property
|
||||
def election_id(self):
|
||||
def election_id(self) -> Optional[ObjectId]:
|
||||
return self._election_id
|
||||
|
||||
@property
|
||||
def cluster_time(self):
|
||||
def cluster_time(self)-> Optional[Mapping[str, Any]]:
|
||||
return self._cluster_time
|
||||
|
||||
@property
|
||||
def election_tuple(self):
|
||||
def election_tuple(self) -> Tuple[Optional[int], Optional[ObjectId]]:
|
||||
return self._set_version, self._election_id
|
||||
|
||||
@property
|
||||
def me(self):
|
||||
def me(self) -> Optional[Tuple[str, int]]:
|
||||
return self._me
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self):
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
return self._ls_timeout_minutes
|
||||
|
||||
@property
|
||||
def last_write_date(self):
|
||||
def last_write_date(self) -> Optional[float]:
|
||||
return self._last_write_date
|
||||
|
||||
@property
|
||||
def last_update_time(self):
|
||||
def last_update_time(self) -> float:
|
||||
return self._last_update_time
|
||||
|
||||
@property
|
||||
def round_trip_time(self):
|
||||
def round_trip_time(self) -> Optional[float]:
|
||||
"""The current average latency or None."""
|
||||
# This override is for unittesting only!
|
||||
if self._address in self._host_to_round_trip_time:
|
||||
@ -181,28 +187,28 @@ class ServerDescription(object):
|
||||
return self._round_trip_time
|
||||
|
||||
@property
|
||||
def error(self):
|
||||
def error(self) -> Optional[Exception]:
|
||||
"""The last error attempting to connect to the server, or None."""
|
||||
return self._error
|
||||
|
||||
@property
|
||||
def is_writable(self):
|
||||
def is_writable(self) -> bool:
|
||||
return self._is_writable
|
||||
|
||||
@property
|
||||
def is_readable(self):
|
||||
def is_readable(self) -> bool:
|
||||
return self._is_readable
|
||||
|
||||
@property
|
||||
def mongos(self):
|
||||
def mongos(self) -> bool:
|
||||
return self._server_type == SERVER_TYPE.Mongos
|
||||
|
||||
@property
|
||||
def is_server_type_known(self):
|
||||
def is_server_type_known(self) -> bool:
|
||||
return self.server_type != SERVER_TYPE.Unknown
|
||||
|
||||
@property
|
||||
def retryable_writes_supported(self):
|
||||
def retryable_writes_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return ((
|
||||
self._ls_timeout_minutes is not None and
|
||||
@ -210,20 +216,20 @@ class ServerDescription(object):
|
||||
or self._server_type == SERVER_TYPE.LoadBalancer)
|
||||
|
||||
@property
|
||||
def retryable_reads_supported(self):
|
||||
def retryable_reads_supported(self) -> bool:
|
||||
"""Checks if this server supports retryable writes."""
|
||||
return self._max_wire_version >= 6
|
||||
|
||||
@property
|
||||
def topology_version(self):
|
||||
def topology_version(self) -> Optional[Mapping[str, Any]]:
|
||||
return self._topology_version
|
||||
|
||||
def to_unknown(self, error=None):
|
||||
def to_unknown(self, error: Optional[Exception] = None) -> "ServerDescription":
|
||||
unknown = ServerDescription(self.address, error=error)
|
||||
unknown._topology_version = self.topology_version
|
||||
return unknown
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, ServerDescription):
|
||||
return ((self._address == other.address) and
|
||||
(self._server_type == other.server_type) and
|
||||
@ -242,7 +248,7 @@ class ServerDescription(object):
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __repr__(self):
|
||||
@ -254,4 +260,4 @@ class ServerDescription(object):
|
||||
self.round_trip_time, errmsg)
|
||||
|
||||
# For unittesting only. Use under no circumstances!
|
||||
_host_to_round_trip_time = {}
|
||||
_host_to_round_trip_time: Dict = {}
|
||||
|
||||
@ -14,10 +14,19 @@
|
||||
|
||||
"""Type codes for MongoDB servers."""
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
SERVER_TYPE = namedtuple('ServerType',
|
||||
['Unknown', 'Mongos', 'RSPrimary', 'RSSecondary',
|
||||
'RSArbiter', 'RSOther', 'RSGhost',
|
||||
'Standalone', 'LoadBalancer'])(*range(9))
|
||||
class _ServerType(NamedTuple):
|
||||
Unknown: int
|
||||
Mongos: int
|
||||
RSPrimary: int
|
||||
RSSecondary: int
|
||||
RSArbiter: int
|
||||
RSOther: int
|
||||
RSGhost: int
|
||||
Standalone: int
|
||||
LoadBalancer: int
|
||||
|
||||
|
||||
SERVER_TYPE = _ServerType(*range(9))
|
||||
|
||||
@ -16,7 +16,9 @@
|
||||
|
||||
import errno
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
# PYTHON-2320: Jython does not fully support poll on SSL sockets,
|
||||
# https://bugs.jython.org/issue2900
|
||||
@ -34,17 +36,19 @@ def _errno_from_exception(exc):
|
||||
|
||||
class SocketChecker(object):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._poller: Optional[select.poll]
|
||||
if _HAVE_POLL:
|
||||
self._poller = select.poll()
|
||||
else:
|
||||
self._poller = None
|
||||
|
||||
def select(self, sock, read=False, write=False, timeout=0):
|
||||
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
|
||||
"""Select for reads or writes with a timeout in seconds (or None).
|
||||
|
||||
Returns True if the socket is readable/writable, False on timeout.
|
||||
"""
|
||||
res: Any
|
||||
while True:
|
||||
try:
|
||||
if self._poller:
|
||||
@ -74,12 +78,12 @@ class SocketChecker(object):
|
||||
# ready: subsets of the first three arguments. Return
|
||||
# True if any of the lists are not empty.
|
||||
return any(res)
|
||||
except (_SelectError, IOError) as exc:
|
||||
except (_SelectError, IOError) as exc: # type: ignore
|
||||
if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN):
|
||||
continue
|
||||
raise
|
||||
|
||||
def socket_closed(self, sock):
|
||||
def socket_closed(self, sock: Any) -> bool:
|
||||
"""Return True if we know socket has been closed, False otherwise.
|
||||
"""
|
||||
try:
|
||||
|
||||
@ -26,6 +26,7 @@ except ImportError:
|
||||
from pymongo.common import CONNECT_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
|
||||
# dnspython can return bytes or str from various parts
|
||||
# of its API depending on version. We always want str.
|
||||
def maybe_decode(text):
|
||||
@ -38,7 +39,7 @@ def maybe_decode(text):
|
||||
def _resolve(*args, **kwargs):
|
||||
if hasattr(resolver, 'resolve'):
|
||||
# dnspython >= 2
|
||||
return resolver.resolve(*args, **kwargs)
|
||||
return resolver.resolve(*args, **kwargs) # type: ignore
|
||||
# dnspython 1.X
|
||||
return resolver.query(*args, **kwargs)
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ IS_PYOPENSSL = False
|
||||
SSLError = _ssl.SSLError
|
||||
|
||||
from ssl import SSLContext
|
||||
|
||||
if hasattr(_ssl, "VERIFY_CRL_CHECK_LEAF"):
|
||||
from ssl import VERIFY_CRL_CHECK_LEAF
|
||||
# Python 3.7 uses OpenSSL's hostname matching implementation
|
||||
|
||||
@ -24,7 +24,7 @@ try:
|
||||
import pymongo.pyopenssl_context as _ssl
|
||||
except ImportError:
|
||||
try:
|
||||
import pymongo.ssl_context as _ssl
|
||||
import pymongo.ssl_context as _ssl # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
HAVE_SSL = False
|
||||
|
||||
@ -74,7 +74,7 @@ if HAVE_SSL:
|
||||
raise ConfigurationError(
|
||||
"tlsCRLFile cannot be used with PyOpenSSL")
|
||||
# Match the server's behavior.
|
||||
ctx.verify_flags = getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0)
|
||||
setattr(ctx, 'verify_flags', getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0))
|
||||
ctx.load_verify_locations(crlfile)
|
||||
if ca_certs is not None:
|
||||
ctx.load_verify_locations(ca_certs)
|
||||
@ -83,11 +83,11 @@ if HAVE_SSL:
|
||||
ctx.verify_mode = verify_mode
|
||||
return ctx
|
||||
else:
|
||||
class SSLError(Exception):
|
||||
class SSLError(Exception): # type: ignore
|
||||
pass
|
||||
HAS_SNI = False
|
||||
IPADDR_SAFE = False
|
||||
|
||||
def get_ssl_context(*dummy):
|
||||
def get_ssl_context(*dummy): # type: ignore
|
||||
"""No ssl module, raise ConfigurationError."""
|
||||
raise ConfigurationError("The ssl module is not available.")
|
||||
|
||||
@ -21,35 +21,27 @@ import threading
|
||||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
from pymongo import (common,
|
||||
helpers,
|
||||
periodic_executor)
|
||||
from pymongo import common, helpers, periodic_executor
|
||||
from pymongo.client_session import _ServerSessionPool
|
||||
from pymongo.errors import (ConnectionFailure,
|
||||
ConfigurationError,
|
||||
NetworkTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
ServerSelectionTimeoutError,
|
||||
WriteError,
|
||||
InvalidOperation)
|
||||
from pymongo.errors import (ConfigurationError, ConnectionFailure,
|
||||
InvalidOperation, NetworkTimeout, NotPrimaryError,
|
||||
OperationFailure, PyMongoError,
|
||||
ServerSelectionTimeoutError, WriteError)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.monitor import SrvMonitor
|
||||
from pymongo.pool import PoolOptions
|
||||
from pymongo.server import Server
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import (any_server_selector,
|
||||
from pymongo.server_selectors import (Selection, any_server_selector,
|
||||
arbiter_server_selector,
|
||||
secondary_server_selector,
|
||||
readable_server_selector,
|
||||
writable_server_selector,
|
||||
Selection)
|
||||
from pymongo.topology_description import (updated_topology_description,
|
||||
_updated_topology_description_srv_polling,
|
||||
TopologyDescription,
|
||||
SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE)
|
||||
secondary_server_selector,
|
||||
writable_server_selector)
|
||||
from pymongo.topology_description import (
|
||||
SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE, TopologyDescription,
|
||||
_updated_topology_description_srv_polling, updated_topology_description)
|
||||
|
||||
|
||||
def process_events_queue(queue_ref):
|
||||
@ -80,12 +72,13 @@ class Topology(object):
|
||||
|
||||
# Create events queue if there are publishers.
|
||||
self._events = None
|
||||
self.__events_executor = None
|
||||
self.__events_executor: Any = None
|
||||
|
||||
if self._publish_server or self._publish_tp:
|
||||
self._events = queue.Queue(maxsize=100)
|
||||
|
||||
if self._publish_tp:
|
||||
assert self._events is not None
|
||||
self._events.put((self._listeners.publish_topology_opened,
|
||||
(self._topology_id,)))
|
||||
self._settings = topology_settings
|
||||
@ -99,6 +92,7 @@ class Topology(object):
|
||||
|
||||
self._description = topology_description
|
||||
if self._publish_tp:
|
||||
assert self._events is not None
|
||||
initial_td = TopologyDescription(TOPOLOGY_TYPE.Unknown, {}, None,
|
||||
None, None, self._settings)
|
||||
self._events.put((
|
||||
@ -107,6 +101,7 @@ class Topology(object):
|
||||
|
||||
for seed in topology_settings.seeds:
|
||||
if self._publish_server:
|
||||
assert self._events is not None
|
||||
self._events.put((self._listeners.publish_server_opened,
|
||||
(seed, self._topology_id)))
|
||||
|
||||
@ -296,6 +291,7 @@ class Topology(object):
|
||||
suppress_event = ((self._publish_server or self._publish_tp)
|
||||
and sd_old == server_description)
|
||||
if self._publish_server and not suppress_event:
|
||||
assert self._events is not None
|
||||
self._events.put((
|
||||
self._listeners.publish_server_description_changed,
|
||||
(sd_old, server_description,
|
||||
@ -306,6 +302,7 @@ class Topology(object):
|
||||
self._receive_cluster_time_no_lock(server_description.cluster_time)
|
||||
|
||||
if self._publish_tp and not suppress_event:
|
||||
assert self._events is not None
|
||||
self._events.put((
|
||||
self._listeners.publish_topology_description_changed,
|
||||
(td_old, self._description, self._topology_id)))
|
||||
@ -354,6 +351,7 @@ class Topology(object):
|
||||
self._update_servers()
|
||||
|
||||
if self._publish_tp:
|
||||
assert self._events is not None
|
||||
self._events.put((
|
||||
self._listeners.publish_topology_description_changed,
|
||||
(td_old, self._description, self._topology_id)))
|
||||
@ -485,6 +483,7 @@ class Topology(object):
|
||||
|
||||
# Publish only after releasing the lock.
|
||||
if self._publish_tp:
|
||||
assert self._events is not None
|
||||
self._events.put((self._listeners.publish_topology_closed,
|
||||
(self._topology_id,)))
|
||||
if self._publish_server or self._publish_tp:
|
||||
|
||||
@ -14,34 +14,48 @@
|
||||
|
||||
"""Represent a deployment of MongoDB servers."""
|
||||
|
||||
from collections import namedtuple
|
||||
from random import sample
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo import common
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_preferences import ReadPreference, _AggWritePref
|
||||
from pymongo.read_preferences import ReadPreference, _AggWritePref, _ServerMode
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import Selection
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import _Address
|
||||
|
||||
|
||||
# Enumeration for various kinds of MongoDB cluster topologies.
|
||||
TOPOLOGY_TYPE = namedtuple('TopologyType', [
|
||||
'Single', 'ReplicaSetNoPrimary', 'ReplicaSetWithPrimary', 'Sharded',
|
||||
'Unknown', 'LoadBalanced'])(*range(6))
|
||||
class _TopologyType(NamedTuple):
|
||||
Single: int
|
||||
ReplicaSetNoPrimary: int
|
||||
ReplicaSetWithPrimary: int
|
||||
Sharded: int
|
||||
Unknown: int
|
||||
LoadBalanced: int
|
||||
|
||||
|
||||
TOPOLOGY_TYPE = _TopologyType(*range(6))
|
||||
|
||||
# Topologies compatible with SRV record polling.
|
||||
SRV_POLLING_TOPOLOGIES = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
|
||||
SRV_POLLING_TOPOLOGIES: Tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
|
||||
|
||||
|
||||
_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]]
|
||||
|
||||
|
||||
class TopologyDescription(object):
|
||||
def __init__(self,
|
||||
topology_type,
|
||||
server_descriptions,
|
||||
replica_set_name,
|
||||
max_set_version,
|
||||
max_election_id,
|
||||
topology_settings):
|
||||
def __init__(
|
||||
self,
|
||||
topology_type: int,
|
||||
server_descriptions: Dict[_Address, ServerDescription],
|
||||
replica_set_name: Optional[str],
|
||||
max_set_version: Optional[int],
|
||||
max_election_id: Optional[ObjectId],
|
||||
topology_settings: Any,
|
||||
) -> None:
|
||||
"""Representation of a deployment of MongoDB servers.
|
||||
|
||||
:Parameters:
|
||||
@ -81,7 +95,7 @@ class TopologyDescription(object):
|
||||
for s in readable_servers):
|
||||
self._ls_timeout_minutes = None
|
||||
else:
|
||||
self._ls_timeout_minutes = min(s.logical_session_timeout_minutes
|
||||
self._ls_timeout_minutes = min(s.logical_session_timeout_minutes # type: ignore
|
||||
for s in readable_servers)
|
||||
|
||||
def _init_incompatible_err(self):
|
||||
@ -104,23 +118,23 @@ class TopologyDescription(object):
|
||||
|
||||
if server_too_new:
|
||||
self._incompatible_err = (
|
||||
"Server at %s:%d requires wire version %d, but this "
|
||||
"Server at %s:%d requires wire version %d, but this " # type: ignore
|
||||
"version of PyMongo only supports up to %d."
|
||||
% (s.address[0], s.address[1],
|
||||
% (s.address[0], s.address[1] or 0,
|
||||
s.min_wire_version, common.MAX_SUPPORTED_WIRE_VERSION))
|
||||
|
||||
elif server_too_old:
|
||||
self._incompatible_err = (
|
||||
"Server at %s:%d reports wire version %d, but this "
|
||||
"Server at %s:%d reports wire version %d, but this " # type: ignore
|
||||
"version of PyMongo requires at least %d (MongoDB %s)."
|
||||
% (s.address[0], s.address[1],
|
||||
% (s.address[0], s.address[1] or 0,
|
||||
s.max_wire_version,
|
||||
common.MIN_SUPPORTED_WIRE_VERSION,
|
||||
common.MIN_SUPPORTED_SERVER_VERSION))
|
||||
|
||||
break
|
||||
|
||||
def check_compatible(self):
|
||||
def check_compatible(self) -> None:
|
||||
"""Raise ConfigurationError if any server is incompatible.
|
||||
|
||||
A server is incompatible if its wire protocol version range does not
|
||||
@ -129,15 +143,15 @@ class TopologyDescription(object):
|
||||
if self._incompatible_err:
|
||||
raise ConfigurationError(self._incompatible_err)
|
||||
|
||||
def has_server(self, address):
|
||||
def has_server(self, address: _Address) -> bool:
|
||||
return address in self._server_descriptions
|
||||
|
||||
def reset_server(self, address):
|
||||
def reset_server(self, address: _Address) -> "TopologyDescription":
|
||||
"""A copy of this description, with one server marked Unknown."""
|
||||
unknown_sd = self._server_descriptions[address].to_unknown()
|
||||
return updated_topology_description(self, unknown_sd)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> "TopologyDescription":
|
||||
"""A copy of this description, with all servers marked Unknown."""
|
||||
if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary:
|
||||
topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
@ -156,18 +170,18 @@ class TopologyDescription(object):
|
||||
self._max_election_id,
|
||||
self._topology_settings)
|
||||
|
||||
def server_descriptions(self):
|
||||
def server_descriptions(self) -> Dict[_Address, ServerDescription]:
|
||||
"""Dict of (address,
|
||||
:class:`~pymongo.server_description.ServerDescription`)."""
|
||||
return self._server_descriptions.copy()
|
||||
|
||||
@property
|
||||
def topology_type(self):
|
||||
def topology_type(self) -> int:
|
||||
"""The type of this topology."""
|
||||
return self._topology_type
|
||||
|
||||
@property
|
||||
def topology_type_name(self):
|
||||
def topology_type_name(self) -> str:
|
||||
"""The topology type as a human readable string.
|
||||
|
||||
.. versionadded:: 3.4
|
||||
@ -175,44 +189,44 @@ class TopologyDescription(object):
|
||||
return TOPOLOGY_TYPE._fields[self._topology_type]
|
||||
|
||||
@property
|
||||
def replica_set_name(self):
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
"""The replica set name."""
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def max_set_version(self):
|
||||
def max_set_version(self) -> Optional[int]:
|
||||
"""Greatest setVersion seen from a primary, or None."""
|
||||
return self._max_set_version
|
||||
|
||||
@property
|
||||
def max_election_id(self):
|
||||
def max_election_id(self) -> Optional[ObjectId]:
|
||||
"""Greatest electionId seen from a primary, or None."""
|
||||
return self._max_election_id
|
||||
|
||||
@property
|
||||
def logical_session_timeout_minutes(self):
|
||||
def logical_session_timeout_minutes(self) -> Optional[int]:
|
||||
"""Minimum logical session timeout, or None."""
|
||||
return self._ls_timeout_minutes
|
||||
|
||||
@property
|
||||
def known_servers(self):
|
||||
def known_servers(self) -> List[ServerDescription]:
|
||||
"""List of Servers of types besides Unknown."""
|
||||
return [s for s in self._server_descriptions.values()
|
||||
if s.is_server_type_known]
|
||||
|
||||
@property
|
||||
def has_known_servers(self):
|
||||
def has_known_servers(self) -> bool:
|
||||
"""Whether there are any Servers of types besides Unknown."""
|
||||
return any(s for s in self._server_descriptions.values()
|
||||
if s.is_server_type_known)
|
||||
|
||||
@property
|
||||
def readable_servers(self):
|
||||
def readable_servers(self) -> List[ServerDescription]:
|
||||
"""List of readable Servers."""
|
||||
return [s for s in self._server_descriptions.values() if s.is_readable]
|
||||
|
||||
@property
|
||||
def common_wire_version(self):
|
||||
def common_wire_version(self) -> Optional[int]:
|
||||
"""Minimum of all servers' max wire versions, or None."""
|
||||
servers = self.known_servers
|
||||
if servers:
|
||||
@ -221,11 +235,11 @@ class TopologyDescription(object):
|
||||
return None
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self):
|
||||
def heartbeat_frequency(self) -> int:
|
||||
return self._topology_settings.heartbeat_frequency
|
||||
|
||||
@property
|
||||
def srv_max_hosts(self):
|
||||
def srv_max_hosts(self) -> int:
|
||||
return self._topology_settings._srv_max_hosts
|
||||
|
||||
def _apply_local_threshold(self, selection):
|
||||
@ -238,7 +252,12 @@ class TopologyDescription(object):
|
||||
return [s for s in selection.server_descriptions
|
||||
if (s.round_trip_time - fastest) <= threshold]
|
||||
|
||||
def apply_selector(self, selector, address=None, custom_selector=None):
|
||||
def apply_selector(
|
||||
self,
|
||||
selector: Any,
|
||||
address: Optional[_Address] = None,
|
||||
custom_selector: Optional[_ServerSelector] = None
|
||||
) -> List[ServerDescription]:
|
||||
"""List of servers matching the provided selector(s).
|
||||
|
||||
:Parameters:
|
||||
@ -288,7 +307,7 @@ class TopologyDescription(object):
|
||||
custom_selector(selection.server_descriptions))
|
||||
return self._apply_local_threshold(selection)
|
||||
|
||||
def has_readable_server(self, read_preference=ReadPreference.PRIMARY):
|
||||
def has_readable_server(self, read_preference: _ServerMode =ReadPreference.PRIMARY) -> bool:
|
||||
"""Does this topology have any readable servers available matching the
|
||||
given read preference?
|
||||
|
||||
@ -305,7 +324,7 @@ class TopologyDescription(object):
|
||||
common.validate_read_preference("read_preference", read_preference)
|
||||
return any(self.apply_selector(read_preference))
|
||||
|
||||
def has_writable_server(self):
|
||||
def has_writable_server(self) -> bool:
|
||||
"""Does this topology have a writable server available?
|
||||
|
||||
.. note:: When connected directly to a single server this method
|
||||
@ -336,7 +355,9 @@ _SERVER_TYPE_TO_TOPOLOGY_TYPE = {
|
||||
}
|
||||
|
||||
|
||||
def updated_topology_description(topology_description, server_description):
|
||||
def updated_topology_description(
|
||||
topology_description: TopologyDescription, server_description: ServerDescription
|
||||
) -> "TopologyDescription":
|
||||
"""Return an updated copy of a TopologyDescription.
|
||||
|
||||
:Parameters:
|
||||
|
||||
29
pymongo/typings.py
Normal file
29
pymongo/typings.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2022-Present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Type aliases used by PyMongo"""
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo.collation import Collation
|
||||
|
||||
|
||||
# Common Shared Types.
|
||||
_Address = Tuple[str, Optional[int]]
|
||||
_CollationIn = Union[Mapping[str, Any], "Collation"]
|
||||
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
|
||||
_Pipeline = List[Mapping[str, Any]]
|
||||
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])
|
||||
@ -15,20 +15,19 @@
|
||||
|
||||
"""Tools to parse and validate a MongoDB URI."""
|
||||
import re
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
import warnings
|
||||
from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple,
|
||||
Union, cast)
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from pymongo.client_options import _parse_ssl_options
|
||||
from pymongo.common import (
|
||||
SRV_SERVICE_NAME,
|
||||
get_validated_options, INTERNAL_URI_OPTION_NAME_MAP,
|
||||
URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary)
|
||||
from pymongo.common import (INTERNAL_URI_OPTION_NAME_MAP, SRV_SERVICE_NAME,
|
||||
URI_OPTIONS_DEPRECATION_MAP,
|
||||
_CaseInsensitiveDictionary, get_validated_options)
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver
|
||||
|
||||
|
||||
SCHEME = 'mongodb://'
|
||||
SCHEME_LEN = len(SCHEME)
|
||||
SRV_SCHEME = 'mongodb+srv://'
|
||||
@ -52,7 +51,7 @@ def _unquoted_percent(s):
|
||||
return True
|
||||
return False
|
||||
|
||||
def parse_userinfo(userinfo):
|
||||
def parse_userinfo(userinfo: str) -> Tuple[str, str]:
|
||||
"""Validates the format of user information in a MongoDB URI.
|
||||
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
|
||||
"]", "@") as per RFC 3986 must be escaped.
|
||||
@ -76,7 +75,7 @@ def parse_userinfo(userinfo):
|
||||
return unquote_plus(user), unquote_plus(passwd)
|
||||
|
||||
|
||||
def parse_ipv6_literal_host(entity, default_port):
|
||||
def parse_ipv6_literal_host(entity: str, default_port: Optional[int]) -> Tuple[str, Optional[Union[str, int]]]:
|
||||
"""Validates an IPv6 literal host:port string.
|
||||
|
||||
Returns a 2-tuple of IPv6 literal followed by port where
|
||||
@ -98,7 +97,7 @@ def parse_ipv6_literal_host(entity, default_port):
|
||||
return entity[1: i], entity[i + 2:]
|
||||
|
||||
|
||||
def parse_host(entity, default_port=DEFAULT_PORT):
|
||||
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> Tuple[str, Optional[int]]:
|
||||
"""Validates a host string
|
||||
|
||||
Returns a 2-tuple of host followed by port where port is default_port
|
||||
@ -111,7 +110,7 @@ def parse_host(entity, default_port=DEFAULT_PORT):
|
||||
specified in entity.
|
||||
"""
|
||||
host = entity
|
||||
port = default_port
|
||||
port: Optional[Union[str, int]] = default_port
|
||||
if entity[0] == '[':
|
||||
host, port = parse_ipv6_literal_host(entity, default_port)
|
||||
elif entity.endswith(".sock"):
|
||||
@ -279,7 +278,7 @@ def _normalize_options(options):
|
||||
return options
|
||||
|
||||
|
||||
def validate_options(opts, warn=False):
|
||||
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
|
||||
"""Validates and normalizes options passed in a MongoDB URI.
|
||||
|
||||
Returns a new dictionary of validated and normalized options. If warn is
|
||||
@ -295,7 +294,7 @@ def validate_options(opts, warn=False):
|
||||
return get_validated_options(opts, warn)
|
||||
|
||||
|
||||
def split_options(opts, validate=True, warn=False, normalize=True):
|
||||
def split_options(opts: str, validate: bool = True, warn: bool = False, normalize: bool = True) -> MutableMapping[str, Any]:
|
||||
"""Takes the options portion of a MongoDB URI, validates each option
|
||||
and returns the options in a dictionary.
|
||||
|
||||
@ -340,7 +339,7 @@ def split_options(opts, validate=True, warn=False, normalize=True):
|
||||
return options
|
||||
|
||||
|
||||
def split_hosts(hosts, default_port=DEFAULT_PORT):
|
||||
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> List[Tuple[str, Optional[int]]]:
|
||||
"""Takes a string of the form host1[:port],host2[:port]... and
|
||||
splits it into (host, port) tuples. If [:port] isn't present the
|
||||
default_port is used.
|
||||
@ -393,9 +392,16 @@ def _check_options(nodes, options):
|
||||
'Cannot specify replicaSet with loadBalanced=true')
|
||||
|
||||
|
||||
def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
|
||||
normalize=True, connect_timeout=None, srv_service_name=None,
|
||||
srv_max_hosts=None):
|
||||
def parse_uri(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Parse and validate a MongoDB URI.
|
||||
|
||||
Returns a dict of the form::
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
|
||||
"""Tools for working with write concerns."""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
|
||||
@ -45,8 +47,8 @@ class WriteConcern(object):
|
||||
|
||||
__slots__ = ("__document", "__acknowledged", "__server_default")
|
||||
|
||||
def __init__(self, w=None, wtimeout=None, j=None, fsync=None):
|
||||
self.__document = {}
|
||||
def __init__(self, w: Optional[Union[int, str]] = None, wtimeout: Optional[int] = None, j: Optional[bool] = None, fsync: Optional[bool] = None) -> None:
|
||||
self.__document: Dict[str, Any] = {}
|
||||
self.__acknowledged = True
|
||||
|
||||
if wtimeout is not None:
|
||||
@ -84,12 +86,12 @@ class WriteConcern(object):
|
||||
self.__server_default = not self.__document
|
||||
|
||||
@property
|
||||
def is_server_default(self):
|
||||
def is_server_default(self) -> bool:
|
||||
"""Does this WriteConcern match the server default."""
|
||||
return self.__server_default
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
def document(self) -> Dict[str, Any]:
|
||||
"""The document representation of this write concern.
|
||||
|
||||
.. note::
|
||||
@ -99,7 +101,7 @@ class WriteConcern(object):
|
||||
return self.__document.copy()
|
||||
|
||||
@property
|
||||
def acknowledged(self):
|
||||
def acknowledged(self) -> bool:
|
||||
"""If ``True`` write operations will wait for acknowledgement before
|
||||
returning.
|
||||
"""
|
||||
@ -109,12 +111,12 @@ class WriteConcern(object):
|
||||
return ("WriteConcern(%s)" % (
|
||||
", ".join("%s=%s" % kvt for kvt in self.__document.items()),))
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, WriteConcern):
|
||||
return self.__document == other.document
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
if isinstance(other, WriteConcern):
|
||||
return self.__document != other.document
|
||||
return NotImplemented
|
||||
|
||||
@ -25,7 +25,7 @@ import warnings
|
||||
try:
|
||||
import simplejson as json
|
||||
except ImportError:
|
||||
import json # type: ignore
|
||||
import json # type: ignore[no-redef]
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
|
||||
@ -889,7 +889,7 @@ class TestCursor(IntegrationTest):
|
||||
|
||||
# Every attribute should be the same.
|
||||
cursor2 = cursor.clone()
|
||||
self.assertEqual(cursor.__dict__, cursor2.__dict__)
|
||||
self.assertDictEqual(cursor.__dict__, cursor2.__dict__)
|
||||
|
||||
# Shallow copies can so can mutate
|
||||
cursor2 = copy.copy(cursor)
|
||||
|
||||
@ -238,7 +238,7 @@ class TestGridFile(IntegrationTest):
|
||||
cursor_dict.pop('_Cursor__session')
|
||||
cursor_clone_dict = cursor_clone.__dict__.copy()
|
||||
cursor_clone_dict.pop('_Cursor__session')
|
||||
self.assertEqual(cursor_dict, cursor_clone_dict)
|
||||
self.assertDictEqual(cursor_dict, cursor_clone_dict)
|
||||
|
||||
self.assertRaises(NotImplementedError, cursor.add_option, 0)
|
||||
self.assertRaises(NotImplementedError, cursor.remove_option, 0)
|
||||
|
||||
@ -33,7 +33,7 @@ except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from pymongo import _cmessage # type: ignore
|
||||
from pymongo import _cmessage # type: ignore[attr-defined]
|
||||
sys.exit("could still import _cmessage")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
Loading…
Reference in New Issue
Block a user