diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 3a8ce635c..13cbd282b 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -84,15 +84,9 @@ jobs: - name: Install dependencies run: | pip install -q tox - - name: Run mypy + - name: Run typecheck run: | - tox -m typecheck-mypy - - name: Run pyright - run: | - tox -m typecheck-pyright - - name: Run pyright strict - run: | - tox -m typecheck-pyright-strict + tox -m typecheck docs: name: Docs Checks diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 9a4cda552..549334899 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -19,9 +19,10 @@ The :mod:`gridfs` package is an implementation of GridFS on top of .. seealso:: The MongoDB documentation on `gridfs `_. """ +from __future__ import annotations from collections import abc -from typing import Any, List, Mapping, Optional, cast +from typing import Any, Mapping, Optional, cast from bson.objectid import ObjectId from gridfs.errors import NoFile @@ -170,7 +171,7 @@ class GridFS: filename: Optional[str] = None, version: Optional[int] = -1, session: Optional[ClientSession] = None, - **kwargs: Any + **kwargs: Any, ) -> GridOut: """Get a file from GridFS by ``"filename"`` or metadata fields. @@ -275,7 +276,7 @@ class GridFS: self.__files.delete_one({"_id": file_id}, session=session) self.__chunks.delete_many({"files_id": file_id}, session=session) - def list(self, session: Optional[ClientSession] = None) -> List[str]: + def list(self, session: Optional[ClientSession] = None) -> list[str]: """List the names of all files stored in this instance of :class:`GridFS`. @@ -301,7 +302,7 @@ class GridFS: filter: Optional[Any] = None, session: Optional[ClientSession] = None, *args: Any, - **kwargs: Any + **kwargs: Any, ) -> Optional[GridOut]: """Get a single file from gridfs. @@ -400,7 +401,7 @@ class GridFS: self, document_or_id: Optional[Any] = None, session: Optional[ClientSession] = None, - **kwargs: Any + **kwargs: Any, ) -> bool: """Check if a file exists in this instance of :class:`GridFS`. diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index 7a09f35ef..b43303a0a 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -13,11 +13,13 @@ # limitations under the License. """Tools for representing files stored in GridFS.""" +from __future__ import annotations + import datetime import io import math import os -from typing import Any, Iterable, List, Mapping, NoReturn, Optional +from typing import Any, Iterable, Mapping, NoReturn, Optional from bson.binary import Binary from bson.int64 import Int64 @@ -480,7 +482,7 @@ class GridOut(io.IOBase): upload_date: datetime.datetime = _grid_out_property( "uploadDate", "Date that this file was first uploaded." ) - aliases: Optional[List[str]] = _grid_out_property("aliases", "List of aliases for this file.") + aliases: Optional[list[str]] = _grid_out_property("aliases", "List of aliases for this file.") metadata: Optional[Mapping[str, Any]] = _grid_out_property( "metadata", "Metadata attached to this file." ) diff --git a/pymongo/__init__.py b/pymongo/__init__.py index a3bdb4c16..f568c90f5 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """Python driver for MongoDB.""" +from __future__ import annotations from typing import ContextManager, Optional diff --git a/pymongo/_csot.py b/pymongo/_csot.py index 9ad477a24..d8f06eedc 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -20,7 +20,7 @@ import functools import time from collections import deque from contextvars import ContextVar, Token -from typing import Any, Callable, Deque, MutableMapping, Optional, Tuple, TypeVar, cast +from typing import Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast from pymongo.write_concern import WriteConcern @@ -72,7 +72,7 @@ class _TimeoutContext: def __init__(self, timeout: Optional[float]): self._timeout = timeout - self._tokens: Optional[Tuple[Token, Token, Token]] = None + self._tokens: Optional[tuple[Token, Token, Token]] = None def __enter__(self) -> _TimeoutContext: timeout_token = TIMEOUT.set(self._timeout) diff --git a/pymongo/bulk.py b/pymongo/bulk.py index 04f218b0e..1ed7a117e 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -24,13 +24,10 @@ from itertools import islice from typing import ( TYPE_CHECKING, Any, - Dict, Iterator, - List, Mapping, NoReturn, Optional, - Tuple, Type, Union, ) @@ -76,7 +73,7 @@ _BAD_VALUE: int = 2 _UNKNOWN_ERROR: int = 8 _WRITE_CONCERN_ERROR: int = 64 -_COMMANDS: Tuple[str, str, str] = ("insert", "update", "delete") +_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete") class _Run: @@ -85,8 +82,8 @@ class _Run: def __init__(self, op_type: int) -> None: """Initialize a new Run object.""" self.op_type: int = op_type - self.index_map: List[int] = [] - self.ops: List[Any] = [] + self.index_map: list[int] = [] + self.ops: list[Any] = [] self.idx_offset: int = 0 def index(self, idx: int) -> int: @@ -182,7 +179,7 @@ class _Bulk: common.validate_is_document_type("let", self.let) self.comment: Optional[str] = comment self.ordered = ordered - self.ops: List[Tuple[int, Mapping[str, Any]]] = [] + self.ops: list[tuple[int, Mapping[str, Any]]] = [] self.executed = False self.bypass_doc_val = bypass_document_validation self.uses_collation = False @@ -219,12 +216,12 @@ class _Bulk: multi: bool = False, upsert: bool = False, collation: Optional[Mapping[str, Any]] = None, - array_filters: Optional[List[Mapping[str, Any]]] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, SON[str, Any], None] = None, ) -> None: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) - cmd: Dict[str, Any] = dict( + cmd: dict[str, Any] = dict( [("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)] ) if collation is not None: @@ -414,7 +411,7 @@ class _Bulk: generator: Iterator[Any], write_concern: WriteConcern, session: Optional[ClientSession], - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Execute using write commands.""" # nModified is only reported for write commands, not legacy ops. full_result = { diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index ed3031ef5..ba8744c5c 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -16,17 +16,7 @@ from __future__ import annotations import copy -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generic, - List, - Mapping, - Optional, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union from bson import _bson_to_dict from bson.raw_bson import RawBSONDocument @@ -173,9 +163,9 @@ class ChangeStream(Generic[_DocumentType]): """ raise NotImplementedError - def _change_stream_options(self) -> Dict[str, Any]: + def _change_stream_options(self) -> dict[str, Any]: """Return the options dict for the $changeStream pipeline stage.""" - options: Dict[str, Any] = {} + options: dict[str, Any] = {} if self._full_document is not None: options["fullDocument"] = self._full_document @@ -197,7 +187,7 @@ class ChangeStream(Generic[_DocumentType]): return options - def _command_options(self) -> Dict[str, Any]: + def _command_options(self) -> dict[str, Any]: """Return the options dict for the aggregation command.""" options = {} if self._max_await_time_ms is not None: @@ -206,7 +196,7 @@ class ChangeStream(Generic[_DocumentType]): options["batchSize"] = self._batch_size return options - def _aggregation_pipeline(self) -> List[Dict[str, Any]]: + def _aggregation_pipeline(self) -> list[dict[str, Any]]: """Return the full aggregation pipeline for this ChangeStream.""" options = self._change_stream_options() full_pipeline: list = [{"$changeStream": options}] @@ -491,7 +481,7 @@ class ClusterChangeStream(DatabaseChangeStream, Generic[_DocumentType]): .. versionadded:: 3.7 """ - def _change_stream_options(self) -> Dict[str, Any]: + def _change_stream_options(self) -> dict[str, Any]: options = super()._change_stream_options() options["allChangesForCluster"] = True return options diff --git a/pymongo/client_options.py b/pymongo/client_options.py index a83216e9d..d1342652a 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -15,7 +15,7 @@ """Tools to parse mongo client options.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Tuple, cast +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast from bson.codec_options import _parse_codec_options from pymongo import common @@ -80,7 +80,7 @@ def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: return ReadConcern(concern) -def _parse_ssl_options(options: Mapping[str, Any]) -> Tuple[Optional[SSLContext], bool]: +def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: """Parse ssl options.""" use_tls = options.get("tls") if use_tls is not None: @@ -309,7 +309,7 @@ class ClientOptions: return self.__load_balanced @property - def event_listeners(self) -> List[_EventListeners]: + def event_listeners(self) -> list[_EventListeners]: """The event listeners registered for this client. See :mod:`~pymongo.monitoring` for details. diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 2ee13e2e9..fa9c62b59 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -144,8 +144,6 @@ from typing import ( Any, Callable, ContextManager, - Dict, - List, Mapping, MutableMapping, NoReturn, @@ -832,7 +830,7 @@ class ClientSession: self._transaction.state = _TxnState.ABORTED self._unpin() - def _finish_transaction_with_retry(self, command_name: str) -> Dict[str, Any]: + def _finish_transaction_with_retry(self, command_name: str) -> dict[str, Any]: """Run commit or abort with one retry after any retryable error. :Parameters: @@ -841,12 +839,12 @@ class ClientSession: def func( session: Optional[ClientSession], conn: Connection, retryable: bool - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return self._finish_transaction(conn, command_name) return self._client._retry_internal(func, self, None, retryable=True) - def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]: + def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 opts = self._transaction.opts assert opts @@ -1102,7 +1100,7 @@ class _ServerSessionPool(collections.deque): self.generation += 1 self.clear() - def pop_all(self) -> List[_ServerSession]: + def pop_all(self) -> list[_ServerSession]: ids = [] while self: ids.append(self.pop().session_id) diff --git a/pymongo/collation.py b/pymongo/collation.py index bada2d941..e940868e5 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -16,7 +16,9 @@ .. _collations: https://www.mongodb.com/docs/manual/reference/collation/ """ -from typing import Any, Dict, Mapping, Optional, Union +from __future__ import annotations + +from typing import Any, Mapping, Optional, Union from pymongo import common @@ -166,7 +168,7 @@ class Collation: **kwargs: Any, ) -> None: locale = common.validate_string("locale", locale) - self.__document: Dict[str, Any] = {"locale": locale} + self.__document: dict[str, Any] = {"locale": locale} if caseLevel is not None: self.__document["caseLevel"] = common.validate_boolean("caseLevel", caseLevel) if caseFirst is not None: @@ -190,7 +192,7 @@ class Collation: self.__document.update(kwargs) @property - def document(self) -> Dict[str, Any]: + def document(self) -> dict[str, Any]: """The document representation of this collation. .. note:: @@ -214,7 +216,7 @@ class Collation: def validate_collation_or_none( value: Optional[Union[Mapping[str, Any], Collation]] -) -> Optional[Dict[str, Any]]: +) -> Optional[dict[str, Any]]: if value is None: return None if isinstance(value, Collation): diff --git a/pymongo/collection.py b/pymongo/collection.py index 772e43e95..16b4f9b4b 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -24,13 +24,11 @@ from typing import ( Generic, Iterable, Iterator, - List, Mapping, MutableMapping, NoReturn, Optional, Sequence, - Tuple, Type, TypeVar, Union, @@ -258,7 +256,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _conn_for_reads( self, session: ClientSession - ) -> ContextManager[Tuple[Connection, _ServerMode]]: + ) -> ContextManager[tuple[Connection, _ServerMode]]: return self.__database.client._conn_for_reads(self._read_preference_for(session), session) def _conn_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]: @@ -739,9 +737,9 @@ class Collection(common.BaseObject, Generic[_DocumentType]): or not documents ): raise TypeError("documents must be a non-empty list") - inserted_ids: List[ObjectId] = [] + inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[Tuple[int, Mapping[str, Any]]]: + def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: """A generator that validates documents and handles _ids.""" for document in documents: common.validate_is_document_type("document", document) @@ -1930,7 +1928,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): session: Optional[ClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Create one or more indexes on this collection. >>> from pymongo import IndexModel, ASCENDING, DESCENDING @@ -1975,7 +1973,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): @_csot.apply def __create_indexes( self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any - ) -> List[str]: + ) -> list[str]: """Internal createIndexes helper. :Parameters: @@ -2438,11 +2436,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def create_search_indexes( self, - models: List[SearchIndexModel], + models: list[SearchIndexModel], session: Optional[ClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Create multiple search indexes for the current collection. :Parameters: @@ -2990,7 +2988,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): session: Optional[ClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, - ) -> List: + ) -> list: """Get a list of distinct values for `key` among all documents in this collection. @@ -3043,7 +3041,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): server: Server, conn: Connection, read_preference: Optional[_ServerMode], - ) -> List: + ) -> list: return self._command( conn, cmd, diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 777b88b08..ea58e46f7 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -21,7 +21,6 @@ from typing import ( Any, Generic, Iterator, - List, Mapping, NoReturn, Optional, @@ -389,7 +388,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]): codec_options: CodecOptions, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[Mapping[str, Any]]: + ) -> list[Mapping[str, Any]]: raw_response = response.raw_response(cursor_id, user_fields=user_fields) if not legacy_response: # OP_MSG returns firstBatch/nextBatch documents as a BSON array diff --git a/pymongo/common.py b/pymongo/common.py index a791a5e44..f5c4da71e 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -24,15 +24,12 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Iterator, - List, Mapping, MutableMapping, NoReturn, Optional, Sequence, - Tuple, Type, Union, overload, @@ -140,7 +137,7 @@ _MAX_END_SESSIONS = 10000 SRV_SERVICE_NAME = "mongodb" -def partition_node(node: str) -> Tuple[str, int]: +def partition_node(node: str) -> tuple[str, int]: """Split a host:port string into (host, int(port)) pair.""" host = node port = 27017 @@ -152,7 +149,7 @@ def partition_node(node: str) -> Tuple[str, int]: return host, port -def clean_node(node: str) -> Tuple[str, int]: +def clean_node(node: str) -> tuple[str, int]: """Split and normalize a node name from a hello response.""" host, port = partition_node(node) @@ -394,12 +391,12 @@ def validate_uuid_representation(dummy: Any, value: Any) -> int: ) -def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]]: +def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]: """Parse readPreferenceTags if passed as a client kwarg.""" if not isinstance(value, list): value = [value] - tag_sets: List = [] + tag_sets: list = [] for tag_set in value: if tag_set == "": tag_sets.append({}) @@ -426,9 +423,9 @@ _MECHANISM_PROPS = frozenset( ) -def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]: +def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Union[bool, str]]: """Validate authMechanismProperties.""" - props: Dict[str, Any] = {} + props: dict[str, Any] = {} if not isinstance(value, str): if not isinstance(value, dict): raise ValueError("Auth mechanism properties must be given as a string or a dictionary") @@ -515,14 +512,14 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: return value -def validate_list(option: str, value: Any) -> List: +def validate_list(option: str, value: Any) -> list: """Validates that 'value' is a list.""" if not isinstance(value, list): raise TypeError(f"{option} must be a list") return value -def validate_list_or_none(option: Any, value: Any) -> Optional[List]: +def validate_list_or_none(option: Any, value: Any) -> Optional[list]: """Validates that 'value' is a list or None.""" if value is None: return value @@ -671,7 +668,7 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo # Dictionary where keys are the names of public URI options, and values # are lists of aliases for that option. -URI_OPTIONS_ALIAS_MAP: Dict[str, List[str]] = { +URI_OPTIONS_ALIAS_MAP: dict[str, list[str]] = { "tls": ["ssl"], } @@ -679,7 +676,7 @@ URI_OPTIONS_ALIAS_MAP: Dict[str, List[str]] = { # are functions that validate user-input values for that option. If an option # alias uses a different validator than its public counterpart, it should be # included here as a key, value pair. -URI_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = { +URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = { "appname": validate_appname_or_none, "authmechanism": validate_auth_mechanism, "authmechanismproperties": validate_auth_mechanism_properties, @@ -721,7 +718,7 @@ URI_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = { # Dictionary where keys are the names of URI options specific to pymongo, # and values are functions that validate user-input values for those options. -NONSPEC_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = { +NONSPEC_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = { "connect": validate_boolean_or_string, "driver": validate_driver_or_none, "server_api": validate_server_api_or_none, @@ -739,7 +736,7 @@ NONSPEC_OPTIONS_VALIDATOR_MAP: Dict[str, Callable[[Any, Any], Any]] = { # Dictionary where keys are the names of keyword-only options for the # MongoClient constructor, and values are functions that validate user-input # values for those options. -KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = { +KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = { "document_class": validate_document_class, "type_registry": validate_type_registry, "read_preference": validate_read_preference, @@ -756,14 +753,14 @@ KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = { # internally-used names of that URI option. Options with only one name # variant need not be included here. Options whose public and internal # names are the same need not be included here. -INTERNAL_URI_OPTION_NAME_MAP: Dict[str, str] = { +INTERNAL_URI_OPTION_NAME_MAP: dict[str, str] = { "ssl": "tls", } # Map from deprecated URI option names to a tuple indicating the method of # their deprecation and any additional information that may be needed to # construct the warning message. -URI_OPTIONS_DEPRECATION_MAP: Dict[str, Tuple[str, str]] = { +URI_OPTIONS_DEPRECATION_MAP: dict[str, tuple[str, str]] = { # format: : (, ), # Supported values: # - 'renamed': should be the new option name. Note that case is @@ -782,11 +779,11 @@ for optname, aliases in URI_OPTIONS_ALIAS_MAP.items(): URI_OPTIONS_VALIDATOR_MAP[alias] = URI_OPTIONS_VALIDATOR_MAP[optname] # Map containing all URI option and keyword argument validators. -VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy() +VALIDATORS: dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy() VALIDATORS.update(KW_VALIDATORS) # List of timeout-related options. -TIMEOUT_OPTIONS: List[str] = [ +TIMEOUT_OPTIONS: list[str] = [ "connecttimeoutms", "heartbeatfrequencyms", "maxidletimems", @@ -800,7 +797,7 @@ TIMEOUT_OPTIONS: List[str] = [ _AUTH_OPTIONS = frozenset(["authmechanismproperties"]) -def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]: +def validate_auth_option(option: str, value: Any) -> tuple[str, Any]: """Validate optional authentication parameters.""" lower, value = validate(option, value) if lower not in _AUTH_OPTIONS: @@ -808,7 +805,7 @@ def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]: return option, value -def validate(option: str, value: Any) -> Tuple[str, Any]: +def validate(option: str, value: Any) -> tuple[str, Any]: """Generic validation function.""" lower = option.lower() validator = VALIDATORS.get(lower, raise_config_error) @@ -962,8 +959,8 @@ class BaseObject: class _CaseInsensitiveDictionary(MutableMapping[str, Any]): def __init__(self, *args: Any, **kwargs: Any): - self.__casedkeys: Dict[str, Any] = {} - self.__data: Dict[str, Any] = {} + self.__casedkeys: dict[str, Any] = {} + self.__data: dict[str, Any] = {} self.update(dict(*args, **kwargs)) def __contains__(self, key: str) -> bool: # type: ignore[override] @@ -1010,7 +1007,7 @@ class _CaseInsensitiveDictionary(MutableMapping[str, Any]): self.__casedkeys.pop(lc_key, None) return self.__data.pop(lc_key, *args, **kwargs) - def popitem(self) -> Tuple[str, Any]: + def popitem(self) -> tuple[str, Any]: lc_key, cased_key = self.__casedkeys.popitem() value = self.__data.pop(lc_key) return cased_key, value diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index 27fc3cdf2..a29bf826d 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -14,7 +14,7 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Iterable, Optional, Union try: import snappy @@ -47,7 +47,7 @@ _NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} _NO_COMPRESSION.update(_SENSITIVE_COMMANDS) -def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> List[str]: +def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]: try: # `value` is string. compressors = value.split(",") # type: ignore[union-attr] @@ -91,12 +91,12 @@ def validate_zlib_compression_level(option: str, value: Any) -> int: class CompressionSettings: - def __init__(self, compressors: List[str], zlib_compression_level: int): + def __init__(self, compressors: list[str], zlib_compression_level: int): self.compressors = compressors self.zlib_compression_level = zlib_compression_level def get_compression_context( - self, compressors: Optional[List[str]] + self, compressors: Optional[list[str]] ) -> Union[SnappyContext, ZlibContext, ZstdContext, None]: if compressors: chosen = compressors[0] diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 2bf420ac1..be4f6e2c1 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -21,7 +21,6 @@ from collections import deque from typing import ( TYPE_CHECKING, Any, - Dict, Generic, Iterable, List, @@ -444,7 +443,7 @@ class Cursor(Generic[_DocumentType]): def __query_spec(self) -> Mapping[str, Any]: """Get the spec to use for a query.""" - operators: Dict[str, Any] = {} + operators: dict[str, Any] = {} if self.__ordering: operators["$orderby"] = self.__ordering if self.__explain: @@ -884,7 +883,7 @@ class Cursor(Generic[_DocumentType]): self.__ordering = helpers._index_document(keys) return self - def distinct(self, key: str) -> List: + def distinct(self, key: str) -> list: """Get a list of distinct values for `key` among all documents in the result set of this query. @@ -901,7 +900,7 @@ class Cursor(Generic[_DocumentType]): .. seealso:: :meth:`pymongo.collection.Collection.distinct` """ - options: Dict[str, Any] = {} + options: dict[str, Any] = {} if self.__spec: options["query"] = self.__spec if self.__max_time_ms is not None: @@ -1017,11 +1016,11 @@ class Cursor(Generic[_DocumentType]): # Avoid overwriting a filter argument that was given by the user # when updating the spec. - spec: Dict[str, Any] + spec: dict[str, Any] if self.__has_filter: spec = dict(self.__spec) else: - spec = cast(Dict, self.__spec) + spec = cast(dict, self.__spec) spec["$where"] = code self.__spec = spec return self @@ -1234,7 +1233,7 @@ class Cursor(Generic[_DocumentType]): return self.__id @property - def address(self) -> Optional[Tuple[str, Any]]: + def address(self) -> Optional[tuple[str, Any]]: """The (host, port) of the server used, or None. .. versionchanged:: 3.0 @@ -1287,25 +1286,25 @@ class Cursor(Generic[_DocumentType]): return self._clone(deepcopy=True) @overload - def _deepcopy(self, x: Iterable, memo: Optional[Dict[int, Union[List, Dict]]] = None) -> List: + def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: ... @overload def _deepcopy( - self, x: SupportsItems, memo: Optional[Dict[int, Union[List, Dict]]] = None - ) -> Dict: + self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None + ) -> dict: ... def _deepcopy( - self, x: Union[Iterable, SupportsItems], memo: Optional[Dict[int, Union[List, Dict]]] = None - ) -> Union[List, Dict]: + self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None + ) -> Union[list, dict]: """Deepcopy helper for the data dictionary or list. Regular expressions cannot be deep copied but as they are immutable we don't have to copy them when cloning. """ - y: Union[List, Dict] - iterator: Iterable[Tuple[Any, Any]] + y: Union[list, dict] + iterator: Iterable[tuple[Any, Any]] if not hasattr(x, "items"): y, is_list, iterator = [], True, enumerate(x) else: @@ -1356,7 +1355,7 @@ class RawBatchCursor(Cursor, Generic[_DocumentType]): codec_options: CodecOptions[Mapping[str, Any]], user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[_DocumentOut]: + ) -> list[_DocumentOut]: raw_response = response.raw_response(cursor_id, user_fields=user_fields) if not legacy_response: # OP_MSG returns firstBatch/nextBatch documents as a BSON array diff --git a/pymongo/daemon.py b/pymongo/daemon.py index 643eb58b6..46c5b3984 100644 --- a/pymongo/daemon.py +++ b/pymongo/daemon.py @@ -18,6 +18,7 @@ PyMongo only attempts to spawn the mongocryptd daemon process when automatic client-side field level encryption is enabled. See :ref:`automatic-client-side-encryption` for more info. """ +from __future__ import annotations import os import subprocess diff --git a/pymongo/database.py b/pymongo/database.py index 133061424..ab11f4bdf 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -19,9 +19,7 @@ from copy import deepcopy from typing import ( TYPE_CHECKING, Any, - Dict, Generic, - List, Mapping, MutableMapping, NoReturn, @@ -695,12 +693,12 @@ class Database(common.BaseObject, Generic[_DocumentType]): check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: _ServerMode = ReadPreference.PRIMARY, - codec_options: CodecOptions[Dict[str, Any]] = DEFAULT_CODEC_OPTIONS, + codec_options: CodecOptions[dict[str, Any]] = DEFAULT_CODEC_OPTIONS, write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, session: Optional[ClientSession] = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: ... @overload @@ -729,13 +727,13 @@ class Database(common.BaseObject, Generic[_DocumentType]): allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: _ServerMode = ReadPreference.PRIMARY, codec_options: Union[ - CodecOptions[Dict[str, Any]], CodecOptions[_CodecDocumentType] + CodecOptions[dict[str, Any]], CodecOptions[_CodecDocumentType] ] = DEFAULT_CODEC_OPTIONS, write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, session: Optional[ClientSession] = None, **kwargs: Any, - ) -> Union[Dict[str, Any], _CodecDocumentType]: + ) -> Union[dict[str, Any], _CodecDocumentType]: """Internal command helper.""" if isinstance(command, str): command = SON([(command, value)]) @@ -767,7 +765,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): session: Optional[ClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: ... @overload @@ -797,7 +795,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): session: Optional[ClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, - ) -> Union[Dict[str, Any], _CodecDocumentType]: + ) -> Union[dict[str, Any], _CodecDocumentType]: """Issue a MongoDB command. Send command `command` to the database and return the @@ -1008,7 +1006,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): self, command: Union[str, MutableMapping[str, Any]], session: Optional[ClientSession] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Same as command but used for retryable read commands.""" read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY @@ -1017,7 +1015,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): server: Server, conn: Connection, read_preference: _ServerMode, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return self._command( conn, command, @@ -1106,7 +1104,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Get a list of all the collection names in this database. For example, to list all non-system collections:: @@ -1150,7 +1148,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): def _drop_helper( self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: command = SON([("drop", name)]) if comment is not None: command["comment"] = comment @@ -1172,7 +1170,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): session: Optional[ClientSession] = None, comment: Optional[Any] = None, encrypted_fields: Optional[Mapping[str, Any]] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Drop a collection. :Parameters: @@ -1252,7 +1250,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): session: Optional[ClientSession] = None, background: Optional[bool] = None, comment: Optional[Any] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Validate a collection. Returns a dict of validation info. Raises CollectionInvalid if diff --git a/pymongo/driver_info.py b/pymongo/driver_info.py index 86ddfcfb3..7252d8f64 100644 --- a/pymongo/driver_info.py +++ b/pymongo/driver_info.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Advanced options for MongoDB drivers implemented on top of PyMongo.""" +from __future__ import annotations from collections import namedtuple from typing import Optional diff --git a/pymongo/encryption.py b/pymongo/encryption.py index c8a33db7d..dbde9c367 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -29,7 +29,6 @@ from typing import ( MutableMapping, Optional, Sequence, - Tuple, ) try: @@ -594,7 +593,7 @@ class ClientEncryption(Generic[_DocumentType]): kms_provider: Optional[str] = None, master_key: Optional[Mapping[str, Any]] = None, **kwargs: Any, - ) -> Tuple[Collection[_DocumentType], Mapping[str, Any]]: + ) -> tuple[Collection[_DocumentType], Mapping[str, Any]]: """Create a collection with encryptedFields. .. warning:: diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index b4ffd92a8..33c6fe240 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -15,7 +15,7 @@ """Support for automatic client-side field level encryption.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional +from typing import TYPE_CHECKING, Any, Mapping, Optional try: import pymongocrypt # noqa: F401 @@ -45,7 +45,7 @@ class AutoEncryptionOpts: mongocryptd_uri: str = "mongodb://localhost:27020", mongocryptd_bypass_spawn: bool = False, mongocryptd_spawn_path: str = "mongocryptd", - mongocryptd_spawn_args: Optional[List[str]] = None, + mongocryptd_spawn_args: Optional[list[str]] = None, kms_tls_options: Optional[Mapping[str, Any]] = None, crypt_shared_lib_path: Optional[str] = None, crypt_shared_lib_required: bool = False, @@ -245,7 +245,7 @@ class RangeOpts: self.precision = precision @property - def document(self) -> Dict[str, Any]: + def document(self) -> dict[str, Any]: doc = {} for k, v in [ ("sparsity", int64.Int64(self.sparsity)), diff --git a/pymongo/errors.py b/pymongo/errors.py index c2cc6bbb6..c1e90b164 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -15,17 +15,7 @@ """Exceptions raised by PyMongo.""" from __future__ import annotations -from typing import ( - TYPE_CHECKING, - Any, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Union from bson.errors import InvalidDocument @@ -138,7 +128,7 @@ class NetworkTimeout(AutoReconnect): return True -def _format_detailed_error(message: str, details: Optional[Union[Mapping[str, Any], List]]) -> str: +def _format_detailed_error(message: str, details: Optional[Union[Mapping[str, Any], list]]) -> str: if details is not None: message = f"{message}, full error: {details}" return message @@ -161,7 +151,7 @@ class NotPrimaryError(AutoReconnect): """ def __init__( - self, message: str = "", errors: Optional[Union[Mapping[str, Any], List]] = None + self, message: str = "", errors: Optional[Union[Mapping[str, Any], list]] = None ) -> None: super().__init__(_format_detailed_error(message, errors), errors=errors) @@ -306,7 +296,7 @@ class BulkWriteError(OperationFailure): def __init__(self, results: _DocumentOut) -> None: super().__init__("batch op errors occurred", 65, results) - def __reduce__(self) -> Tuple[Any, Any]: + def __reduce__(self) -> tuple[Any, Any]: return self.__class__, (self.details,) @property diff --git a/pymongo/event_loggers.py b/pymongo/event_loggers.py index 70c386ab0..287db3fc4 100644 --- a/pymongo/event_loggers.py +++ b/pymongo/event_loggers.py @@ -26,6 +26,8 @@ or ``MongoClient(event_listeners=[CommandLogger()])`` """ +from __future__ import annotations + import logging from pymongo import monitoring diff --git a/pymongo/hello.py b/pymongo/hello.py index 1715beb5c..d38c285ab 100644 --- a/pymongo/hello.py +++ b/pymongo/hello.py @@ -13,11 +13,12 @@ # limitations under the License. """Helpers for the 'hello' and legacy hello commands.""" +from __future__ import annotations import copy import datetime import itertools -from typing import Any, Generic, List, Mapping, Optional, Set, Tuple +from typing import Any, Generic, Mapping, Optional from bson.objectid import ObjectId from pymongo import common @@ -95,7 +96,7 @@ class Hello(Generic[_DocumentType]): return self._server_type @property - def all_hosts(self) -> Set[Tuple[str, int]]: + def all_hosts(self) -> set[tuple[str, int]]: """List of hosts, passives, and arbiters known to this server.""" return set( map( @@ -114,7 +115,7 @@ class Hello(Generic[_DocumentType]): return self._doc.get("tags", {}) @property - def primary(self) -> Optional[Tuple[str, int]]: + def primary(self) -> Optional[tuple[str, int]]: """This server's opinion about who the primary is, or None.""" if self._doc.get("primary"): return common.partition_node(self._doc["primary"]) @@ -171,7 +172,7 @@ class Hello(Generic[_DocumentType]): return self._is_readable @property - def me(self) -> Optional[Tuple[str, int]]: + def me(self) -> Optional[tuple[str, int]]: me = self._doc.get("me") if me: return common.clean_node(me) @@ -182,11 +183,11 @@ class Hello(Generic[_DocumentType]): return self._doc.get("lastWrite", {}).get("lastWriteDate") @property - def compressors(self) -> Optional[List[str]]: + def compressors(self) -> Optional[list[str]]: return self._doc.get("compression") @property - def sasl_supported_mechs(self) -> List[str]: + def sasl_supported_mechs(self) -> list[str]: """Supported authentication mechanisms for the current user. For example:: diff --git a/pymongo/helpers.py b/pymongo/helpers.py index afab67e4e..e8c09aaa8 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -24,12 +24,10 @@ from typing import ( Callable, Container, Iterable, - List, Mapping, NoReturn, Optional, Sequence, - Tuple, TypeVar, Union, cast, @@ -100,7 +98,7 @@ def _gen_index_name(keys: _IndexList) -> str: def _index_list( key_or_list: _Hint, direction: Optional[Union[int, str]] = None -) -> Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]: +) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: """Helper to generate a list of (key, direction) pairs. Takes such a list, or a single key, or a single key and direction. @@ -116,7 +114,7 @@ def _index_list( return list(key_or_list) elif not isinstance(key_or_list, (list, tuple)): raise TypeError("if no direction is specified, key_or_list must be an instance of list") - values: List[Tuple[str, int]] = [] + values: list[tuple[str, int]] = [] for item in key_or_list: if isinstance(item, str): item = (item, ASCENDING) @@ -223,7 +221,7 @@ def _check_command_response( raise OperationFailure(errmsg, code, response, max_wire_version) -def _raise_last_write_error(write_errors: List[Any]) -> NoReturn: +def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: # If the last batch had multiple errors only report # the last error to emulate continue_on_error. error = write_errors[-1] diff --git a/pymongo/lock.py b/pymongo/lock.py index 741876afc..5a89a0263 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os import threading diff --git a/pymongo/message.py b/pymongo/message.py index 0c2034229..c370cfa03 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -29,14 +29,11 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Iterable, - List, Mapping, MutableMapping, NoReturn, Optional, - Tuple, Union, cast, ) @@ -137,14 +134,14 @@ def _maybe_add_read_preference( return spec -def _convert_exception(exception: Exception) -> Dict[str, Any]: +def _convert_exception(exception: Exception) -> dict[str, Any]: """Convert an Exception into a failure document for publishing.""" return {"errmsg": str(exception), "errtype": exception.__class__.__name__} def _convert_write_result( operation: str, command: Mapping[str, Any], result: Mapping[str, Any] -) -> Dict[str, Any]: +) -> dict[str, Any]: """Convert a legacy write result to write command format.""" # Based on _merge_legacy from bulk.py affected = result.get("n", 0) @@ -340,7 +337,7 @@ class _Query: self.client = client self.allow_disk_use = allow_disk_use self.name = "find" - self._as_command: Optional[Tuple[SON[str, Any], str]] = None + self._as_command: Optional[tuple[SON[str, Any], str]] = None self.exhaust = exhaust def reset(self) -> None: @@ -367,7 +364,7 @@ class _Query: def as_command( self, conn: Connection, apply_timeout: bool = False - ) -> Tuple[SON[str, Any], str]: + ) -> tuple[SON[str, Any], str]: """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. # Generate it once, for speed and to avoid repeating side-effects. @@ -411,7 +408,7 @@ class _Query: def get_message( self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False - ) -> Tuple[int, bytes, int]: + ) -> tuple[int, bytes, int]: """Get a query message, possibly setting the secondaryOk bit.""" # Use the read_preference decided by _socket_from_server. self.read_preference = read_preference @@ -508,7 +505,7 @@ class _GetMore: self.client = client self.max_await_time_ms = max_await_time_ms self.conn_mgr = conn_mgr - self._as_command: Optional[Tuple[SON[str, Any], str]] = None + self._as_command: Optional[tuple[SON[str, Any], str]] = None self.exhaust = exhaust self.comment = comment @@ -531,7 +528,7 @@ class _GetMore: def as_command( self, conn: Connection, apply_timeout: bool = False - ) -> Tuple[SON[str, Any], str]: + ) -> tuple[SON[str, Any], str]: """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. if self._as_command is not None: @@ -561,7 +558,7 @@ class _GetMore: def get_message( self, dummy0: Any, conn: Connection, use_cmd: bool = False - ) -> Union[Tuple[int, bytes, int], Tuple[int, bytes]]: + ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: """Get a getmore message.""" ns = self.namespace() ctx = conn.compression_context @@ -639,7 +636,7 @@ _COMPRESSION_HEADER_SIZE = 25 def _compress( operation: int, data: bytes, ctx: Union[SnappyContext, ZlibContext, ZstdContext] -) -> Tuple[int, bytes]: +) -> tuple[int, bytes]: """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" compressed = ctx.compress(data) request_id = _randint() @@ -659,7 +656,7 @@ def _compress( _pack_header = struct.Struct(" Tuple[int, bytes]: +def __pack_message(operation: int, data: bytes) -> tuple[int, bytes]: """Takes message data and adds a message header based on the operation. Returns the resultant message string. @@ -678,9 +675,9 @@ def _op_msg_no_header( flags: int, command: Mapping[str, Any], identifier: str, - docs: Optional[List[Mapping[str, Any]]], + docs: Optional[list[Mapping[str, Any]]], opts: CodecOptions, -) -> Tuple[bytes, int, int]: +) -> tuple[bytes, int, int]: """Get a OP_MSG message. Note: this method handles multiple documents in a type one payload but @@ -710,10 +707,10 @@ def _op_msg_compressed( flags: int, command: Mapping[str, Any], identifier: str, - docs: Optional[List[Mapping[str, Any]]], + docs: Optional[list[Mapping[str, Any]]], opts: CodecOptions, ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> Tuple[int, bytes, int, int]: +) -> tuple[int, bytes, int, int]: """Internal OP_MSG message helper.""" msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) rid, msg = _compress(2013, msg, ctx) @@ -724,9 +721,9 @@ def _op_msg_uncompressed( flags: int, command: Mapping[str, Any], identifier: str, - docs: Optional[List[Mapping[str, Any]]], + docs: Optional[list[Mapping[str, Any]]], opts: CodecOptions, -) -> Tuple[int, bytes, int, int]: +) -> tuple[int, bytes, int, int]: """Internal compressed OP_MSG message helper.""" data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) request_id, op_message = __pack_message(2013, data) @@ -744,7 +741,7 @@ def _op_msg( read_preference: Optional[_ServerMode], opts: CodecOptions, ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> Tuple[int, bytes, int, int]: +) -> tuple[int, bytes, int, int]: """Get a OP_MSG message.""" command["$db"] = dbname # getMore commands do not send $readPreference. @@ -777,7 +774,7 @@ def _query_impl( query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], opts: CodecOptions, -) -> Tuple[bytes, int]: +) -> tuple[bytes, int]: """Get an OP_QUERY message.""" encoded = _dict_to_bson(query, False, opts) if field_selector: @@ -809,7 +806,7 @@ def _query_compressed( field_selector: Optional[Mapping[str, Any]], opts: CodecOptions, ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> Tuple[int, bytes, int]: +) -> tuple[int, bytes, int]: """Internal compressed query message helper.""" op_query, max_bson_size = _query_impl( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts @@ -826,7 +823,7 @@ def _query_uncompressed( query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], opts: CodecOptions, -) -> Tuple[int, bytes, int]: +) -> tuple[int, bytes, int]: """Internal query message helper.""" op_query, max_bson_size = _query_impl( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts @@ -848,7 +845,7 @@ def _query( field_selector: Optional[Mapping[str, Any]], opts: CodecOptions, ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> Tuple[int, bytes, int]: +) -> tuple[int, bytes, int]: """Get a **query** message.""" if ctx: return _query_compressed( @@ -879,14 +876,14 @@ def _get_more_compressed( num_to_return: int, cursor_id: int, ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> Tuple[int, bytes]: +) -> tuple[int, bytes]: """Internal compressed getMore message helper.""" return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) def _get_more_uncompressed( collection_name: str, num_to_return: int, cursor_id: int -) -> Tuple[int, bytes]: +) -> tuple[int, bytes]: """Internal getMore message helper.""" return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) @@ -900,7 +897,7 @@ def _get_more( num_to_return: int, cursor_id: int, ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> Tuple[int, bytes]: +) -> tuple[int, bytes]: """Get a **getMore** message.""" if ctx: return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) @@ -950,8 +947,8 @@ class _BulkWriteContext: self.codec = codec def __batch_command( - self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]] - ) -> Tuple[int, bytes, List[Mapping[str, Any]]]: + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] + ) -> tuple[int, bytes, list[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" request_id, msg, to_send = _do_batched_op_msg( namespace, self.op_type, cmd, docs, self.codec, self @@ -961,16 +958,16 @@ class _BulkWriteContext: return request_id, msg, to_send def execute( - self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient - ) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]: + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient + ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: request_id, msg, to_send = self.__batch_command(cmd, docs) result = self.write_command(cmd, request_id, msg, to_send) client._process_response(result, self.session) return result, to_send def execute_unack( - self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient - ) -> List[Mapping[str, Any]]: + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient + ) -> list[Mapping[str, Any]]: request_id, msg, to_send = self.__batch_command(cmd, docs) # Though this isn't strictly a "legacy" write, the helper # handles publishing commands and sending our message @@ -1009,7 +1006,7 @@ class _BulkWriteContext: request_id: int, msg: bytes, max_doc_size: int, - docs: List[Mapping[str, Any]], + docs: list[Mapping[str, Any]], ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" if self.publish: @@ -1049,8 +1046,8 @@ class _BulkWriteContext: cmd: MutableMapping[str, Any], request_id: int, msg: bytes, - docs: List[Mapping[str, Any]], - ) -> Dict[str, Any]: + docs: list[Mapping[str, Any]], + ) -> dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" if self.publish: assert self.start_time is not None @@ -1076,7 +1073,7 @@ class _BulkWriteContext: return reply def _start( - self, cmd: MutableMapping[str, Any], request_id: int, docs: List[Mapping[str, Any]] + self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] ) -> MutableMapping[str, Any]: """Publish a CommandStartedEvent.""" cmd[self.field] = docs @@ -1126,8 +1123,8 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): __slots__ = () def __batch_command( - self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]] - ) -> Tuple[MutableMapping[str, Any], List[Mapping[str, Any]]]: + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] + ) -> tuple[MutableMapping[str, Any], list[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" msg, to_send = _encode_batched_write_command( namespace, self.op_type, cmd, docs, self.codec, self @@ -1141,8 +1138,8 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): return outgoing, to_send def execute( - self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient - ) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]: + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient + ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: batched_cmd, to_send = self.__batch_command(cmd, docs) result: Mapping[str, Any] = self.conn.command( self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client @@ -1150,8 +1147,8 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): return result, to_send def execute_unack( - self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient - ) -> List[Mapping[str, Any]]: + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient + ) -> list[Mapping[str, Any]]: batched_cmd, to_send = self.__batch_command(cmd, docs) self.conn.command( self.db_name, @@ -1196,12 +1193,12 @@ _OP_MSG_MAP = { def _batched_op_msg_impl( operation: int, command: Mapping[str, Any], - docs: List[Mapping[str, Any]], + docs: list[Mapping[str, Any]], ack: bool, opts: CodecOptions, ctx: _BulkWriteContext, buf: _BytesIO, -) -> Tuple[List[Mapping[str, Any]], int]: +) -> tuple[list[Mapping[str, Any]], int]: """Create a batched OP_MSG write.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size @@ -1264,11 +1261,11 @@ def _batched_op_msg_impl( def _encode_batched_op_msg( operation: int, command: Mapping[str, Any], - docs: List[Mapping[str, Any]], + docs: list[Mapping[str, Any]], ack: bool, opts: CodecOptions, ctx: _BulkWriteContext, -) -> Tuple[bytes, List[Mapping[str, Any]]]: +) -> tuple[bytes, list[Mapping[str, Any]]]: """Encode the next batched insert, update, or delete operation as OP_MSG. """ @@ -1285,11 +1282,11 @@ if _use_c: def _batched_op_msg_compressed( operation: int, command: Mapping[str, Any], - docs: List[Mapping[str, Any]], + docs: list[Mapping[str, Any]], ack: bool, opts: CodecOptions, ctx: _BulkWriteContext, -) -> Tuple[int, bytes, List[Mapping[str, Any]]]: +) -> tuple[int, bytes, list[Mapping[str, Any]]]: """Create the next batched insert, update, or delete operation with OP_MSG, compressed. """ @@ -1303,11 +1300,11 @@ def _batched_op_msg_compressed( def _batched_op_msg( operation: int, command: Mapping[str, Any], - docs: List[Mapping[str, Any]], + docs: list[Mapping[str, Any]], ack: bool, opts: CodecOptions, ctx: _BulkWriteContext, -) -> Tuple[int, bytes, List[Mapping[str, Any]]]: +) -> tuple[int, bytes, list[Mapping[str, Any]]]: """OP_MSG implementation entry point.""" buf = _BytesIO() @@ -1336,10 +1333,10 @@ def _do_batched_op_msg( namespace: str, operation: int, command: MutableMapping[str, Any], - docs: List[Mapping[str, Any]], + docs: list[Mapping[str, Any]], opts: CodecOptions, ctx: _BulkWriteContext, -) -> Tuple[int, bytes, List[Mapping[str, Any]]]: +) -> tuple[int, bytes, list[Mapping[str, Any]]]: """Create the next batched insert, update, or delete operation using OP_MSG. """ @@ -1360,10 +1357,10 @@ def _encode_batched_write_command( namespace: str, operation: int, command: MutableMapping[str, Any], - docs: List[Mapping[str, Any]], + docs: list[Mapping[str, Any]], opts: CodecOptions, ctx: _BulkWriteContext, -) -> Tuple[bytes, List[Mapping[str, Any]]]: +) -> tuple[bytes, list[Mapping[str, Any]]]: """Encode the next batched insert, update, or delete command.""" buf = _BytesIO() @@ -1379,11 +1376,11 @@ def _batched_write_command_impl( namespace: str, operation: int, command: MutableMapping[str, Any], - docs: List[Mapping[str, Any]], + docs: list[Mapping[str, Any]], opts: CodecOptions, ctx: _BulkWriteContext, buf: _BytesIO, -) -> Tuple[List[Mapping[str, Any]], int]: +) -> tuple[list[Mapping[str, Any]], int]: """Create a batched OP_QUERY write command.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size @@ -1468,7 +1465,7 @@ class _OpReply: def raw_response( self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None - ) -> List[bytes]: + ) -> list[bytes]: """Check the response header from the database, without decoding BSON. Check the response for errors and unpack. @@ -1517,7 +1514,7 @@ class _OpReply: codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Unpack a response from the database and decode the BSON document(s). Check the response for errors and unpack, returning a dictionary @@ -1541,7 +1538,7 @@ class _OpReply: return bson.decode_all(self.documents, codec_options) return bson._decode_all_selective(self.documents, codec_options, user_fields) - def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]: + def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: """Unpack a command response.""" docs = self.unpack_response(codec_options=codec_options) assert self.number_returned == 1 @@ -1588,7 +1585,7 @@ class _OpMsg: self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = {}, # noqa: B006 - ) -> List[Mapping[str, Any]]: + ) -> list[Mapping[str, Any]]: """ cursor_id is ignored user_fields is used to determine which fields must not be decoded @@ -1604,7 +1601,7 @@ class _OpMsg: codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Unpack a OP_MSG command response. :Parameters: @@ -1619,7 +1616,7 @@ class _OpMsg: assert not legacy_response return bson._decode_all_selective(self.payload_document, codec_options, user_fields) - def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]: + def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: """Unpack a command response.""" return self.unpack_response(codec_options=codec_options)[0] @@ -1652,7 +1649,7 @@ class _OpMsg: return cls(flags, payload_document) -_UNPACK_REPLY: Dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { +_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { _OpReply.OP_CODE: _OpReply.unpack, _OpMsg.OP_CODE: _OpMsg.unpack, } diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 51a6d20d6..becddb65f 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -41,18 +41,14 @@ from typing import ( Any, Callable, ContextManager, - Dict, FrozenSet, Generic, Iterator, - List, Mapping, MutableMapping, NoReturn, Optional, Sequence, - Set, - Tuple, Type, TypeVar, Union, @@ -716,7 +712,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): client.__my_database__ """ doc_class = document_class or dict - self.__init_kwargs: Dict[str, Any] = { + self.__init_kwargs: dict[str, Any] = { "host": host, "port": port, "document_class": doc_class, @@ -824,7 +820,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): self.__default_database_name = dbase self.__lock = _create_lock() - self.__kill_cursors_queue: List = [] + self.__kill_cursors_queue: list = [] self._event_listeners = options.pool_options._event_listeners super().__init__( @@ -1068,7 +1064,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): return self._topology.description @property - def address(self) -> Optional[Tuple[str, int]]: + def address(self) -> Optional[tuple[str, int]]: """(host, port) of the current standalone, primary, or mongos, or None. Accessing :attr:`address` raises :exc:`~.errors.InvalidOperation` if @@ -1100,7 +1096,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): return self._server_property("address") @property - def primary(self) -> Optional[Tuple[str, int]]: + def primary(self) -> Optional[tuple[str, int]]: """The (host, port) of the current primary of the replica set. Returns ``None`` if this client is not connected to a replica set, @@ -1113,7 +1109,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): return self._topology.get_primary() # type: ignore[return-value] @property - def secondaries(self) -> Set[_Address]: + def secondaries(self) -> set[_Address]: """The secondary members known to this client. A sequence of (host, port) pairs. Empty if this client is not @@ -1126,7 +1122,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): return self._topology.get_secondaries() @property - def arbiters(self) -> Set[_Address]: + def arbiters(self) -> set[_Address]: """Arbiters in the replica set. A sequence of (host, port) pairs. Empty if this client is not @@ -1179,7 +1175,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): """ return self.__options - def _end_sessions(self, session_ids: List[_ServerSession]) -> None: + def _end_sessions(self, session_ids: list[_ServerSession]) -> None: """Send endSessions command(s) with the given session ids.""" try: # Use Connection.command directly to avoid implicitly creating @@ -1313,7 +1309,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): @contextlib.contextmanager def _conn_from_server( self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession] - ) -> Iterator[Tuple[Connection, _ServerMode]]: + ) -> Iterator[tuple[Connection, _ServerMode]]: assert read_preference is not None, "read_preference must not be None" # Get a connection for a server matching the read preference, and yield # conn with the effective read preference. The Server Selection @@ -1337,7 +1333,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): def _conn_for_reads( self, read_preference: _ServerMode, session: Optional[ClientSession] - ) -> ContextManager[Tuple[Connection, _ServerMode]]: + ) -> ContextManager[tuple[Connection, _ServerMode]]: assert read_preference is not None, "read_preference must not be None" _ = self._get_topology() server = self._select_server(read_preference, session) @@ -1872,7 +1868,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if session is not None: session._process_response(reply) - def server_info(self, session: Optional[client_session.ClientSession] = None) -> Dict[str, Any]: + def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]: """Get information about the MongoDB server we're connected to. :Parameters: @@ -1894,7 +1890,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): session: Optional[client_session.ClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, - ) -> CommandCursor[Dict[str, Any]]: + ) -> CommandCursor[dict[str, Any]]: """Get a cursor over the databases of the connected server. :Parameters: @@ -1932,7 +1928,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): self, session: Optional[client_session.ClientSession] = None, comment: Optional[Any] = None, - ) -> List[str]: + ) -> list[str]: """Get a list of the names of all databases on the connected server. :Parameters: diff --git a/pymongo/monitor.py b/pymongo/monitor.py index b2ff3404f..d52e1e4c2 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -19,7 +19,7 @@ from __future__ import annotations import atexit import time import weakref -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Mapping, Optional, cast from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum @@ -274,7 +274,7 @@ class Monitor(MonitorBase): ) return sd - def _check_with_socket(self, conn: Connection) -> Tuple[Hello, float]: + def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: """Return (Hello, round_trip_time). Can raise ConnectionFailure or OperationFailure. @@ -326,7 +326,7 @@ class SrvMonitor(MonitorBase): # Topology was garbage-collected. self.close() - def _get_seedlist(self) -> Optional[List[Tuple[str, Any]]]: + def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: """Poll SRV records for a seedlist. Returns a list of ServerDescriptions. @@ -383,7 +383,7 @@ class _RttMonitor(MonitorBase): self._moving_average.add_sample(sample) self._moving_min.add_sample(sample) - def get(self) -> Tuple[Optional[float], float]: + def get(self) -> tuple[Optional[float], float]: """Get the calculated average, or None if no samples yet and the min.""" with self._lock: return self._moving_average.get(), self._moving_min.get() diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 73e15821d..a14d9a911 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -187,7 +187,7 @@ from __future__ import annotations import datetime from collections import abc, namedtuple -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence from bson.objectid import ObjectId from pymongo.hello import Hello, HelloCompat @@ -812,12 +812,12 @@ class PoolCreatedEvent(_PoolEvent): __slots__ = ("__options",) - def __init__(self, address: _Address, options: Dict[str, Any]) -> None: + def __init__(self, address: _Address, options: dict[str, Any]) -> None: super().__init__(address) self.__options = options @property - def options(self) -> Dict[str, Any]: + def options(self) -> dict[str, Any]: """Any non-default pool options that were set on this Connection Pool.""" return self.__options @@ -1439,7 +1439,7 @@ class _EventListeners: """Are any ConnectionPoolListener instances registered?""" return self.__enabled_for_cmap - def event_listeners(self) -> List[_EventListeners]: + def event_listeners(self) -> list[_EventListeners]: """List of registered event listeners.""" return ( self.__command_listeners @@ -1712,7 +1712,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_pool_created(self, address: _Address, options: Dict[str, Any]) -> None: + def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None: """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" event = PoolCreatedEvent(address, options) for subscriber in self.__cmap_listeners: diff --git a/pymongo/ocsp_cache.py b/pymongo/ocsp_cache.py index 033a7b607..742579312 100644 --- a/pymongo/ocsp_cache.py +++ b/pymongo/ocsp_cache.py @@ -19,7 +19,7 @@ from __future__ import annotations from collections import namedtuple from datetime import datetime as _datetime from datetime import timezone -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any from pymongo.lock import _create_lock @@ -36,7 +36,7 @@ class _OCSPCache: ) def __init__(self) -> None: - self._data: Dict[Any, OCSPResponse] = {} + self._data: dict[Any, OCSPResponse] = {} # Hold this lock when accessing _data. self._lock = _create_lock() diff --git a/pymongo/ocsp_support.py b/pymongo/ocsp_support.py index 292ee1bbf..ccc03ae67 100644 --- a/pymongo/ocsp_support.py +++ b/pymongo/ocsp_support.py @@ -19,7 +19,7 @@ import logging as _logging import re as _re from datetime import datetime as _datetime from datetime import timezone -from typing import TYPE_CHECKING, Iterable, List, Optional, Type, Union +from typing import TYPE_CHECKING, Iterable, Optional, Type, Union from cryptography.exceptions import InvalidSignature as _InvalidSignature from cryptography.hazmat.backends import default_backend as _default_backend @@ -101,7 +101,7 @@ _CERT_REGEX = _re.compile( ) -def _load_trusted_ca_certs(cafile: str) -> List[Certificate]: +def _load_trusted_ca_certs(cafile: str) -> list[Certificate]: """Parse the tlsCAFile into a list of certificates.""" with open(cafile, "rb") as f: data = f.read() @@ -115,7 +115,7 @@ def _load_trusted_ca_certs(cafile: str) -> List[Certificate]: def _get_issuer_cert( - cert: Certificate, chain: Iterable[Certificate], trusted_ca_certs: Optional[List[Certificate]] + cert: Certificate, chain: Iterable[Certificate], trusted_ca_certs: Optional[list[Certificate]] ) -> Optional[Certificate]: issuer_name = cert.issuer for candidate in chain: @@ -187,7 +187,7 @@ def _public_key_hash(cert: Certificate) -> bytes: def _get_certs_by_key_hash( certificates: Iterable[Certificate], issuer: Certificate, responder_key_hash: Optional[bytes] -) -> List[Certificate]: +) -> list[Certificate]: return [ cert for cert in certificates @@ -197,7 +197,7 @@ def _get_certs_by_key_hash( def _get_certs_by_name( certificates: Iterable[Certificate], issuer: Certificate, responder_name: Optional[Name] -) -> List[Certificate]: +) -> list[Certificate]: return [ cert for cert in certificates diff --git a/pymongo/operations.py b/pymongo/operations.py index a72dd523b..92d920bf0 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -18,9 +18,7 @@ from __future__ import annotations from typing import ( TYPE_CHECKING, Any, - Dict, Generic, - List, Mapping, Optional, Sequence, @@ -293,7 +291,7 @@ class _UpdateOp: doc: Union[Mapping[str, Any], _Pipeline], upsert: bool, collation: Optional[_CollationIn], - array_filters: Optional[List[Mapping[str, Any]]], + array_filters: Optional[list[Mapping[str, Any]]], hint: Optional[_IndexKeyHint], ): if filter is not None: @@ -355,7 +353,7 @@ class UpdateOne(_UpdateOp): update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, collation: Optional[_CollationIn] = None, - array_filters: Optional[List[Mapping[str, Any]]] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, ) -> None: """Represents an update_one operation. @@ -413,7 +411,7 @@ class UpdateMany(_UpdateOp): update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, collation: Optional[_CollationIn] = None, - array_filters: Optional[List[Mapping[str, Any]]] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, ) -> None: """Create an UpdateMany instance. @@ -537,7 +535,7 @@ class IndexModel: self.__document["collation"] = collation @property - def document(self) -> Dict[str, Any]: + def document(self) -> dict[str, Any]: """An index document suitable for passing to the createIndexes command. """ diff --git a/pymongo/pool.py b/pymongo/pool.py index 68052f649..a088e2eeb 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -28,16 +28,12 @@ import weakref from typing import ( TYPE_CHECKING, Any, - Dict, Iterator, - List, Mapping, MutableMapping, NoReturn, Optional, Sequence, - Set, - Tuple, Union, ) @@ -305,8 +301,8 @@ def _getenv_int(key: str) -> Optional[int]: return None -def _metadata_env() -> Dict[str, Any]: - env: Dict[str, Any] = {} +def _metadata_env() -> dict[str, Any]: + env: dict[str, Any] = {} # Skip if multiple (or no) envs are matched. if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1: return env @@ -520,7 +516,7 @@ class PoolOptions: return self.__credentials @property - def non_default_options(self) -> Dict[str, Any]: + def non_default_options(self) -> dict[str, Any]: """The non-default options this pool was created with. Added for CMAP's :class:`PoolCreatedEvent`. @@ -668,7 +664,7 @@ class Connection: """ def __init__( - self, conn: Union[socket.socket, _sslConn], pool: Pool, address: Tuple[str, int], id: int + self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int ): self.pool_ref = weakref.ref(pool) self.conn = conn @@ -693,7 +689,7 @@ class Connection: self.socket_checker: SocketChecker = SocketChecker() self.oidc_token_gen_id: Optional[int] = None # Support for mechanism negotiation on the initial handshake. - self.negotiated_mechs: Optional[List[str]] = None + self.negotiated_mechs: Optional[list[str]] = None self.auth_ctx: Optional[_AuthContext] = None # The pool's generation changes with each reset() so we can close @@ -774,7 +770,7 @@ class Connection: else: return SON([(HelloCompat.LEGACY_CMD, 1), ("helloOk", True)]) - def hello(self) -> Hello[Dict[str, Any]]: + def hello(self) -> Hello[dict[str, Any]]: return self._hello(None, None, None) def _hello( @@ -782,7 +778,7 @@ class Connection: cluster_time: Optional[ClusterTime], topology_version: Optional[Any], heartbeat_frequency: Optional[int], - ) -> Hello[Dict[str, Any]]: + ) -> Hello[dict[str, Any]]: cmd = self.hello_cmd() performing_handshake = not self.performed_handshake awaitable = False @@ -860,7 +856,7 @@ class Connection: self.generation = self.pool_gen.get(self.service_id) return hello - def _next_reply(self) -> Dict[str, Any]: + def _next_reply(self) -> dict[str, Any]: reply = self.receive_message(None) self.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response() @@ -887,7 +883,7 @@ class Connection: publish_events: bool = True, user_fields: Optional[Mapping[str, Any]] = None, exhaust_allowed: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Execute a command or raise an error. :Parameters: @@ -1007,7 +1003,7 @@ class Connection: def write_command( self, request_id: int, msg: bytes, codec_options: CodecOptions - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Send "insert" etc. command, returning response as a dict. Can raise ConnectionFailure or OperationFailure. @@ -1280,7 +1276,7 @@ class _PoolClosedError(PyMongoError): class _PoolGeneration: def __init__(self) -> None: # Maps service_id to generation. - self._generations: Dict[ObjectId, int] = collections.defaultdict(int) + self._generations: dict[ObjectId, int] = collections.defaultdict(int) # Overall pool generation. self._generation = 0 @@ -1381,7 +1377,7 @@ class Pool: # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). - self.__pinned_sockets: Set[Connection] = set() + self.__pinned_sockets: set[Connection] = set() self.ncursors = 0 self.ntxns = 0 diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index c16917381..4bef96ec7 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -23,7 +23,7 @@ import sys as _sys import time as _time from errno import EINTR as _EINTR from ipaddress import ip_address as _ip_address -from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate from OpenSSL import SSL as _SSL @@ -184,7 +184,7 @@ class _CallbackData: """Data class which is passed to the OCSP callback.""" def __init__(self) -> None: - self.trusted_ca_certs: Optional[List[Certificate]] = None + self.trusted_ca_certs: Optional[list[Certificate]] = None self.check_ocsp_endpoint: Optional[bool] = None self.ocsp_response_cache = _OCSPCache() diff --git a/pymongo/read_concern.py b/pymongo/read_concern.py index ddc90a817..0b54ee86f 100644 --- a/pymongo/read_concern.py +++ b/pymongo/read_concern.py @@ -13,8 +13,9 @@ # limitations under the License. """Tools for working with read concerns.""" +from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Optional class ReadConcern: @@ -50,7 +51,7 @@ class ReadConcern: return self.level is None or self.level == "local" @property - def document(self) -> Dict[str, Any]: + def document(self) -> dict[str, Any]: """The document representation of this read concern. .. note:: diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index 477efeda3..f731fdfaf 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections import abc -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence from pymongo import max_staleness_selectors from pymongo.errors import ConfigurationError @@ -131,9 +131,9 @@ class _ServerMode: return self.__mongos_mode @property - def document(self) -> Dict[str, Any]: + def document(self) -> dict[str, Any]: """Read preference as a document.""" - doc: Dict[str, Any] = {"mode": self.__mongos_mode} + doc: dict[str, Any] = {"mode": self.__mongos_mode} if self.__tag_sets not in (None, [{}]): doc["tags"] = self.__tag_sets if self.__max_staleness != -1: @@ -235,7 +235,7 @@ class _ServerMode: def __ne__(self, other: Any) -> bool: return not self == other - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: """Return value of object for pickling. Needed explicitly because __slots__() defined. diff --git a/pymongo/response.py b/pymongo/response.py index c236754b3..5ff6ca707 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -15,7 +15,7 @@ """Represent a response from the server.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union if TYPE_CHECKING: from datetime import timedelta @@ -95,7 +95,7 @@ class PinnedResponse(Response): request_id: int, duration: Optional[timedelta], from_command: bool, - docs: List[_DocumentOut], + docs: list[_DocumentOut], more_to_come: bool, ): """Represent a response to an exhaust cursor's initial query. diff --git a/pymongo/results.py b/pymongo/results.py index 3676d4a48..16e682a79 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -13,7 +13,9 @@ # limitations under the License. """Result class definitions.""" -from typing import Any, Dict, List, Mapping, Optional, cast +from __future__ import annotations + +from typing import Any, Mapping, Optional, cast from pymongo.errors import InvalidOperation @@ -76,12 +78,12 @@ class InsertManyResult(_WriteResult): __slots__ = ("__inserted_ids",) - def __init__(self, inserted_ids: List[Any], acknowledged: bool) -> None: + def __init__(self, inserted_ids: list[Any], acknowledged: bool) -> None: self.__inserted_ids = inserted_ids super().__init__(acknowledged) @property - def inserted_ids(self) -> List: + def inserted_ids(self) -> list: """A list of _ids of the inserted documents, in the order provided. .. note:: If ``False`` is passed for the `ordered` parameter to @@ -163,7 +165,7 @@ class BulkWriteResult(_WriteResult): __slots__ = ("__bulk_api_result",) - def __init__(self, bulk_api_result: Dict[str, Any], acknowledged: bool) -> None: + def __init__(self, bulk_api_result: dict[str, Any], acknowledged: bool) -> None: """Create a BulkWriteResult instance. :Parameters: @@ -176,7 +178,7 @@ class BulkWriteResult(_WriteResult): super().__init__(acknowledged) @property - def bulk_api_result(self) -> Dict[str, Any]: + def bulk_api_result(self) -> dict[str, Any]: """The raw bulk API result.""" return self.__bulk_api_result @@ -211,7 +213,7 @@ class BulkWriteResult(_WriteResult): return cast(int, self.__bulk_api_result.get("nUpserted")) @property - def upserted_ids(self) -> Optional[Dict[int, Any]]: + def upserted_ids(self) -> Optional[dict[int, Any]]: """A map of operation index to the _id of the upserted document.""" self._raise_if_unacknowledged("upserted_ids") if self.__bulk_api_result: diff --git a/pymongo/saslprep.py b/pymongo/saslprep.py index 34c0182a5..ff92d4b4f 100644 --- a/pymongo/saslprep.py +++ b/pymongo/saslprep.py @@ -13,6 +13,8 @@ # limitations under the License. """An implementation of RFC4013 SASLprep.""" +from __future__ import annotations + from typing import Any, Optional try: diff --git a/pymongo/server.py b/pymongo/server.py index 2fe2443ee..93bdeadeb 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -16,16 +16,7 @@ from __future__ import annotations from datetime import datetime -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ContextManager, - List, - Optional, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union from bson import _decode_all_selective from pymongo.errors import NotPrimaryError, OperationFailure @@ -110,7 +101,7 @@ class Server: operation: Union[_Query, _GetMore], read_preference: _ServerMode, listeners: Optional[_EventListeners], - unpack_res: Callable[..., List[_DocumentOut]], + unpack_res: Callable[..., list[_DocumentOut]], ) -> Response: """Run a _Query or _GetMore operation and return a Response object. @@ -275,8 +266,8 @@ class Server: return self._pool def _split_message( - self, message: Union[Tuple[int, Any], Tuple[int, Any, int]] - ) -> Tuple[int, Any, int]: + self, message: Union[tuple[int, Any], tuple[int, Any, int]] + ) -> tuple[int, Any, int]: """Return request_id, data, max_doc_size. :Parameters: diff --git a/pymongo/server_description.py b/pymongo/server_description.py index c2fa03053..490970581 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -13,10 +13,11 @@ # limitations under the License. """Represent one server the driver is connected to.""" +from __future__ import annotations import time import warnings -from typing import Any, Dict, Mapping, Optional, Set, Tuple +from typing import Any, Mapping, Optional from bson import EPOCH_NAIVE from bson.objectid import ObjectId @@ -129,7 +130,7 @@ class ServerDescription: return SERVER_TYPE._fields[self._server_type] @property - def all_hosts(self) -> Set[Tuple[str, int]]: + def all_hosts(self) -> set[tuple[str, int]]: """List of hosts, passives, and arbiters known to this server.""" return self._all_hosts @@ -143,7 +144,7 @@ class ServerDescription: return self._replica_set_name @property - def primary(self) -> Optional[Tuple[str, int]]: + def primary(self) -> Optional[tuple[str, int]]: """This server's opinion about who the primary is, or None.""" return self._primary @@ -180,7 +181,7 @@ class ServerDescription: return self._cluster_time @property - def election_tuple(self) -> Tuple[Optional[int], Optional[ObjectId]]: + def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]: warnings.warn( "'election_tuple' is deprecated, use 'set_version' and 'election_id' instead", DeprecationWarning, @@ -189,7 +190,7 @@ class ServerDescription: return self._set_version, self._election_id @property - def me(self) -> Optional[Tuple[str, int]]: + def me(self) -> Optional[tuple[str, int]]: return self._me @property @@ -297,4 +298,4 @@ class ServerDescription: ) # For unittesting only. Use under no circumstances! - _host_to_round_trip_time: Dict = {} + _host_to_round_trip_time: dict = {} diff --git a/pymongo/server_selectors.py b/pymongo/server_selectors.py index ee6441d7d..c22ad599e 100644 --- a/pymongo/server_selectors.py +++ b/pymongo/server_selectors.py @@ -15,7 +15,7 @@ """Criteria to select some ServerDescriptions from a TopologyDescription.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast from pymongo.server_type import SERVER_TYPE @@ -51,7 +51,7 @@ class Selection: def __init__( self, topology_description: TopologyDescription, - server_descriptions: List[ServerDescription], + server_descriptions: list[ServerDescription], common_wire_version: Optional[int], primary: Optional[ServerDescription], ): @@ -60,7 +60,7 @@ class Selection: self.primary = primary self.common_wire_version = common_wire_version - def with_server_descriptions(self, server_descriptions: List[ServerDescription]) -> Selection: + def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection: return Selection( self.topology_description, server_descriptions, self.common_wire_version, self.primary ) diff --git a/pymongo/server_type.py b/pymongo/server_type.py index ee53b6b97..937855cc7 100644 --- a/pymongo/server_type.py +++ b/pymongo/server_type.py @@ -13,6 +13,7 @@ # limitations under the License. """Type codes for MongoDB servers.""" +from __future__ import annotations from typing import NamedTuple diff --git a/pymongo/settings.py b/pymongo/settings.py index d6ef93e5c..a4be2295d 100644 --- a/pymongo/settings.py +++ b/pymongo/settings.py @@ -13,10 +13,11 @@ # permissions and limitations under the License. """Represent MongoClient's configuration.""" +from __future__ import annotations import threading import traceback -from typing import Any, Collection, Dict, Optional, Tuple, Type, Union +from typing import Any, Collection, Optional, Type, Union from bson.objectid import ObjectId from pymongo import common, monitor, pool @@ -30,7 +31,7 @@ from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector class TopologySettings: def __init__( self, - seeds: Optional[Collection[Tuple[str, int]]] = None, + seeds: Optional[Collection[tuple[str, int]]] = None, replica_set_name: Optional[str] = None, pool_class: Optional[Type[Pool]] = None, pool_options: Optional[PoolOptions] = None, @@ -56,7 +57,7 @@ class TopologySettings: % (common.MIN_HEARTBEAT_INTERVAL * 1000,) ) - self._seeds: Collection[Tuple[str, int]] = seeds or [("localhost", 27017)] + self._seeds: Collection[tuple[str, int]] = seeds or [("localhost", 27017)] self._replica_set_name = replica_set_name self._pool_class: Type[Pool] = pool_class or pool.Pool self._pool_options: PoolOptions = pool_options or PoolOptions() @@ -78,7 +79,7 @@ class TopologySettings: self._stack = "".join(traceback.format_stack()) @property - def seeds(self) -> Collection[Tuple[str, int]]: + def seeds(self) -> Collection[tuple[str, int]]: """List of server addresses.""" return self._seeds @@ -155,6 +156,6 @@ class TopologySettings: else: return TOPOLOGY_TYPE.Unknown - def get_server_descriptions(self) -> Dict[Union[Tuple[str, int], Any], ServerDescription]: + def get_server_descriptions(self) -> dict[Union[tuple[str, int], Any], ServerDescription]: """Initial dict of (address, ServerDescription) for all seeds.""" return {address: ServerDescription(address) for address in self.seeds} diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index a83311cd0..6d989a3d7 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -13,6 +13,7 @@ # limitations under the License. """Select / poll helper""" +from __future__ import annotations import errno import select diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index 67b781cf9..afde6bba6 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -17,7 +17,7 @@ from __future__ import annotations import ipaddress import random -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union try: from dns import resolver @@ -107,7 +107,7 @@ class _SrvResolver: def _get_srv_response_and_hosts( self, encapsulate_errors: bool - ) -> Tuple[resolver.Answer, List[Tuple[str, Any]]]: + ) -> tuple[resolver.Answer, list[tuple[str, Any]]]: results = self._resolve_uri(encapsulate_errors) # Construct address tuples @@ -127,11 +127,11 @@ class _SrvResolver: nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) return results, nodes - def get_hosts(self) -> List[Tuple[str, Any]]: + def get_hosts(self) -> list[tuple[str, Any]]: _, nodes = self._get_srv_response_and_hosts(True) return nodes - def get_hosts_and_min_ttl(self) -> Tuple[List[Tuple[str, Any]], int]: + def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]: results, nodes = self._get_srv_response_and_hosts(False) rrset = results.rrset ttl = rrset.ttl if rrset else 0 diff --git a/pymongo/ssl_context.py b/pymongo/ssl_context.py index 63970cb5e..1a0424208 100644 --- a/pymongo/ssl_context.py +++ b/pymongo/ssl_context.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """A fake SSLContext implementation.""" +from __future__ import annotations import ssl as _ssl diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index dafd88bdb..4c5e46dfd 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Support for SSL in PyMongo.""" +from __future__ import annotations from typing import Optional diff --git a/pymongo/topology.py b/pymongo/topology.py index 5b4197bc1..45e018c3b 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -22,18 +22,7 @@ import random import time import warnings import weakref -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Mapping, - Optional, - Set, - Tuple, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast from pymongo import _csot, common, helpers, periodic_executor from pymongo.client_session import _ServerSession, _ServerSessionPool @@ -146,7 +135,7 @@ class Topology: self._closed = False self._lock = _create_lock() self._condition = self._settings.condition_class(self._lock) - self._servers: Dict[_Address, Server] = {} + self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None self._session_pool = _ServerSessionPool() @@ -222,7 +211,7 @@ class Topology: selector: Callable[[Selection], Selection], server_selection_timeout: Optional[float] = None, address: Optional[_Address] = None, - ) -> List[Server]: + ) -> list[Server]: """Return a list of Servers matching selector, or time out. :Parameters: @@ -255,7 +244,7 @@ class Topology: selector: Callable[[Selection], Selection], timeout: float, address: Optional[_Address], - ) -> List[ServerDescription]: + ) -> list[ServerDescription]: """select_servers() guts. Hold the lock when calling this.""" now = time.monotonic() end_time = now + timeout @@ -414,7 +403,7 @@ class Topology: if self._opened and self._description.has_server(server_description.address): self._process_change(server_description, reset_pool) - def _process_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None: + def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: """Process a new seedlist on an opened topology. Hold the lock when calling this. """ @@ -434,7 +423,7 @@ class Topology: ) ) - def on_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None: + def on_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: """Process a new list of nodes obtained from scanning SRV records.""" # We do no I/O holding the lock. with self._lock: @@ -464,7 +453,7 @@ class Topology: return writable_server_selector(self._new_selection())[0].address - def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> Set[_Address]: + def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> set[_Address]: """Return set of replica set member addresses.""" # Implemented here in Topology instead of MongoClient, so it can lock. with self._lock: @@ -477,11 +466,11 @@ class Topology: return {sd.address for sd in iter(selector(self._new_selection()))} - def get_secondaries(self) -> Set[_Address]: + def get_secondaries(self) -> set[_Address]: """Return set of secondary addresses.""" return self._get_replica_set_members(secondary_server_selector) - def get_arbiters(self) -> Set[_Address]: + def get_arbiters(self) -> set[_Address]: """Return set of arbiter addresses.""" return self._get_replica_set_members(arbiter_server_selector) @@ -514,7 +503,7 @@ class Topology: self._request_check_all() self._condition.wait(wait_time) - def data_bearing_servers(self) -> List[ServerDescription]: + def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. This includes any server that might be selected for an operation. @@ -573,7 +562,7 @@ class Topology: def description(self) -> TopologyDescription: return self._description - def pop_all_sessions(self) -> List[_ServerSession]: + def pop_all_sessions(self) -> list[_ServerSession]: """Pop all session ids from the pool.""" with self._lock: return self._session_pool.pop_all() @@ -890,7 +879,7 @@ class Topology: msg = "CLOSED " return f"<{self.__class__.__name__} {msg}{self._description!r}>" - def eq_props(self) -> Tuple[Tuple[_Address, ...], Optional[str], Optional[str], str]: + def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]: """The properties to use for MongoClient/Topology equality checks.""" ts = self._settings return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name) diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index 21d47a531..9f9855714 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -13,18 +13,17 @@ # permissions and limitations under the License. """Represent a deployment of MongoDB servers.""" +from __future__ import annotations from random import sample from typing import ( Any, Callable, - Dict, List, Mapping, MutableMapping, NamedTuple, Optional, - Tuple, cast, ) @@ -52,7 +51,7 @@ class _TopologyType(NamedTuple): TOPOLOGY_TYPE = _TopologyType(*range(6)) # Topologies compatible with SRV record polling. -SRV_POLLING_TOPOLOGIES: Tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) +SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) _ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] @@ -62,7 +61,7 @@ class TopologyDescription: def __init__( self, topology_type: int, - server_descriptions: Dict[_Address, ServerDescription], + server_descriptions: dict[_Address, ServerDescription], replica_set_name: Optional[str], max_set_version: Optional[int], max_election_id: Optional[ObjectId], @@ -193,8 +192,8 @@ class TopologyDescription: self._topology_settings, ) - def server_descriptions(self) -> Dict[_Address, ServerDescription]: - """Dict of (address, + def server_descriptions(self) -> dict[_Address, ServerDescription]: + """dict of (address, :class:`~pymongo.server_description.ServerDescription`). """ return self._server_descriptions.copy() @@ -233,7 +232,7 @@ class TopologyDescription: return self._ls_timeout_minutes @property - def known_servers(self) -> List[ServerDescription]: + def known_servers(self) -> list[ServerDescription]: """List of Servers of types besides Unknown.""" return [s for s in self._server_descriptions.values() if s.is_server_type_known] @@ -243,7 +242,7 @@ class TopologyDescription: return any(s for s in self._server_descriptions.values() if s.is_server_type_known) @property - def readable_servers(self) -> List[ServerDescription]: + def readable_servers(self) -> list[ServerDescription]: """List of readable Servers.""" return [s for s in self._server_descriptions.values() if s.is_readable] @@ -264,7 +263,7 @@ class TopologyDescription: def srv_max_hosts(self) -> int: return self._topology_settings._srv_max_hosts - def _apply_local_threshold(self, selection: Optional[Selection]) -> List[ServerDescription]: + def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]: if not selection: return [] # Round trip time in seconds. @@ -281,7 +280,7 @@ class TopologyDescription: selector: Any, address: Optional[_Address] = None, custom_selector: Optional[_ServerSelector] = None, - ) -> List[ServerDescription]: + ) -> list[ServerDescription]: """List of servers matching the provided selector(s). :Parameters: @@ -486,7 +485,7 @@ def updated_topology_description( def _updated_topology_description_srv_polling( - topology_description: TopologyDescription, seedlist: List[Tuple[str, Any]] + topology_description: TopologyDescription, seedlist: list[tuple[str, Any]] ) -> TopologyDescription: """Return an updated copy of a TopologyDescription. @@ -535,7 +534,7 @@ def _update_rs_from_primary( server_description: ServerDescription, max_set_version: Optional[int], max_election_id: Optional[ObjectId], -) -> Tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: +) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: """Update topology description from a primary's hello response. Pass in a dict of ServerDescriptions, current replica set name, the @@ -555,8 +554,8 @@ def _update_rs_from_primary( return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id if server_description.max_wire_version is None or server_description.max_wire_version < 17: - new_election_tuple: Tuple = (server_description.set_version, server_description.election_id) - max_election_tuple: Tuple = (max_set_version, max_election_id) + new_election_tuple: tuple = (server_description.set_version, server_description.election_id) + max_election_tuple: tuple = (max_set_version, max_election_id) if None not in new_election_tuple: if None not in max_election_tuple and new_election_tuple < max_election_tuple: # Stale primary, set to type Unknown. @@ -635,7 +634,7 @@ def _update_rs_no_primary_from_member( sds: MutableMapping[_Address, ServerDescription], replica_set_name: Optional[str], server_description: ServerDescription, -) -> Tuple[int, Optional[str]]: +) -> tuple[int, Optional[str]]: """RS without known primary. Update from a non-primary's response. Pass in a dict of ServerDescriptions, current replica set name, and the diff --git a/pymongo/typings.py b/pymongo/typings.py index 3464c9294..174a0e361 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -13,6 +13,8 @@ # limitations under the License. """Type aliases used by PyMongo""" +from __future__ import annotations + from typing import ( TYPE_CHECKING, Any, diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index c0dcc4f8b..ab86e11bf 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -22,13 +22,10 @@ import warnings from typing import ( TYPE_CHECKING, Any, - Dict, - List, Mapping, MutableMapping, Optional, Sized, - Tuple, Union, cast, ) @@ -73,7 +70,7 @@ def _unquoted_percent(s: str) -> bool: return False -def parse_userinfo(userinfo: str) -> Tuple[str, str]: +def parse_userinfo(userinfo: str) -> tuple[str, str]: """Validates the format of user information in a MongoDB URI. Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", "]", "@") as per RFC 3986 must be escaped. @@ -100,7 +97,7 @@ def parse_userinfo(userinfo: str) -> Tuple[str, str]: def parse_ipv6_literal_host( entity: str, default_port: Optional[int] -) -> Tuple[str, Optional[Union[str, int]]]: +) -> tuple[str, Optional[Union[str, int]]]: """Validates an IPv6 literal host:port string. Returns a 2-tuple of IPv6 literal followed by port where @@ -370,7 +367,7 @@ def split_options( return options -def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> List[_Address]: +def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: """Takes a string of the form host1[:port],host2[:port]... and splits it into (host, port) tuples. If [:port] isn't present the default_port is used. @@ -427,7 +424,7 @@ def parse_uri( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Parse and validate a MongoDB URI. Returns a dict of the form:: @@ -594,7 +591,7 @@ def parse_uri( } -def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> Dict[str, SSLContext]: +def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: """Parse KMS TLS connection options.""" if not kms_tls_options: return {} diff --git a/pymongo/write_concern.py b/pymongo/write_concern.py index 6487197b0..ab6629fbb 100644 --- a/pymongo/write_concern.py +++ b/pymongo/write_concern.py @@ -13,8 +13,9 @@ # limitations under the License. """Tools for working with write concerns.""" +from __future__ import annotations -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from pymongo.errors import ConfigurationError @@ -62,7 +63,7 @@ class WriteConcern: j: Optional[bool] = None, fsync: Optional[bool] = None, ) -> None: - self.__document: Dict[str, Any] = {} + self.__document: dict[str, Any] = {} self.__acknowledged = True if wtimeout is not None: @@ -102,7 +103,7 @@ class WriteConcern: return self.__server_default @property - def document(self) -> Dict[str, Any]: + def document(self) -> dict[str, Any]: """The document representation of this write concern. .. note:: diff --git a/tools/ensure_future_annotations_import.py b/tools/ensure_future_annotations_import.py new file mode 100644 index 000000000..03e1bd036 --- /dev/null +++ b/tools/ensure_future_annotations_import.py @@ -0,0 +1,41 @@ +# Copyright 2023-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ensure that 'from __future__ import annotations' is used in all package files +""" + +import glob +import os +import sys + +pattern = "from __future__ import annotations" +missing = [] +for dirname in ["pymongo", "bson", "gridfs"]: + for path in glob.glob(f"{dirname}/*.py"): + if os.path.basename(path) in ["_version.py", "errors.py"]: + continue + found = False + with open(path) as fid: + for line in fid.readlines(): + if line.strip() == pattern: + found = True + break + if not found: + missing.append(path) + +if missing: + print(f"Missing '{pattern}' import in:") + for item in missing: + print(item) + sys.exit(1) diff --git a/tox.ini b/tox.ini index 66e33a983..221738015 100644 --- a/tox.ini +++ b/tox.ini @@ -117,6 +117,7 @@ deps = {[testenv:typecheck-pyright]deps} allowlist_externals=echo commands = + python tools/ensure_future_annotations_import.py {[testenv:typecheck-mypy]commands} {[testenv:typecheck-pyright]commands} {[testenv:typecheck-pyright-strict]commands}