PYTHON-3905 Use from __future__ import annotations in all files (#1370)

* PYTHON-3905 Use from __future__ import annotations in all files

* cleanup

* cleanup

* cleanup
This commit is contained in:
Steven Silvester 2023-09-11 10:49:24 -05:00 committed by GitHub
parent b67ca68cc5
commit 6f4e617e6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 358 additions and 370 deletions

View File

@ -84,15 +84,9 @@ jobs:
- name: Install dependencies
run: |
pip install -q tox
- name: Run mypy
- name: Run typecheck
run: |
tox -m typecheck-mypy
- name: Run pyright
run: |
tox -m typecheck-pyright
- name: Run pyright strict
run: |
tox -m typecheck-pyright-strict
tox -m typecheck
docs:
name: Docs Checks

View File

@ -19,9 +19,10 @@ The :mod:`gridfs` package is an implementation of GridFS on top of
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
"""
from __future__ import annotations
from collections import abc
from typing import Any, List, Mapping, Optional, cast
from typing import Any, Mapping, Optional, cast
from bson.objectid import ObjectId
from gridfs.errors import NoFile
@ -170,7 +171,7 @@ class GridFS:
filename: Optional[str] = None,
version: Optional[int] = -1,
session: Optional[ClientSession] = None,
**kwargs: Any
**kwargs: Any,
) -> GridOut:
"""Get a file from GridFS by ``"filename"`` or metadata fields.
@ -275,7 +276,7 @@ class GridFS:
self.__files.delete_one({"_id": file_id}, session=session)
self.__chunks.delete_many({"files_id": file_id}, session=session)
def list(self, session: Optional[ClientSession] = None) -> List[str]:
def list(self, session: Optional[ClientSession] = None) -> list[str]:
"""List the names of all files stored in this instance of
:class:`GridFS`.
@ -301,7 +302,7 @@ class GridFS:
filter: Optional[Any] = None,
session: Optional[ClientSession] = None,
*args: Any,
**kwargs: Any
**kwargs: Any,
) -> Optional[GridOut]:
"""Get a single file from gridfs.
@ -400,7 +401,7 @@ class GridFS:
self,
document_or_id: Optional[Any] = None,
session: Optional[ClientSession] = None,
**kwargs: Any
**kwargs: Any,
) -> bool:
"""Check if a file exists in this instance of :class:`GridFS`.

View File

@ -13,11 +13,13 @@
# limitations under the License.
"""Tools for representing files stored in GridFS."""
from __future__ import annotations
import datetime
import io
import math
import os
from typing import Any, Iterable, List, Mapping, NoReturn, Optional
from typing import Any, Iterable, Mapping, NoReturn, Optional
from bson.binary import Binary
from bson.int64 import Int64
@ -480,7 +482,7 @@ class GridOut(io.IOBase):
upload_date: datetime.datetime = _grid_out_property(
"uploadDate", "Date that this file was first uploaded."
)
aliases: Optional[List[str]] = _grid_out_property("aliases", "List of aliases for this file.")
aliases: Optional[list[str]] = _grid_out_property("aliases", "List of aliases for this file.")
metadata: Optional[Mapping[str, Any]] = _grid_out_property(
"metadata", "Metadata attached to this file."
)

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Python driver for MongoDB."""
from __future__ import annotations
from typing import ContextManager, Optional

View File

@ -20,7 +20,7 @@ import functools
import time
from collections import deque
from contextvars import ContextVar, Token
from typing import Any, Callable, Deque, MutableMapping, Optional, Tuple, TypeVar, cast
from typing import Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
from pymongo.write_concern import WriteConcern
@ -72,7 +72,7 @@ class _TimeoutContext:
def __init__(self, timeout: Optional[float]):
self._timeout = timeout
self._tokens: Optional[Tuple[Token, Token, Token]] = None
self._tokens: Optional[tuple[Token, Token, Token]] = None
def __enter__(self) -> _TimeoutContext:
timeout_token = TIMEOUT.set(self._timeout)

View File

@ -24,13 +24,10 @@ from itertools import islice
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Mapping,
NoReturn,
Optional,
Tuple,
Type,
Union,
)
@ -76,7 +73,7 @@ _BAD_VALUE: int = 2
_UNKNOWN_ERROR: int = 8
_WRITE_CONCERN_ERROR: int = 64
_COMMANDS: Tuple[str, str, str] = ("insert", "update", "delete")
_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete")
class _Run:
@ -85,8 +82,8 @@ class _Run:
def __init__(self, op_type: int) -> None:
"""Initialize a new Run object."""
self.op_type: int = op_type
self.index_map: List[int] = []
self.ops: List[Any] = []
self.index_map: list[int] = []
self.ops: list[Any] = []
self.idx_offset: int = 0
def index(self, idx: int) -> int:
@ -182,7 +179,7 @@ class _Bulk:
common.validate_is_document_type("let", self.let)
self.comment: Optional[str] = comment
self.ordered = ordered
self.ops: List[Tuple[int, Mapping[str, Any]]] = []
self.ops: list[tuple[int, Mapping[str, Any]]] = []
self.executed = False
self.bypass_doc_val = bypass_document_validation
self.uses_collation = False
@ -219,12 +216,12 @@ class _Bulk:
multi: bool = False,
upsert: bool = False,
collation: Optional[Mapping[str, Any]] = None,
array_filters: Optional[List[Mapping[str, Any]]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, SON[str, Any], None] = None,
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd: Dict[str, Any] = dict(
cmd: dict[str, Any] = dict(
[("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)]
)
if collation is not None:
@ -414,7 +411,7 @@ class _Bulk:
generator: Iterator[Any],
write_concern: WriteConcern,
session: Optional[ClientSession],
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Execute using write commands."""
# nModified is only reported for write commands, not legacy ops.
full_result = {

View File

@ -16,17 +16,7 @@
from __future__ import annotations
import copy
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Mapping,
Optional,
Type,
Union,
)
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
from bson import _bson_to_dict
from bson.raw_bson import RawBSONDocument
@ -173,9 +163,9 @@ class ChangeStream(Generic[_DocumentType]):
"""
raise NotImplementedError
def _change_stream_options(self) -> Dict[str, Any]:
def _change_stream_options(self) -> dict[str, Any]:
"""Return the options dict for the $changeStream pipeline stage."""
options: Dict[str, Any] = {}
options: dict[str, Any] = {}
if self._full_document is not None:
options["fullDocument"] = self._full_document
@ -197,7 +187,7 @@ class ChangeStream(Generic[_DocumentType]):
return options
def _command_options(self) -> Dict[str, Any]:
def _command_options(self) -> dict[str, Any]:
"""Return the options dict for the aggregation command."""
options = {}
if self._max_await_time_ms is not None:
@ -206,7 +196,7 @@ class ChangeStream(Generic[_DocumentType]):
options["batchSize"] = self._batch_size
return options
def _aggregation_pipeline(self) -> List[Dict[str, Any]]:
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
"""Return the full aggregation pipeline for this ChangeStream."""
options = self._change_stream_options()
full_pipeline: list = [{"$changeStream": options}]
@ -491,7 +481,7 @@ class ClusterChangeStream(DatabaseChangeStream, Generic[_DocumentType]):
.. versionadded:: 3.7
"""
def _change_stream_options(self) -> Dict[str, Any]:
def _change_stream_options(self) -> dict[str, Any]:
options = super()._change_stream_options()
options["allChangesForCluster"] = True
return options

View File

@ -15,7 +15,7 @@
"""Tools to parse mongo client options."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Tuple, cast
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast
from bson.codec_options import _parse_codec_options
from pymongo import common
@ -80,7 +80,7 @@ def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
return ReadConcern(concern)
def _parse_ssl_options(options: Mapping[str, Any]) -> Tuple[Optional[SSLContext], bool]:
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
"""Parse ssl options."""
use_tls = options.get("tls")
if use_tls is not None:
@ -309,7 +309,7 @@ class ClientOptions:
return self.__load_balanced
@property
def event_listeners(self) -> List[_EventListeners]:
def event_listeners(self) -> list[_EventListeners]:
"""The event listeners registered for this client.
See :mod:`~pymongo.monitoring` for details.

View File

@ -144,8 +144,6 @@ from typing import (
Any,
Callable,
ContextManager,
Dict,
List,
Mapping,
MutableMapping,
NoReturn,
@ -832,7 +830,7 @@ class ClientSession:
self._transaction.state = _TxnState.ABORTED
self._unpin()
def _finish_transaction_with_retry(self, command_name: str) -> Dict[str, Any]:
def _finish_transaction_with_retry(self, command_name: str) -> dict[str, Any]:
"""Run commit or abort with one retry after any retryable error.
:Parameters:
@ -841,12 +839,12 @@ class ClientSession:
def func(
session: Optional[ClientSession], conn: Connection, retryable: bool
) -> Dict[str, Any]:
) -> dict[str, Any]:
return self._finish_transaction(conn, command_name)
return self._client._retry_internal(func, self, None, retryable=True)
def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]:
def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]:
self._transaction.attempt += 1
opts = self._transaction.opts
assert opts
@ -1102,7 +1100,7 @@ class _ServerSessionPool(collections.deque):
self.generation += 1
self.clear()
def pop_all(self) -> List[_ServerSession]:
def pop_all(self) -> list[_ServerSession]:
ids = []
while self:
ids.append(self.pop().session_id)

View File

@ -16,7 +16,9 @@
.. _collations: https://www.mongodb.com/docs/manual/reference/collation/
"""
from typing import Any, Dict, Mapping, Optional, Union
from __future__ import annotations
from typing import Any, Mapping, Optional, Union
from pymongo import common
@ -166,7 +168,7 @@ class Collation:
**kwargs: Any,
) -> None:
locale = common.validate_string("locale", locale)
self.__document: Dict[str, Any] = {"locale": locale}
self.__document: dict[str, Any] = {"locale": locale}
if caseLevel is not None:
self.__document["caseLevel"] = common.validate_boolean("caseLevel", caseLevel)
if caseFirst is not None:
@ -190,7 +192,7 @@ class Collation:
self.__document.update(kwargs)
@property
def document(self) -> Dict[str, Any]:
def document(self) -> dict[str, Any]:
"""The document representation of this collation.
.. note::
@ -214,7 +216,7 @@ class Collation:
def validate_collation_or_none(
value: Optional[Union[Mapping[str, Any], Collation]]
) -> Optional[Dict[str, Any]]:
) -> Optional[dict[str, Any]]:
if value is None:
return None
if isinstance(value, Collation):

View File

@ -24,13 +24,11 @@ from typing import (
Generic,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
@ -258,7 +256,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _conn_for_reads(
self, session: ClientSession
) -> ContextManager[Tuple[Connection, _ServerMode]]:
) -> ContextManager[tuple[Connection, _ServerMode]]:
return self.__database.client._conn_for_reads(self._read_preference_for(session), session)
def _conn_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]:
@ -739,9 +737,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
or not documents
):
raise TypeError("documents must be a non-empty list")
inserted_ids: List[ObjectId] = []
inserted_ids: list[ObjectId] = []
def gen() -> Iterator[Tuple[int, Mapping[str, Any]]]:
def gen() -> Iterator[tuple[int, Mapping[str, Any]]]:
"""A generator that validates documents and handles _ids."""
for document in documents:
common.validate_is_document_type("document", document)
@ -1930,7 +1928,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Create one or more indexes on this collection.
>>> from pymongo import IndexModel, ASCENDING, DESCENDING
@ -1975,7 +1973,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
@_csot.apply
def __create_indexes(
self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any
) -> List[str]:
) -> list[str]:
"""Internal createIndexes helper.
:Parameters:
@ -2438,11 +2436,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def create_search_indexes(
self,
models: List[SearchIndexModel],
models: list[SearchIndexModel],
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Create multiple search indexes for the current collection.
:Parameters:
@ -2990,7 +2988,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> List:
) -> list:
"""Get a list of distinct values for `key` among all documents
in this collection.
@ -3043,7 +3041,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
server: Server,
conn: Connection,
read_preference: Optional[_ServerMode],
) -> List:
) -> list:
return self._command(
conn,
cmd,

View File

@ -21,7 +21,6 @@ from typing import (
Any,
Generic,
Iterator,
List,
Mapping,
NoReturn,
Optional,
@ -389,7 +388,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
codec_options: CodecOptions,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[Mapping[str, Any]]:
) -> list[Mapping[str, Any]]:
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
if not legacy_response:
# OP_MSG returns firstBatch/nextBatch documents as a BSON array

View File

@ -24,15 +24,12 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
Union,
overload,
@ -140,7 +137,7 @@ _MAX_END_SESSIONS = 10000
SRV_SERVICE_NAME = "mongodb"
def partition_node(node: str) -> Tuple[str, int]:
def partition_node(node: str) -> tuple[str, int]:
"""Split a host:port string into (host, int(port)) pair."""
host = node
port = 27017
@ -152,7 +149,7 @@ def partition_node(node: str) -> Tuple[str, int]:
return host, port
def clean_node(node: str) -> Tuple[str, int]:
def clean_node(node: str) -> tuple[str, int]:
"""Split and normalize a node name from a hello response."""
host, port = partition_node(node)
@ -394,12 +391,12 @@ def validate_uuid_representation(dummy: Any, value: Any) -> int:
)
def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]]:
def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]:
"""Parse readPreferenceTags if passed as a client kwarg."""
if not isinstance(value, list):
value = [value]
tag_sets: List = []
tag_sets: list = []
for tag_set in value:
if tag_set == "":
tag_sets.append({})
@ -426,9 +423,9 @@ _MECHANISM_PROPS = frozenset(
)
def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]:
def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Union[bool, str]]:
"""Validate authMechanismProperties."""
props: Dict[str, Any] = {}
props: dict[str, Any] = {}
if not isinstance(value, str):
if not isinstance(value, dict):
raise ValueError("Auth mechanism properties must be given as a string or a dictionary")
@ -515,14 +512,14 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]:
return value
def validate_list(option: str, value: Any) -> List:
def validate_list(option: str, value: Any) -> list:
"""Validates that 'value' is a list."""
if not isinstance(value, list):
raise TypeError(f"{option} must be a list")
return value
def validate_list_or_none(option: Any, value: Any) -> Optional[List]:
def validate_list_or_none(option: Any, value: Any) -> Optional[list]:
"""Validates that 'value' is a list or None."""
if value is None:
return value
@ -671,7 +668,7 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo
# Dictionary where keys are the names of public URI options, and values
# are lists of aliases for that option.
URI_OPTIONS_ALIAS_MAP: Dict[str, List[str]] = {
URI_OPTIONS_ALIAS_MAP: dict[str, list[str]] = {
"tls": ["ssl"],
}
@ -679,7 +676,7 @@ URI_OPTIONS_ALIAS_MAP: Dict[str, List[str]] = {
# are functions that validate user-input values for that option. If an option
# alias uses a different validator than its public counterpart, it should be
# included here as a key, value pair.
URI_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = {
URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = {
"appname": validate_appname_or_none,
"authmechanism": validate_auth_mechanism,
"authmechanismproperties": validate_auth_mechanism_properties,
@ -721,7 +718,7 @@ URI_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = {
# Dictionary where keys are the names of URI options specific to pymongo,
# and values are functions that validate user-input values for those options.
NONSPEC_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = {
NONSPEC_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = {
"connect": validate_boolean_or_string,
"driver": validate_driver_or_none,
"server_api": validate_server_api_or_none,
@ -739,7 +736,7 @@ NONSPEC_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = {
# Dictionary where keys are the names of keyword-only options for the
# MongoClient constructor, and values are functions that validate user-input
# values for those options.
KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = {
KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = {
"document_class": validate_document_class,
"type_registry": validate_type_registry,
"read_preference": validate_read_preference,
@ -756,14 +753,14 @@ KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = {
# internally-used names of that URI option. Options with only one name
# variant need not be included here. Options whose public and internal
# names are the same need not be included here.
INTERNAL_URI_OPTION_NAME_MAP: Dict[str, str] = {
INTERNAL_URI_OPTION_NAME_MAP: dict[str, str] = {
"ssl": "tls",
}
# Map from deprecated URI option names to a tuple indicating the method of
# their deprecation and any additional information that may be needed to
# construct the warning message.
URI_OPTIONS_DEPRECATION_MAP: Dict[str, Tuple[str, str]] = {
URI_OPTIONS_DEPRECATION_MAP: dict[str, tuple[str, str]] = {
# format: <deprecated option name>: (<mode>, <message>),
# Supported <mode> values:
# - 'renamed': <message> should be the new option name. Note that case is
@ -782,11 +779,11 @@ for optname, aliases in URI_OPTIONS_ALIAS_MAP.items():
URI_OPTIONS_VALIDATOR_MAP[alias] = URI_OPTIONS_VALIDATOR_MAP[optname]
# Map containing all URI option and keyword argument validators.
VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy()
VALIDATORS: dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy()
VALIDATORS.update(KW_VALIDATORS)
# List of timeout-related options.
TIMEOUT_OPTIONS: List[str] = [
TIMEOUT_OPTIONS: list[str] = [
"connecttimeoutms",
"heartbeatfrequencyms",
"maxidletimems",
@ -800,7 +797,7 @@ TIMEOUT_OPTIONS: List[str] = [
_AUTH_OPTIONS = frozenset(["authmechanismproperties"])
def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]:
def validate_auth_option(option: str, value: Any) -> tuple[str, Any]:
"""Validate optional authentication parameters."""
lower, value = validate(option, value)
if lower not in _AUTH_OPTIONS:
@ -808,7 +805,7 @@ def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]:
return option, value
def validate(option: str, value: Any) -> Tuple[str, Any]:
def validate(option: str, value: Any) -> tuple[str, Any]:
"""Generic validation function."""
lower = option.lower()
validator = VALIDATORS.get(lower, raise_config_error)
@ -962,8 +959,8 @@ class BaseObject:
class _CaseInsensitiveDictionary(MutableMapping[str, Any]):
def __init__(self, *args: Any, **kwargs: Any):
self.__casedkeys: Dict[str, Any] = {}
self.__data: Dict[str, Any] = {}
self.__casedkeys: dict[str, Any] = {}
self.__data: dict[str, Any] = {}
self.update(dict(*args, **kwargs))
def __contains__(self, key: str) -> bool: # type: ignore[override]
@ -1010,7 +1007,7 @@ class _CaseInsensitiveDictionary(MutableMapping[str, Any]):
self.__casedkeys.pop(lc_key, None)
return self.__data.pop(lc_key, *args, **kwargs)
def popitem(self) -> Tuple[str, Any]:
def popitem(self) -> tuple[str, Any]:
lc_key, cased_key = self.__casedkeys.popitem()
value = self.__data.pop(lc_key)
return cased_key, value

View File

@ -14,7 +14,7 @@
from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Optional, Union
from typing import Any, Iterable, Optional, Union
try:
import snappy
@ -47,7 +47,7 @@ _NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> List[str]:
def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]:
try:
# `value` is string.
compressors = value.split(",") # type: ignore[union-attr]
@ -91,12 +91,12 @@ def validate_zlib_compression_level(option: str, value: Any) -> int:
class CompressionSettings:
def __init__(self, compressors: List[str], zlib_compression_level: int):
def __init__(self, compressors: list[str], zlib_compression_level: int):
self.compressors = compressors
self.zlib_compression_level = zlib_compression_level
def get_compression_context(
self, compressors: Optional[List[str]]
self, compressors: Optional[list[str]]
) -> Union[SnappyContext, ZlibContext, ZstdContext, None]:
if compressors:
chosen = compressors[0]

View File

@ -21,7 +21,6 @@ from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Iterable,
List,
@ -444,7 +443,7 @@ class Cursor(Generic[_DocumentType]):
def __query_spec(self) -> Mapping[str, Any]:
"""Get the spec to use for a query."""
operators: Dict[str, Any] = {}
operators: dict[str, Any] = {}
if self.__ordering:
operators["$orderby"] = self.__ordering
if self.__explain:
@ -884,7 +883,7 @@ class Cursor(Generic[_DocumentType]):
self.__ordering = helpers._index_document(keys)
return self
def distinct(self, key: str) -> List:
def distinct(self, key: str) -> list:
"""Get a list of distinct values for `key` among all documents
in the result set of this query.
@ -901,7 +900,7 @@ class Cursor(Generic[_DocumentType]):
.. seealso:: :meth:`pymongo.collection.Collection.distinct`
"""
options: Dict[str, Any] = {}
options: dict[str, Any] = {}
if self.__spec:
options["query"] = self.__spec
if self.__max_time_ms is not None:
@ -1017,11 +1016,11 @@ class Cursor(Generic[_DocumentType]):
# Avoid overwriting a filter argument that was given by the user
# when updating the spec.
spec: Dict[str, Any]
spec: dict[str, Any]
if self.__has_filter:
spec = dict(self.__spec)
else:
spec = cast(Dict, self.__spec)
spec = cast(dict, self.__spec)
spec["$where"] = code
self.__spec = spec
return self
@ -1234,7 +1233,7 @@ class Cursor(Generic[_DocumentType]):
return self.__id
@property
def address(self) -> Optional[Tuple[str, Any]]:
def address(self) -> Optional[tuple[str, Any]]:
"""The (host, port) of the server used, or None.
.. versionchanged:: 3.0
@ -1287,25 +1286,25 @@ class Cursor(Generic[_DocumentType]):
return self._clone(deepcopy=True)
@overload
def _deepcopy(self, x: Iterable, memo: Optional[Dict[int, Union[List, Dict]]] = None) -> List:
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list:
...
@overload
def _deepcopy(
self, x: SupportsItems, memo: Optional[Dict[int, Union[List, Dict]]] = None
) -> Dict:
self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None
) -> dict:
...
def _deepcopy(
self, x: Union[Iterable, SupportsItems], memo: Optional[Dict[int, Union[List, Dict]]] = None
) -> Union[List, Dict]:
self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None
) -> Union[list, dict]:
"""Deepcopy helper for the data dictionary or list.
Regular expressions cannot be deep copied but as they are immutable we
don't have to copy them when cloning.
"""
y: Union[List, Dict]
iterator: Iterable[Tuple[Any, Any]]
y: Union[list, dict]
iterator: Iterable[tuple[Any, Any]]
if not hasattr(x, "items"):
y, is_list, iterator = [], True, enumerate(x)
else:
@ -1356,7 +1355,7 @@ class RawBatchCursor(Cursor, Generic[_DocumentType]):
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[_DocumentOut]:
) -> list[_DocumentOut]:
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
if not legacy_response:
# OP_MSG returns firstBatch/nextBatch documents as a BSON array

View File

@ -18,6 +18,7 @@ PyMongo only attempts to spawn the mongocryptd daemon process when automatic
client-side field level encryption is enabled. See
:ref:`automatic-client-side-encryption` for more info.
"""
from __future__ import annotations
import os
import subprocess

View File

@ -19,9 +19,7 @@ from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Mapping,
MutableMapping,
NoReturn,
@ -695,12 +693,12 @@ class Database(common.BaseObject, Generic[_DocumentType]):
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: _ServerMode = ReadPreference.PRIMARY,
codec_options: CodecOptions[Dict[str, Any]] = DEFAULT_CODEC_OPTIONS,
codec_options: CodecOptions[dict[str, Any]] = DEFAULT_CODEC_OPTIONS,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
...
@overload
@ -729,13 +727,13 @@ class Database(common.BaseObject, Generic[_DocumentType]):
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: _ServerMode = ReadPreference.PRIMARY,
codec_options: Union[
CodecOptions[Dict[str, Any]], CodecOptions[_CodecDocumentType]
CodecOptions[dict[str, Any]], CodecOptions[_CodecDocumentType]
] = DEFAULT_CODEC_OPTIONS,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> Union[Dict[str, Any], _CodecDocumentType]:
) -> Union[dict[str, Any], _CodecDocumentType]:
"""Internal command helper."""
if isinstance(command, str):
command = SON([(command, value)])
@ -767,7 +765,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
...
@overload
@ -797,7 +795,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> Union[Dict[str, Any], _CodecDocumentType]:
) -> Union[dict[str, Any], _CodecDocumentType]:
"""Issue a MongoDB command.
Send command `command` to the database and return the
@ -1008,7 +1006,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
self,
command: Union[str, MutableMapping[str, Any]],
session: Optional[ClientSession] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Same as command but used for retryable read commands."""
read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
@ -1017,7 +1015,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> Dict[str, Any]:
) -> dict[str, Any]:
return self._command(
conn,
command,
@ -1106,7 +1104,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
filter: Optional[Mapping[str, Any]] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Get a list of all the collection names in this database.
For example, to list all non-system collections::
@ -1150,7 +1148,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
def _drop_helper(
self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None
) -> Dict[str, Any]:
) -> dict[str, Any]:
command = SON([("drop", name)])
if comment is not None:
command["comment"] = comment
@ -1172,7 +1170,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
encrypted_fields: Optional[Mapping[str, Any]] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Drop a collection.
:Parameters:
@ -1252,7 +1250,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession] = None,
background: Optional[bool] = None,
comment: Optional[Any] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Validate a collection.
Returns a dict of validation info. Raises CollectionInvalid if

View File

@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Advanced options for MongoDB drivers implemented on top of PyMongo."""
from __future__ import annotations
from collections import namedtuple
from typing import Optional

View File

@ -29,7 +29,6 @@ from typing import (
MutableMapping,
Optional,
Sequence,
Tuple,
)
try:
@ -594,7 +593,7 @@ class ClientEncryption(Generic[_DocumentType]):
kms_provider: Optional[str] = None,
master_key: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> Tuple[Collection[_DocumentType], Mapping[str, Any]]:
) -> tuple[Collection[_DocumentType], Mapping[str, Any]]:
"""Create a collection with encryptedFields.
.. warning::

View File

@ -15,7 +15,7 @@
"""Support for automatic client-side field level encryption."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
from typing import TYPE_CHECKING, Any, Mapping, Optional
try:
import pymongocrypt # noqa: F401
@ -45,7 +45,7 @@ class AutoEncryptionOpts:
mongocryptd_uri: str = "mongodb://localhost:27020",
mongocryptd_bypass_spawn: bool = False,
mongocryptd_spawn_path: str = "mongocryptd",
mongocryptd_spawn_args: Optional[List[str]] = None,
mongocryptd_spawn_args: Optional[list[str]] = None,
kms_tls_options: Optional[Mapping[str, Any]] = None,
crypt_shared_lib_path: Optional[str] = None,
crypt_shared_lib_required: bool = False,
@ -245,7 +245,7 @@ class RangeOpts:
self.precision = precision
@property
def document(self) -> Dict[str, Any]:
def document(self) -> dict[str, Any]:
doc = {}
for k, v in [
("sparsity", int64.Int64(self.sparsity)),

View File

@ -15,17 +15,7 @@
"""Exceptions raised by PyMongo."""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Union
from bson.errors import InvalidDocument
@ -138,7 +128,7 @@ class NetworkTimeout(AutoReconnect):
return True
def _format_detailed_error(message: str, details: Optional[Union[Mapping[str, Any], List]]) -> str:
def _format_detailed_error(message: str, details: Optional[Union[Mapping[str, Any], list]]) -> str:
if details is not None:
message = f"{message}, full error: {details}"
return message
@ -161,7 +151,7 @@ class NotPrimaryError(AutoReconnect):
"""
def __init__(
self, message: str = "", errors: Optional[Union[Mapping[str, Any], List]] = None
self, message: str = "", errors: Optional[Union[Mapping[str, Any], list]] = None
) -> None:
super().__init__(_format_detailed_error(message, errors), errors=errors)
@ -306,7 +296,7 @@ class BulkWriteError(OperationFailure):
def __init__(self, results: _DocumentOut) -> None:
super().__init__("batch op errors occurred", 65, results)
def __reduce__(self) -> Tuple[Any, Any]:
def __reduce__(self) -> tuple[Any, Any]:
return self.__class__, (self.details,)
@property

View File

@ -26,6 +26,8 @@ or
``MongoClient(event_listeners=[CommandLogger()])``
"""
from __future__ import annotations
import logging
from pymongo import monitoring

View File

@ -13,11 +13,12 @@
# limitations under the License.
"""Helpers for the 'hello' and legacy hello commands."""
from __future__ import annotations
import copy
import datetime
import itertools
from typing import Any, Generic, List, Mapping, Optional, Set, Tuple
from typing import Any, Generic, Mapping, Optional
from bson.objectid import ObjectId
from pymongo import common
@ -95,7 +96,7 @@ class Hello(Generic[_DocumentType]):
return self._server_type
@property
def all_hosts(self) -> Set[Tuple[str, int]]:
def all_hosts(self) -> set[tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return set(
map(
@ -114,7 +115,7 @@ class Hello(Generic[_DocumentType]):
return self._doc.get("tags", {})
@property
def primary(self) -> Optional[Tuple[str, int]]:
def primary(self) -> Optional[tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
if self._doc.get("primary"):
return common.partition_node(self._doc["primary"])
@ -171,7 +172,7 @@ class Hello(Generic[_DocumentType]):
return self._is_readable
@property
def me(self) -> Optional[Tuple[str, int]]:
def me(self) -> Optional[tuple[str, int]]:
me = self._doc.get("me")
if me:
return common.clean_node(me)
@ -182,11 +183,11 @@ class Hello(Generic[_DocumentType]):
return self._doc.get("lastWrite", {}).get("lastWriteDate")
@property
def compressors(self) -> Optional[List[str]]:
def compressors(self) -> Optional[list[str]]:
return self._doc.get("compression")
@property
def sasl_supported_mechs(self) -> List[str]:
def sasl_supported_mechs(self) -> list[str]:
"""Supported authentication mechanisms for the current user.
For example::

View File

@ -24,12 +24,10 @@ from typing import (
Callable,
Container,
Iterable,
List,
Mapping,
NoReturn,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
@ -100,7 +98,7 @@ def _gen_index_name(keys: _IndexList) -> str:
def _index_list(
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
) -> Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]:
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
"""Helper to generate a list of (key, direction) pairs.
Takes such a list, or a single key, or a single key and direction.
@ -116,7 +114,7 @@ def _index_list(
return list(key_or_list)
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
values: List[Tuple[str, int]] = []
values: list[tuple[str, int]] = []
for item in key_or_list:
if isinstance(item, str):
item = (item, ASCENDING)
@ -223,7 +221,7 @@ def _check_command_response(
raise OperationFailure(errmsg, code, response, max_wire_version)
def _raise_last_write_error(write_errors: List[Any]) -> NoReturn:
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
# If the last batch had multiple errors only report
# the last error to emulate continue_on_error.
error = write_errors[-1]

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import threading

View File

@ -29,14 +29,11 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Tuple,
Union,
cast,
)
@ -137,14 +134,14 @@ def _maybe_add_read_preference(
return spec
def _convert_exception(exception: Exception) -> Dict[str, Any]:
def _convert_exception(exception: Exception) -> dict[str, Any]:
"""Convert an Exception into a failure document for publishing."""
return {"errmsg": str(exception), "errtype": exception.__class__.__name__}
def _convert_write_result(
operation: str, command: Mapping[str, Any], result: Mapping[str, Any]
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Convert a legacy write result to write command format."""
# Based on _merge_legacy from bulk.py
affected = result.get("n", 0)
@ -340,7 +337,7 @@ class _Query:
self.client = client
self.allow_disk_use = allow_disk_use
self.name = "find"
self._as_command: Optional[Tuple[SON[str, Any], str]] = None
self._as_command: Optional[tuple[SON[str, Any], str]] = None
self.exhaust = exhaust
def reset(self) -> None:
@ -367,7 +364,7 @@ class _Query:
def as_command(
self, conn: Connection, apply_timeout: bool = False
) -> Tuple[SON[str, Any], str]:
) -> tuple[SON[str, Any], str]:
"""Return a find command document for this query."""
# We use the command twice: on the wire and for command monitoring.
# Generate it once, for speed and to avoid repeating side-effects.
@ -411,7 +408,7 @@ class _Query:
def get_message(
self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False
) -> Tuple[int, bytes, int]:
) -> tuple[int, bytes, int]:
"""Get a query message, possibly setting the secondaryOk bit."""
# Use the read_preference decided by _socket_from_server.
self.read_preference = read_preference
@ -508,7 +505,7 @@ class _GetMore:
self.client = client
self.max_await_time_ms = max_await_time_ms
self.conn_mgr = conn_mgr
self._as_command: Optional[Tuple[SON[str, Any], str]] = None
self._as_command: Optional[tuple[SON[str, Any], str]] = None
self.exhaust = exhaust
self.comment = comment
@ -531,7 +528,7 @@ class _GetMore:
def as_command(
self, conn: Connection, apply_timeout: bool = False
) -> Tuple[SON[str, Any], str]:
) -> tuple[SON[str, Any], str]:
"""Return a getMore command document for this query."""
# See _Query.as_command for an explanation of this caching.
if self._as_command is not None:
@ -561,7 +558,7 @@ class _GetMore:
def get_message(
self, dummy0: Any, conn: Connection, use_cmd: bool = False
) -> Union[Tuple[int, bytes, int], Tuple[int, bytes]]:
) -> Union[tuple[int, bytes, int], tuple[int, bytes]]:
"""Get a getmore message."""
ns = self.namespace()
ctx = conn.compression_context
@ -639,7 +636,7 @@ _COMPRESSION_HEADER_SIZE = 25
def _compress(
operation: int, data: bytes, ctx: Union[SnappyContext, ZlibContext, ZstdContext]
) -> Tuple[int, bytes]:
) -> tuple[int, bytes]:
"""Takes message data, compresses it, and adds an OP_COMPRESSED header."""
compressed = ctx.compress(data)
request_id = _randint()
@ -659,7 +656,7 @@ def _compress(
_pack_header = struct.Struct("<iiii").pack
def __pack_message(operation: int, data: bytes) -> Tuple[int, bytes]:
def __pack_message(operation: int, data: bytes) -> tuple[int, bytes]:
"""Takes message data and adds a message header based on the operation.
Returns the resultant message string.
@ -678,9 +675,9 @@ def _op_msg_no_header(
flags: int,
command: Mapping[str, Any],
identifier: str,
docs: Optional[List[Mapping[str, Any]]],
docs: Optional[list[Mapping[str, Any]]],
opts: CodecOptions,
) -> Tuple[bytes, int, int]:
) -> tuple[bytes, int, int]:
"""Get a OP_MSG message.
Note: this method handles multiple documents in a type one payload but
@ -710,10 +707,10 @@ def _op_msg_compressed(
flags: int,
command: Mapping[str, Any],
identifier: str,
docs: Optional[List[Mapping[str, Any]]],
docs: Optional[list[Mapping[str, Any]]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> Tuple[int, bytes, int, int]:
) -> tuple[int, bytes, int, int]:
"""Internal OP_MSG message helper."""
msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
rid, msg = _compress(2013, msg, ctx)
@ -724,9 +721,9 @@ def _op_msg_uncompressed(
flags: int,
command: Mapping[str, Any],
identifier: str,
docs: Optional[List[Mapping[str, Any]]],
docs: Optional[list[Mapping[str, Any]]],
opts: CodecOptions,
) -> Tuple[int, bytes, int, int]:
) -> tuple[int, bytes, int, int]:
"""Internal compressed OP_MSG message helper."""
data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
request_id, op_message = __pack_message(2013, data)
@ -744,7 +741,7 @@ def _op_msg(
read_preference: Optional[_ServerMode],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> Tuple[int, bytes, int, int]:
) -> tuple[int, bytes, int, int]:
"""Get a OP_MSG message."""
command["$db"] = dbname
# getMore commands do not send $readPreference.
@ -777,7 +774,7 @@ def _query_impl(
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
) -> Tuple[bytes, int]:
) -> tuple[bytes, int]:
"""Get an OP_QUERY message."""
encoded = _dict_to_bson(query, False, opts)
if field_selector:
@ -809,7 +806,7 @@ def _query_compressed(
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> Tuple[int, bytes, int]:
) -> tuple[int, bytes, int]:
"""Internal compressed query message helper."""
op_query, max_bson_size = _query_impl(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
@ -826,7 +823,7 @@ def _query_uncompressed(
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
) -> Tuple[int, bytes, int]:
) -> tuple[int, bytes, int]:
"""Internal query message helper."""
op_query, max_bson_size = _query_impl(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
@ -848,7 +845,7 @@ def _query(
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> Tuple[int, bytes, int]:
) -> tuple[int, bytes, int]:
"""Get a **query** message."""
if ctx:
return _query_compressed(
@ -879,14 +876,14 @@ def _get_more_compressed(
num_to_return: int,
cursor_id: int,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> Tuple[int, bytes]:
) -> tuple[int, bytes]:
"""Internal compressed getMore message helper."""
return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx)
def _get_more_uncompressed(
collection_name: str, num_to_return: int, cursor_id: int
) -> Tuple[int, bytes]:
) -> tuple[int, bytes]:
"""Internal getMore message helper."""
return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id))
@ -900,7 +897,7 @@ def _get_more(
num_to_return: int,
cursor_id: int,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> Tuple[int, bytes]:
) -> tuple[int, bytes]:
"""Get a **getMore** message."""
if ctx:
return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx)
@ -950,8 +947,8 @@ class _BulkWriteContext:
self.codec = codec
def __batch_command(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]]
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
namespace = self.db_name + ".$cmd"
request_id, msg, to_send = _do_batched_op_msg(
namespace, self.op_type, cmd, docs, self.codec, self
@ -961,16 +958,16 @@ class _BulkWriteContext:
return request_id, msg, to_send
def execute(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient
) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]:
request_id, msg, to_send = self.__batch_command(cmd, docs)
result = self.write_command(cmd, request_id, msg, to_send)
client._process_response(result, self.session)
return result, to_send
def execute_unack(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
) -> List[Mapping[str, Any]]:
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient
) -> list[Mapping[str, Any]]:
request_id, msg, to_send = self.__batch_command(cmd, docs)
# Though this isn't strictly a "legacy" write, the helper
# handles publishing commands and sending our message
@ -1009,7 +1006,7 @@ class _BulkWriteContext:
request_id: int,
msg: bytes,
max_doc_size: int,
docs: List[Mapping[str, Any]],
docs: list[Mapping[str, Any]],
) -> Optional[Mapping[str, Any]]:
"""A proxy for Connection.unack_write that handles event publishing."""
if self.publish:
@ -1049,8 +1046,8 @@ class _BulkWriteContext:
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
docs: List[Mapping[str, Any]],
) -> Dict[str, Any]:
docs: list[Mapping[str, Any]],
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
if self.publish:
assert self.start_time is not None
@ -1076,7 +1073,7 @@ class _BulkWriteContext:
return reply
def _start(
self, cmd: MutableMapping[str, Any], request_id: int, docs: List[Mapping[str, Any]]
self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]]
) -> MutableMapping[str, Any]:
"""Publish a CommandStartedEvent."""
cmd[self.field] = docs
@ -1126,8 +1123,8 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
__slots__ = ()
def __batch_command(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
) -> Tuple[MutableMapping[str, Any], List[Mapping[str, Any]]]:
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]]
) -> tuple[MutableMapping[str, Any], list[Mapping[str, Any]]]:
namespace = self.db_name + ".$cmd"
msg, to_send = _encode_batched_write_command(
namespace, self.op_type, cmd, docs, self.codec, self
@ -1141,8 +1138,8 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
return outgoing, to_send
def execute(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient
) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]:
batched_cmd, to_send = self.__batch_command(cmd, docs)
result: Mapping[str, Any] = self.conn.command(
self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client
@ -1150,8 +1147,8 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
return result, to_send
def execute_unack(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
) -> List[Mapping[str, Any]]:
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient
) -> list[Mapping[str, Any]]:
batched_cmd, to_send = self.__batch_command(cmd, docs)
self.conn.command(
self.db_name,
@ -1196,12 +1193,12 @@ _OP_MSG_MAP = {
def _batched_op_msg_impl(
operation: int,
command: Mapping[str, Any],
docs: List[Mapping[str, Any]],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
buf: _BytesIO,
) -> Tuple[List[Mapping[str, Any]], int]:
) -> tuple[list[Mapping[str, Any]], int]:
"""Create a batched OP_MSG write."""
max_bson_size = ctx.max_bson_size
max_write_batch_size = ctx.max_write_batch_size
@ -1264,11 +1261,11 @@ def _batched_op_msg_impl(
def _encode_batched_op_msg(
operation: int,
command: Mapping[str, Any],
docs: List[Mapping[str, Any]],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[bytes, List[Mapping[str, Any]]]:
) -> tuple[bytes, list[Mapping[str, Any]]]:
"""Encode the next batched insert, update, or delete operation
as OP_MSG.
"""
@ -1285,11 +1282,11 @@ if _use_c:
def _batched_op_msg_compressed(
operation: int,
command: Mapping[str, Any],
docs: List[Mapping[str, Any]],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""Create the next batched insert, update, or delete operation
with OP_MSG, compressed.
"""
@ -1303,11 +1300,11 @@ def _batched_op_msg_compressed(
def _batched_op_msg(
operation: int,
command: Mapping[str, Any],
docs: List[Mapping[str, Any]],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""OP_MSG implementation entry point."""
buf = _BytesIO()
@ -1336,10 +1333,10 @@ def _do_batched_op_msg(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: List[Mapping[str, Any]],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""Create the next batched insert, update, or delete operation
using OP_MSG.
"""
@ -1360,10 +1357,10 @@ def _encode_batched_write_command(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: List[Mapping[str, Any]],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[bytes, List[Mapping[str, Any]]]:
) -> tuple[bytes, list[Mapping[str, Any]]]:
"""Encode the next batched insert, update, or delete command."""
buf = _BytesIO()
@ -1379,11 +1376,11 @@ def _batched_write_command_impl(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: List[Mapping[str, Any]],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
buf: _BytesIO,
) -> Tuple[List[Mapping[str, Any]], int]:
) -> tuple[list[Mapping[str, Any]], int]:
"""Create a batched OP_QUERY write command."""
max_bson_size = ctx.max_bson_size
max_write_batch_size = ctx.max_write_batch_size
@ -1468,7 +1465,7 @@ class _OpReply:
def raw_response(
self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None
) -> List[bytes]:
) -> list[bytes]:
"""Check the response header from the database, without decoding BSON.
Check the response for errors and unpack.
@ -1517,7 +1514,7 @@ class _OpReply:
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Unpack a response from the database and decode the BSON document(s).
Check the response for errors and unpack, returning a dictionary
@ -1541,7 +1538,7 @@ class _OpReply:
return bson.decode_all(self.documents, codec_options)
return bson._decode_all_selective(self.documents, codec_options, user_fields)
def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]:
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
"""Unpack a command response."""
docs = self.unpack_response(codec_options=codec_options)
assert self.number_returned == 1
@ -1588,7 +1585,7 @@ class _OpMsg:
self,
cursor_id: Optional[int] = None,
user_fields: Optional[Mapping[str, Any]] = {}, # noqa: B006
) -> List[Mapping[str, Any]]:
) -> list[Mapping[str, Any]]:
"""
cursor_id is ignored
user_fields is used to determine which fields must not be decoded
@ -1604,7 +1601,7 @@ class _OpMsg:
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Unpack a OP_MSG command response.
:Parameters:
@ -1619,7 +1616,7 @@ class _OpMsg:
assert not legacy_response
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]:
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
"""Unpack a command response."""
return self.unpack_response(codec_options=codec_options)[0]
@ -1652,7 +1649,7 @@ class _OpMsg:
return cls(flags, payload_document)
_UNPACK_REPLY: Dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
_OpReply.OP_CODE: _OpReply.unpack,
_OpMsg.OP_CODE: _OpMsg.unpack,
}

View File

@ -41,18 +41,14 @@ from typing import (
Any,
Callable,
ContextManager,
Dict,
FrozenSet,
Generic,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
@ -716,7 +712,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
client.__my_database__
"""
doc_class = document_class or dict
self.__init_kwargs: Dict[str, Any] = {
self.__init_kwargs: dict[str, Any] = {
"host": host,
"port": port,
"document_class": doc_class,
@ -824,7 +820,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self.__default_database_name = dbase
self.__lock = _create_lock()
self.__kill_cursors_queue: List = []
self.__kill_cursors_queue: list = []
self._event_listeners = options.pool_options._event_listeners
super().__init__(
@ -1068,7 +1064,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return self._topology.description
@property
def address(self) -> Optional[Tuple[str, int]]:
def address(self) -> Optional[tuple[str, int]]:
"""(host, port) of the current standalone, primary, or mongos, or None.
Accessing :attr:`address` raises :exc:`~.errors.InvalidOperation` if
@ -1100,7 +1096,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return self._server_property("address")
@property
def primary(self) -> Optional[Tuple[str, int]]:
def primary(self) -> Optional[tuple[str, int]]:
"""The (host, port) of the current primary of the replica set.
Returns ``None`` if this client is not connected to a replica set,
@ -1113,7 +1109,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return self._topology.get_primary() # type: ignore[return-value]
@property
def secondaries(self) -> Set[_Address]:
def secondaries(self) -> set[_Address]:
"""The secondary members known to this client.
A sequence of (host, port) pairs. Empty if this client is not
@ -1126,7 +1122,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return self._topology.get_secondaries()
@property
def arbiters(self) -> Set[_Address]:
def arbiters(self) -> set[_Address]:
"""Arbiters in the replica set.
A sequence of (host, port) pairs. Empty if this client is not
@ -1179,7 +1175,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
"""
return self.__options
def _end_sessions(self, session_ids: List[_ServerSession]) -> None:
def _end_sessions(self, session_ids: list[_ServerSession]) -> None:
"""Send endSessions command(s) with the given session ids."""
try:
# Use Connection.command directly to avoid implicitly creating
@ -1313,7 +1309,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
@contextlib.contextmanager
def _conn_from_server(
self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession]
) -> Iterator[Tuple[Connection, _ServerMode]]:
) -> Iterator[tuple[Connection, _ServerMode]]:
assert read_preference is not None, "read_preference must not be None"
# Get a connection for a server matching the read preference, and yield
# conn with the effective read preference. The Server Selection
@ -1337,7 +1333,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _conn_for_reads(
self, read_preference: _ServerMode, session: Optional[ClientSession]
) -> ContextManager[Tuple[Connection, _ServerMode]]:
) -> ContextManager[tuple[Connection, _ServerMode]]:
assert read_preference is not None, "read_preference must not be None"
_ = self._get_topology()
server = self._select_server(read_preference, session)
@ -1872,7 +1868,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if session is not None:
session._process_response(reply)
def server_info(self, session: Optional[client_session.ClientSession] = None) -> Dict[str, Any]:
def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]:
"""Get information about the MongoDB server we're connected to.
:Parameters:
@ -1894,7 +1890,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
session: Optional[client_session.ClientSession] = None,
comment: Optional[Any] = None,
**kwargs: Any,
) -> CommandCursor[Dict[str, Any]]:
) -> CommandCursor[dict[str, Any]]:
"""Get a cursor over the databases of the connected server.
:Parameters:
@ -1932,7 +1928,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self,
session: Optional[client_session.ClientSession] = None,
comment: Optional[Any] = None,
) -> List[str]:
) -> list[str]:
"""Get a list of the names of all databases on the connected server.
:Parameters:

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import atexit
import time
import weakref
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum
@ -274,7 +274,7 @@ class Monitor(MonitorBase):
)
return sd
def _check_with_socket(self, conn: Connection) -> Tuple[Hello, float]:
def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]:
"""Return (Hello, round_trip_time).
Can raise ConnectionFailure or OperationFailure.
@ -326,7 +326,7 @@ class SrvMonitor(MonitorBase):
# Topology was garbage-collected.
self.close()
def _get_seedlist(self) -> Optional[List[Tuple[str, Any]]]:
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
"""Poll SRV records for a seedlist.
Returns a list of ServerDescriptions.
@ -383,7 +383,7 @@ class _RttMonitor(MonitorBase):
self._moving_average.add_sample(sample)
self._moving_min.add_sample(sample)
def get(self) -> Tuple[Optional[float], float]:
def get(self) -> tuple[Optional[float], float]:
"""Get the calculated average, or None if no samples yet and the min."""
with self._lock:
return self._moving_average.get(), self._moving_min.get()

View File

@ -187,7 +187,7 @@ from __future__ import annotations
import datetime
from collections import abc, namedtuple
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
from bson.objectid import ObjectId
from pymongo.hello import Hello, HelloCompat
@ -812,12 +812,12 @@ class PoolCreatedEvent(_PoolEvent):
__slots__ = ("__options",)
def __init__(self, address: _Address, options: Dict[str, Any]) -> None:
def __init__(self, address: _Address, options: dict[str, Any]) -> None:
super().__init__(address)
self.__options = options
@property
def options(self) -> Dict[str, Any]:
def options(self) -> dict[str, Any]:
"""Any non-default pool options that were set on this Connection Pool."""
return self.__options
@ -1439,7 +1439,7 @@ class _EventListeners:
"""Are any ConnectionPoolListener instances registered?"""
return self.__enabled_for_cmap
def event_listeners(self) -> List[_EventListeners]:
def event_listeners(self) -> list[_EventListeners]:
"""List of registered event listeners."""
return (
self.__command_listeners
@ -1712,7 +1712,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_pool_created(self, address: _Address, options: Dict[str, Any]) -> None:
def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None:
"""Publish a :class:`PoolCreatedEvent` to all pool listeners."""
event = PoolCreatedEvent(address, options)
for subscriber in self.__cmap_listeners:

View File

@ -19,7 +19,7 @@ from __future__ import annotations
from collections import namedtuple
from datetime import datetime as _datetime
from datetime import timezone
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
from pymongo.lock import _create_lock
@ -36,7 +36,7 @@ class _OCSPCache:
)
def __init__(self) -> None:
self._data: Dict[Any, OCSPResponse] = {}
self._data: dict[Any, OCSPResponse] = {}
# Hold this lock when accessing _data.
self._lock = _create_lock()

View File

@ -19,7 +19,7 @@ import logging as _logging
import re as _re
from datetime import datetime as _datetime
from datetime import timezone
from typing import TYPE_CHECKING, Iterable, List, Optional, Type, Union
from typing import TYPE_CHECKING, Iterable, Optional, Type, Union
from cryptography.exceptions import InvalidSignature as _InvalidSignature
from cryptography.hazmat.backends import default_backend as _default_backend
@ -101,7 +101,7 @@ _CERT_REGEX = _re.compile(
)
def _load_trusted_ca_certs(cafile: str) -> List[Certificate]:
def _load_trusted_ca_certs(cafile: str) -> list[Certificate]:
"""Parse the tlsCAFile into a list of certificates."""
with open(cafile, "rb") as f:
data = f.read()
@ -115,7 +115,7 @@ def _load_trusted_ca_certs(cafile: str) -> List[Certificate]:
def _get_issuer_cert(
cert: Certificate, chain: Iterable[Certificate], trusted_ca_certs: Optional[List[Certificate]]
cert: Certificate, chain: Iterable[Certificate], trusted_ca_certs: Optional[list[Certificate]]
) -> Optional[Certificate]:
issuer_name = cert.issuer
for candidate in chain:
@ -187,7 +187,7 @@ def _public_key_hash(cert: Certificate) -> bytes:
def _get_certs_by_key_hash(
certificates: Iterable[Certificate], issuer: Certificate, responder_key_hash: Optional[bytes]
) -> List[Certificate]:
) -> list[Certificate]:
return [
cert
for cert in certificates
@ -197,7 +197,7 @@ def _get_certs_by_key_hash(
def _get_certs_by_name(
certificates: Iterable[Certificate], issuer: Certificate, responder_name: Optional[Name]
) -> List[Certificate]:
) -> list[Certificate]:
return [
cert
for cert in certificates

View File

@ -18,9 +18,7 @@ from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Mapping,
Optional,
Sequence,
@ -293,7 +291,7 @@ class _UpdateOp:
doc: Union[Mapping[str, Any], _Pipeline],
upsert: bool,
collation: Optional[_CollationIn],
array_filters: Optional[List[Mapping[str, Any]]],
array_filters: Optional[list[Mapping[str, Any]]],
hint: Optional[_IndexKeyHint],
):
if filter is not None:
@ -355,7 +353,7 @@ class UpdateOne(_UpdateOp):
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[List[Mapping[str, Any]]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Represents an update_one operation.
@ -413,7 +411,7 @@ class UpdateMany(_UpdateOp):
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
array_filters: Optional[List[Mapping[str, Any]]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
) -> None:
"""Create an UpdateMany instance.
@ -537,7 +535,7 @@ class IndexModel:
self.__document["collation"] = collation
@property
def document(self) -> Dict[str, Any]:
def document(self) -> dict[str, Any]:
"""An index document suitable for passing to the createIndexes
command.
"""

View File

@ -28,16 +28,12 @@ import weakref
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
Union,
)
@ -305,8 +301,8 @@ def _getenv_int(key: str) -> Optional[int]:
return None
def _metadata_env() -> Dict[str, Any]:
env: Dict[str, Any] = {}
def _metadata_env() -> dict[str, Any]:
env: dict[str, Any] = {}
# Skip if multiple (or no) envs are matched.
if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1:
return env
@ -520,7 +516,7 @@ class PoolOptions:
return self.__credentials
@property
def non_default_options(self) -> Dict[str, Any]:
def non_default_options(self) -> dict[str, Any]:
"""The non-default options this pool was created with.
Added for CMAP's :class:`PoolCreatedEvent`.
@ -668,7 +664,7 @@ class Connection:
"""
def __init__(
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: Tuple[str, int], id: int
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int
):
self.pool_ref = weakref.ref(pool)
self.conn = conn
@ -693,7 +689,7 @@ class Connection:
self.socket_checker: SocketChecker = SocketChecker()
self.oidc_token_gen_id: Optional[int] = None
# Support for mechanism negotiation on the initial handshake.
self.negotiated_mechs: Optional[List[str]] = None
self.negotiated_mechs: Optional[list[str]] = None
self.auth_ctx: Optional[_AuthContext] = None
# The pool's generation changes with each reset() so we can close
@ -774,7 +770,7 @@ class Connection:
else:
return SON([(HelloCompat.LEGACY_CMD, 1), ("helloOk", True)])
def hello(self) -> Hello[Dict[str, Any]]:
def hello(self) -> Hello[dict[str, Any]]:
return self._hello(None, None, None)
def _hello(
@ -782,7 +778,7 @@ class Connection:
cluster_time: Optional[ClusterTime],
topology_version: Optional[Any],
heartbeat_frequency: Optional[int],
) -> Hello[Dict[str, Any]]:
) -> Hello[dict[str, Any]]:
cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake
awaitable = False
@ -860,7 +856,7 @@ class Connection:
self.generation = self.pool_gen.get(self.service_id)
return hello
def _next_reply(self) -> Dict[str, Any]:
def _next_reply(self) -> dict[str, Any]:
reply = self.receive_message(None)
self.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response()
@ -887,7 +883,7 @@ class Connection:
publish_events: bool = True,
user_fields: Optional[Mapping[str, Any]] = None,
exhaust_allowed: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Execute a command or raise an error.
:Parameters:
@ -1007,7 +1003,7 @@ class Connection:
def write_command(
self, request_id: int, msg: bytes, codec_options: CodecOptions
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Send "insert" etc. command, returning response as a dict.
Can raise ConnectionFailure or OperationFailure.
@ -1280,7 +1276,7 @@ class _PoolClosedError(PyMongoError):
class _PoolGeneration:
def __init__(self) -> None:
# Maps service_id to generation.
self._generations: Dict[ObjectId, int] = collections.defaultdict(int)
self._generations: dict[ObjectId, int] = collections.defaultdict(int)
# Overall pool generation.
self._generation = 0
@ -1381,7 +1377,7 @@ class Pool:
# Retain references to pinned connections to prevent the CPython GC
# from thinking that a cursor's pinned connection can be GC'd when the
# cursor is GC'd (see PYTHON-2751).
self.__pinned_sockets: Set[Connection] = set()
self.__pinned_sockets: set[Connection] = set()
self.ncursors = 0
self.ntxns = 0

View File

@ -23,7 +23,7 @@ import sys as _sys
import time as _time
from errno import EINTR as _EINTR
from ipaddress import ip_address as _ip_address
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate
from OpenSSL import SSL as _SSL
@ -184,7 +184,7 @@ class _CallbackData:
"""Data class which is passed to the OCSP callback."""
def __init__(self) -> None:
self.trusted_ca_certs: Optional[List[Certificate]] = None
self.trusted_ca_certs: Optional[list[Certificate]] = None
self.check_ocsp_endpoint: Optional[bool] = None
self.ocsp_response_cache = _OCSPCache()

View File

@ -13,8 +13,9 @@
# limitations under the License.
"""Tools for working with read concerns."""
from __future__ import annotations
from typing import Any, Dict, Optional
from typing import Any, Optional
class ReadConcern:
@ -50,7 +51,7 @@ class ReadConcern:
return self.level is None or self.level == "local"
@property
def document(self) -> Dict[str, Any]:
def document(self) -> dict[str, Any]:
"""The document representation of this read concern.
.. note::

View File

@ -17,7 +17,7 @@
from __future__ import annotations
from collections import abc
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
from pymongo import max_staleness_selectors
from pymongo.errors import ConfigurationError
@ -131,9 +131,9 @@ class _ServerMode:
return self.__mongos_mode
@property
def document(self) -> Dict[str, Any]:
def document(self) -> dict[str, Any]:
"""Read preference as a document."""
doc: Dict[str, Any] = {"mode": self.__mongos_mode}
doc: dict[str, Any] = {"mode": self.__mongos_mode}
if self.__tag_sets not in (None, [{}]):
doc["tags"] = self.__tag_sets
if self.__max_staleness != -1:
@ -235,7 +235,7 @@ class _ServerMode:
def __ne__(self, other: Any) -> bool:
return not self == other
def __getstate__(self) -> Dict[str, Any]:
def __getstate__(self) -> dict[str, Any]:
"""Return value of object for pickling.
Needed explicitly because __slots__() defined.

View File

@ -15,7 +15,7 @@
"""Represent a response from the server."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
if TYPE_CHECKING:
from datetime import timedelta
@ -95,7 +95,7 @@ class PinnedResponse(Response):
request_id: int,
duration: Optional[timedelta],
from_command: bool,
docs: List[_DocumentOut],
docs: list[_DocumentOut],
more_to_come: bool,
):
"""Represent a response to an exhaust cursor's initial query.

View File

@ -13,7 +13,9 @@
# limitations under the License.
"""Result class definitions."""
from typing import Any, Dict, List, Mapping, Optional, cast
from __future__ import annotations
from typing import Any, Mapping, Optional, cast
from pymongo.errors import InvalidOperation
@ -76,12 +78,12 @@ class InsertManyResult(_WriteResult):
__slots__ = ("__inserted_ids",)
def __init__(self, inserted_ids: List[Any], acknowledged: bool) -> None:
def __init__(self, inserted_ids: list[Any], acknowledged: bool) -> None:
self.__inserted_ids = inserted_ids
super().__init__(acknowledged)
@property
def inserted_ids(self) -> List:
def inserted_ids(self) -> list:
"""A list of _ids of the inserted documents, in the order provided.
.. note:: If ``False`` is passed for the `ordered` parameter to
@ -163,7 +165,7 @@ class BulkWriteResult(_WriteResult):
__slots__ = ("__bulk_api_result",)
def __init__(self, bulk_api_result: Dict[str, Any], acknowledged: bool) -> None:
def __init__(self, bulk_api_result: dict[str, Any], acknowledged: bool) -> None:
"""Create a BulkWriteResult instance.
:Parameters:
@ -176,7 +178,7 @@ class BulkWriteResult(_WriteResult):
super().__init__(acknowledged)
@property
def bulk_api_result(self) -> Dict[str, Any]:
def bulk_api_result(self) -> dict[str, Any]:
"""The raw bulk API result."""
return self.__bulk_api_result
@ -211,7 +213,7 @@ class BulkWriteResult(_WriteResult):
return cast(int, self.__bulk_api_result.get("nUpserted"))
@property
def upserted_ids(self) -> Optional[Dict[int, Any]]:
def upserted_ids(self) -> Optional[dict[int, Any]]:
"""A map of operation index to the _id of the upserted document."""
self._raise_if_unacknowledged("upserted_ids")
if self.__bulk_api_result:

View File

@ -13,6 +13,8 @@
# limitations under the License.
"""An implementation of RFC4013 SASLprep."""
from __future__ import annotations
from typing import Any, Optional
try:

View File

@ -16,16 +16,7 @@
from __future__ import annotations
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
List,
Optional,
Tuple,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union
from bson import _decode_all_selective
from pymongo.errors import NotPrimaryError, OperationFailure
@ -110,7 +101,7 @@ class Server:
operation: Union[_Query, _GetMore],
read_preference: _ServerMode,
listeners: Optional[_EventListeners],
unpack_res: Callable[..., List[_DocumentOut]],
unpack_res: Callable[..., list[_DocumentOut]],
) -> Response:
"""Run a _Query or _GetMore operation and return a Response object.
@ -275,8 +266,8 @@ class Server:
return self._pool
def _split_message(
self, message: Union[Tuple[int, Any], Tuple[int, Any, int]]
) -> Tuple[int, Any, int]:
self, message: Union[tuple[int, Any], tuple[int, Any, int]]
) -> tuple[int, Any, int]:
"""Return request_id, data, max_doc_size.
:Parameters:

View File

@ -13,10 +13,11 @@
# limitations under the License.
"""Represent one server the driver is connected to."""
from __future__ import annotations
import time
import warnings
from typing import Any, Dict, Mapping, Optional, Set, Tuple
from typing import Any, Mapping, Optional
from bson import EPOCH_NAIVE
from bson.objectid import ObjectId
@ -129,7 +130,7 @@ class ServerDescription:
return SERVER_TYPE._fields[self._server_type]
@property
def all_hosts(self) -> Set[Tuple[str, int]]:
def all_hosts(self) -> set[tuple[str, int]]:
"""List of hosts, passives, and arbiters known to this server."""
return self._all_hosts
@ -143,7 +144,7 @@ class ServerDescription:
return self._replica_set_name
@property
def primary(self) -> Optional[Tuple[str, int]]:
def primary(self) -> Optional[tuple[str, int]]:
"""This server's opinion about who the primary is, or None."""
return self._primary
@ -180,7 +181,7 @@ class ServerDescription:
return self._cluster_time
@property
def election_tuple(self) -> Tuple[Optional[int], Optional[ObjectId]]:
def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]:
warnings.warn(
"'election_tuple' is deprecated, use 'set_version' and 'election_id' instead",
DeprecationWarning,
@ -189,7 +190,7 @@ class ServerDescription:
return self._set_version, self._election_id
@property
def me(self) -> Optional[Tuple[str, int]]:
def me(self) -> Optional[tuple[str, int]]:
return self._me
@property
@ -297,4 +298,4 @@ class ServerDescription:
)
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time: Dict = {}
_host_to_round_trip_time: dict = {}

View File

@ -15,7 +15,7 @@
"""Criteria to select some ServerDescriptions from a TopologyDescription."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, TypeVar, cast
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast
from pymongo.server_type import SERVER_TYPE
@ -51,7 +51,7 @@ class Selection:
def __init__(
self,
topology_description: TopologyDescription,
server_descriptions: List[ServerDescription],
server_descriptions: list[ServerDescription],
common_wire_version: Optional[int],
primary: Optional[ServerDescription],
):
@ -60,7 +60,7 @@ class Selection:
self.primary = primary
self.common_wire_version = common_wire_version
def with_server_descriptions(self, server_descriptions: List[ServerDescription]) -> Selection:
def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection:
return Selection(
self.topology_description, server_descriptions, self.common_wire_version, self.primary
)

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Type codes for MongoDB servers."""
from __future__ import annotations
from typing import NamedTuple

View File

@ -13,10 +13,11 @@
# permissions and limitations under the License.
"""Represent MongoClient's configuration."""
from __future__ import annotations
import threading
import traceback
from typing import Any, Collection, Dict, Optional, Tuple, Type, Union
from typing import Any, Collection, Optional, Type, Union
from bson.objectid import ObjectId
from pymongo import common, monitor, pool
@ -30,7 +31,7 @@ from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector
class TopologySettings:
def __init__(
self,
seeds: Optional[Collection[Tuple[str, int]]] = None,
seeds: Optional[Collection[tuple[str, int]]] = None,
replica_set_name: Optional[str] = None,
pool_class: Optional[Type[Pool]] = None,
pool_options: Optional[PoolOptions] = None,
@ -56,7 +57,7 @@ class TopologySettings:
% (common.MIN_HEARTBEAT_INTERVAL * 1000,)
)
self._seeds: Collection[Tuple[str, int]] = seeds or [("localhost", 27017)]
self._seeds: Collection[tuple[str, int]] = seeds or [("localhost", 27017)]
self._replica_set_name = replica_set_name
self._pool_class: Type[Pool] = pool_class or pool.Pool
self._pool_options: PoolOptions = pool_options or PoolOptions()
@ -78,7 +79,7 @@ class TopologySettings:
self._stack = "".join(traceback.format_stack())
@property
def seeds(self) -> Collection[Tuple[str, int]]:
def seeds(self) -> Collection[tuple[str, int]]:
"""List of server addresses."""
return self._seeds
@ -155,6 +156,6 @@ class TopologySettings:
else:
return TOPOLOGY_TYPE.Unknown
def get_server_descriptions(self) -> Dict[Union[Tuple[str, int], Any], ServerDescription]:
def get_server_descriptions(self) -> dict[Union[tuple[str, int], Any], ServerDescription]:
"""Initial dict of (address, ServerDescription) for all seeds."""
return {address: ServerDescription(address) for address in self.seeds}

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Select / poll helper"""
from __future__ import annotations
import errno
import select

View File

@ -17,7 +17,7 @@ from __future__ import annotations
import ipaddress
import random
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union
try:
from dns import resolver
@ -107,7 +107,7 @@ class _SrvResolver:
def _get_srv_response_and_hosts(
self, encapsulate_errors: bool
) -> Tuple[resolver.Answer, List[Tuple[str, Any]]]:
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
results = self._resolve_uri(encapsulate_errors)
# Construct address tuples
@ -127,11 +127,11 @@ class _SrvResolver:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes
def get_hosts(self) -> List[Tuple[str, Any]]:
def get_hosts(self) -> list[tuple[str, Any]]:
_, nodes = self._get_srv_response_and_hosts(True)
return nodes
def get_hosts_and_min_ttl(self) -> Tuple[List[Tuple[str, Any]], int]:
def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
results, nodes = self._get_srv_response_and_hosts(False)
rrset = results.rrset
ttl = rrset.ttl if rrset else 0

View File

@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""A fake SSLContext implementation."""
from __future__ import annotations
import ssl as _ssl

View File

@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Support for SSL in PyMongo."""
from __future__ import annotations
from typing import Optional

View File

@ -22,18 +22,7 @@ import random
import time
import warnings
import weakref
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Set,
Tuple,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers, periodic_executor
from pymongo.client_session import _ServerSession, _ServerSessionPool
@ -146,7 +135,7 @@ class Topology:
self._closed = False
self._lock = _create_lock()
self._condition = self._settings.condition_class(self._lock)
self._servers: Dict[_Address, Server] = {}
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
self._session_pool = _ServerSessionPool()
@ -222,7 +211,7 @@ class Topology:
selector: Callable[[Selection], Selection],
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
) -> List[Server]:
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.
:Parameters:
@ -255,7 +244,7 @@ class Topology:
selector: Callable[[Selection], Selection],
timeout: float,
address: Optional[_Address],
) -> List[ServerDescription]:
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
end_time = now + timeout
@ -414,7 +403,7 @@ class Topology:
if self._opened and self._description.has_server(server_description.address):
self._process_change(server_description, reset_pool)
def _process_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None:
def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
"""Process a new seedlist on an opened topology.
Hold the lock when calling this.
"""
@ -434,7 +423,7 @@ class Topology:
)
)
def on_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None:
def on_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
"""Process a new list of nodes obtained from scanning SRV records."""
# We do no I/O holding the lock.
with self._lock:
@ -464,7 +453,7 @@ class Topology:
return writable_server_selector(self._new_selection())[0].address
def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> Set[_Address]:
def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> set[_Address]:
"""Return set of replica set member addresses."""
# Implemented here in Topology instead of MongoClient, so it can lock.
with self._lock:
@ -477,11 +466,11 @@ class Topology:
return {sd.address for sd in iter(selector(self._new_selection()))}
def get_secondaries(self) -> Set[_Address]:
def get_secondaries(self) -> set[_Address]:
"""Return set of secondary addresses."""
return self._get_replica_set_members(secondary_server_selector)
def get_arbiters(self) -> Set[_Address]:
def get_arbiters(self) -> set[_Address]:
"""Return set of arbiter addresses."""
return self._get_replica_set_members(arbiter_server_selector)
@ -514,7 +503,7 @@ class Topology:
self._request_check_all()
self._condition.wait(wait_time)
def data_bearing_servers(self) -> List[ServerDescription]:
def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers.
This includes any server that might be selected for an operation.
@ -573,7 +562,7 @@ class Topology:
def description(self) -> TopologyDescription:
return self._description
def pop_all_sessions(self) -> List[_ServerSession]:
def pop_all_sessions(self) -> list[_ServerSession]:
"""Pop all session ids from the pool."""
with self._lock:
return self._session_pool.pop_all()
@ -890,7 +879,7 @@ class Topology:
msg = "CLOSED "
return f"<{self.__class__.__name__} {msg}{self._description!r}>"
def eq_props(self) -> Tuple[Tuple[_Address, ...], Optional[str], Optional[str], str]:
def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]:
"""The properties to use for MongoClient/Topology equality checks."""
ts = self._settings
return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name)

View File

@ -13,18 +13,17 @@
# permissions and limitations under the License.
"""Represent a deployment of MongoDB servers."""
from __future__ import annotations
from random import sample
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
MutableMapping,
NamedTuple,
Optional,
Tuple,
cast,
)
@ -52,7 +51,7 @@ class _TopologyType(NamedTuple):
TOPOLOGY_TYPE = _TopologyType(*range(6))
# Topologies compatible with SRV record polling.
SRV_POLLING_TOPOLOGIES: Tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]]
@ -62,7 +61,7 @@ class TopologyDescription:
def __init__(
self,
topology_type: int,
server_descriptions: Dict[_Address, ServerDescription],
server_descriptions: dict[_Address, ServerDescription],
replica_set_name: Optional[str],
max_set_version: Optional[int],
max_election_id: Optional[ObjectId],
@ -193,8 +192,8 @@ class TopologyDescription:
self._topology_settings,
)
def server_descriptions(self) -> Dict[_Address, ServerDescription]:
"""Dict of (address,
def server_descriptions(self) -> dict[_Address, ServerDescription]:
"""dict of (address,
:class:`~pymongo.server_description.ServerDescription`).
"""
return self._server_descriptions.copy()
@ -233,7 +232,7 @@ class TopologyDescription:
return self._ls_timeout_minutes
@property
def known_servers(self) -> List[ServerDescription]:
def known_servers(self) -> list[ServerDescription]:
"""List of Servers of types besides Unknown."""
return [s for s in self._server_descriptions.values() if s.is_server_type_known]
@ -243,7 +242,7 @@ class TopologyDescription:
return any(s for s in self._server_descriptions.values() if s.is_server_type_known)
@property
def readable_servers(self) -> List[ServerDescription]:
def readable_servers(self) -> list[ServerDescription]:
"""List of readable Servers."""
return [s for s in self._server_descriptions.values() if s.is_readable]
@ -264,7 +263,7 @@ class TopologyDescription:
def srv_max_hosts(self) -> int:
return self._topology_settings._srv_max_hosts
def _apply_local_threshold(self, selection: Optional[Selection]) -> List[ServerDescription]:
def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]:
if not selection:
return []
# Round trip time in seconds.
@ -281,7 +280,7 @@ class TopologyDescription:
selector: Any,
address: Optional[_Address] = None,
custom_selector: Optional[_ServerSelector] = None,
) -> List[ServerDescription]:
) -> list[ServerDescription]:
"""List of servers matching the provided selector(s).
:Parameters:
@ -486,7 +485,7 @@ def updated_topology_description(
def _updated_topology_description_srv_polling(
topology_description: TopologyDescription, seedlist: List[Tuple[str, Any]]
topology_description: TopologyDescription, seedlist: list[tuple[str, Any]]
) -> TopologyDescription:
"""Return an updated copy of a TopologyDescription.
@ -535,7 +534,7 @@ def _update_rs_from_primary(
server_description: ServerDescription,
max_set_version: Optional[int],
max_election_id: Optional[ObjectId],
) -> Tuple[int, Optional[str], Optional[int], Optional[ObjectId]]:
) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]:
"""Update topology description from a primary's hello response.
Pass in a dict of ServerDescriptions, current replica set name, the
@ -555,8 +554,8 @@ def _update_rs_from_primary(
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
if server_description.max_wire_version is None or server_description.max_wire_version < 17:
new_election_tuple: Tuple = (server_description.set_version, server_description.election_id)
max_election_tuple: Tuple = (max_set_version, max_election_id)
new_election_tuple: tuple = (server_description.set_version, server_description.election_id)
max_election_tuple: tuple = (max_set_version, max_election_id)
if None not in new_election_tuple:
if None not in max_election_tuple and new_election_tuple < max_election_tuple:
# Stale primary, set to type Unknown.
@ -635,7 +634,7 @@ def _update_rs_no_primary_from_member(
sds: MutableMapping[_Address, ServerDescription],
replica_set_name: Optional[str],
server_description: ServerDescription,
) -> Tuple[int, Optional[str]]:
) -> tuple[int, Optional[str]]:
"""RS without known primary. Update from a non-primary's response.
Pass in a dict of ServerDescriptions, current replica set name, and the

View File

@ -13,6 +13,8 @@
# limitations under the License.
"""Type aliases used by PyMongo"""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,

View File

@ -22,13 +22,10 @@ import warnings
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Mapping,
MutableMapping,
Optional,
Sized,
Tuple,
Union,
cast,
)
@ -73,7 +70,7 @@ def _unquoted_percent(s: str) -> bool:
return False
def parse_userinfo(userinfo: str) -> Tuple[str, str]:
def parse_userinfo(userinfo: str) -> tuple[str, str]:
"""Validates the format of user information in a MongoDB URI.
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
"]", "@") as per RFC 3986 must be escaped.
@ -100,7 +97,7 @@ def parse_userinfo(userinfo: str) -> Tuple[str, str]:
def parse_ipv6_literal_host(
entity: str, default_port: Optional[int]
) -> Tuple[str, Optional[Union[str, int]]]:
) -> tuple[str, Optional[Union[str, int]]]:
"""Validates an IPv6 literal host:port string.
Returns a 2-tuple of IPv6 literal followed by port where
@ -370,7 +367,7 @@ def split_options(
return options
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> List[_Address]:
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
"""Takes a string of the form host1[:port],host2[:port]... and
splits it into (host, port) tuples. If [:port] isn't present the
default_port is used.
@ -427,7 +424,7 @@ def parse_uri(
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Parse and validate a MongoDB URI.
Returns a dict of the form::
@ -594,7 +591,7 @@ def parse_uri(
}
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> Dict[str, SSLContext]:
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
"""Parse KMS TLS connection options."""
if not kms_tls_options:
return {}

View File

@ -13,8 +13,9 @@
# limitations under the License.
"""Tools for working with write concerns."""
from __future__ import annotations
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from pymongo.errors import ConfigurationError
@ -62,7 +63,7 @@ class WriteConcern:
j: Optional[bool] = None,
fsync: Optional[bool] = None,
) -> None:
self.__document: Dict[str, Any] = {}
self.__document: dict[str, Any] = {}
self.__acknowledged = True
if wtimeout is not None:
@ -102,7 +103,7 @@ class WriteConcern:
return self.__server_default
@property
def document(self) -> Dict[str, Any]:
def document(self) -> dict[str, Any]:
"""The document representation of this write concern.
.. note::

View File

@ -0,0 +1,41 @@
# Copyright 2023-Present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ensure that 'from __future__ import annotations' is used in all package files
"""
import glob
import os
import sys
pattern = "from __future__ import annotations"
missing = []
for dirname in ["pymongo", "bson", "gridfs"]:
for path in glob.glob(f"{dirname}/*.py"):
if os.path.basename(path) in ["_version.py", "errors.py"]:
continue
found = False
with open(path) as fid:
for line in fid.readlines():
if line.strip() == pattern:
found = True
break
if not found:
missing.append(path)
if missing:
print(f"Missing '{pattern}' import in:")
for item in missing:
print(item)
sys.exit(1)

View File

@ -117,6 +117,7 @@ deps =
{[testenv:typecheck-pyright]deps}
allowlist_externals=echo
commands =
python tools/ensure_future_annotations_import.py
{[testenv:typecheck-mypy]commands}
{[testenv:typecheck-pyright]commands}
{[testenv:typecheck-pyright-strict]commands}