From 02a365276c4f0a08adada01cb10c61fb709addcf Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Wed, 2 Aug 2023 20:11:25 -0700 Subject: [PATCH] PYTHON-3806 add types to message.py (#1312) --- pymongo/bulk.py | 2 +- pymongo/client_session.py | 2 +- pymongo/command_cursor.py | 9 +- pymongo/cursor.py | 9 +- pymongo/encryption.py | 4 +- pymongo/errors.py | 21 +- pymongo/helpers.py | 3 +- pymongo/message.py | 502 +++++++++++++----- pymongo/network.py | 11 +- pymongo/response.py | 10 +- pymongo/server.py | 7 +- ...nnections_survive_primary_stepdown_spec.py | 2 +- test/test_grid_file.py | 1 + test/utils.py | 3 +- 14 files changed, 405 insertions(+), 181 deletions(-) diff --git a/pymongo/bulk.py b/pymongo/bulk.py index a0ada2fb3..3401398d7 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -153,7 +153,7 @@ def _merge_command( full_result["writeConcernErrors"].append(wce) -def _raise_bulk_write_error(full_result: Mapping[str, Any]) -> NoReturn: +def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: """Raise a BulkWriteError from the full bulk api result.""" if full_result["writeErrors"]: full_result["writeErrors"].sort(key=lambda error: error["index"]) diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 6b0d24917..4a796154d 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -986,7 +986,7 @@ class ClientSession: self, command: MutableMapping[str, Any], is_retryable: bool, - read_preference: ReadPreference, + read_preference: _ServerMode, conn: Connection, ) -> None: self._check_ended() diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 5c346f5df..ddd81acec 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -33,7 +33,7 @@ from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore from pymongo.response import PinnedResponse -from pymongo.typings import _Address, _DocumentType +from pymongo.typings import _Address, _DocumentOut, _DocumentType if TYPE_CHECKING: from pymongo.client_session import ClientSession @@ -94,6 +94,7 @@ class CommandCursor(Generic[_DocumentType]): self.__killed = True if self.__id and not already_killed: cursor_id = self.__id + assert self.__address is not None address = _CursorAddress(self.__address, self.__ns) else: # Skip killCursors. @@ -219,7 +220,7 @@ class CommandCursor(Generic[_DocumentType]): codec_options: CodecOptions[Mapping[str, Any]], user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[Mapping[str, Any]]: + ) -> List[_DocumentOut]: return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) def _refresh(self) -> int: @@ -380,7 +381,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]): comment, ) - def _unpack_response( + def _unpack_response( # type: ignore[override] self, response: Union[_OpReply, _OpMsg], cursor_id: Optional[int], @@ -393,7 +394,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]): # OP_MSG returns firstBatch/nextBatch documents as a BSON array # Re-assemble the array of documents into a document stream _convert_raw_document_lists_to_streams(raw_response[0]) - return raw_response + return raw_response # type: ignore[return-value] def __getitem__(self, index: int) -> NoReturn: raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor") diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 91c66ee7e..4dab6858f 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -57,7 +57,7 @@ from pymongo.message import ( _RawBatchQuery, ) from pymongo.response import PinnedResponse -from pymongo.typings import _Address, _CollationIn, _DocumentType +from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType if TYPE_CHECKING: from _typeshed import SupportsItems @@ -420,6 +420,7 @@ class Cursor(Generic[_DocumentType]): self.__killed = True if self.__id and not already_killed: cursor_id = self.__id + assert self.__address is not None address = _CursorAddress(self.__address, f"{self.__dbname}.{self.__collname}") else: # Skip killCursors. @@ -1129,7 +1130,7 @@ class Cursor(Generic[_DocumentType]): codec_options: CodecOptions, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[Mapping[str, Any]]: + ) -> List[_DocumentOut]: return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) def _read_preference(self) -> _ServerMode: @@ -1355,13 +1356,13 @@ class RawBatchCursor(Cursor, Generic[_DocumentType]): codec_options: CodecOptions[Mapping[str, Any]], user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[Mapping[str, Any]]: + ) -> List[_DocumentOut]: raw_response = response.raw_response(cursor_id, user_fields=user_fields) if not legacy_response: # OP_MSG returns firstBatch/nextBatch documents as a BSON array # Re-assemble the array of documents into a document stream _convert_raw_document_lists_to_streams(raw_response[0]) - return raw_response + return cast(List["_DocumentOut"], raw_response) def explain(self) -> _DocumentType: """Returns an explain plan record for this cursor. diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 325355e36..c3e2e3407 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -80,8 +80,6 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from pymongocrypt.mongocrypt import MongoCryptKmsContext - from pymongo.response import Response - _HTTPS_PORT = 443 _KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to CONNECT_TIMEOUT _MONGOCRYPTD_TIMEOUT_MS = 10000 @@ -409,7 +407,7 @@ class _Encrypter: encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) return encrypt_cmd - def decrypt(self, response: Response) -> Optional[bytes]: + def decrypt(self, response: bytes) -> Optional[bytes]: """Decrypt a MongoDB command response. :Parameters: diff --git a/pymongo/errors.py b/pymongo/errors.py index 890bf723b..c2cc6bbb6 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -13,10 +13,25 @@ # limitations under the License. """Exceptions raised by PyMongo.""" -from typing import Any, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) from bson.errors import InvalidDocument +if TYPE_CHECKING: + from pymongo.typings import _DocumentOut + try: # CPython 3.7+ from ssl import SSLCertVerificationError as _CertificateError @@ -286,9 +301,9 @@ class BulkWriteError(OperationFailure): .. versionadded:: 2.7 """ - details: Mapping[str, Any] + details: _DocumentOut - def __init__(self, results: Mapping[str, Any]) -> None: + def __init__(self, results: _DocumentOut) -> None: super().__init__("batch op errors occurred", 65, results) def __reduce__(self) -> Tuple[Any, Any]: diff --git a/pymongo/helpers.py b/pymongo/helpers.py index c0c4f51ed..0fcd84b44 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -53,6 +53,7 @@ from pymongo.hello import HelloCompat if TYPE_CHECKING: from pymongo.cursor import _Hint from pymongo.operations import _IndexList + from pymongo.typings import _DocumentOut # From the SDAM spec, the "node is shutting down" codes. _SHUTDOWN_CODES: frozenset = frozenset( @@ -156,7 +157,7 @@ def _index_document(index_list: _IndexList) -> SON[str, Any]: def _check_command_response( - response: Mapping[str, Any], + response: _DocumentOut, max_wire_version: Optional[int], allowable_errors: Optional[Container[Union[int, str]]] = None, parse_write_concern_error: bool = False, diff --git a/pymongo/message.py b/pymongo/message.py index c4cf4d569..7565384e3 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -19,12 +19,27 @@ MongoDB. .. note:: This module is for internal use and is generally not needed by application developers. """ +from __future__ import annotations import datetime import random import struct from io import BytesIO as _BytesIO -from typing import Any, Mapping, NoReturn +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + NoReturn, + Optional, + Tuple, + Union, + cast, +) import bson from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode @@ -58,6 +73,18 @@ from pymongo.helpers import _handle_reauth from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern +if TYPE_CHECKING: + from datetime import timedelta + + from pymongo.client_session import ClientSession + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.mongo_client import MongoClient + from pymongo.monitoring import _EventListeners + from pymongo.pool import Connection + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _Address, _DocumentOut + MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 @@ -87,12 +114,14 @@ _UNICODE_REPLACE_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" = CodecOptions ) -def _randint(): +def _randint() -> int: """Generate a pseudo random 32 bit integer.""" return random.randint(MIN_INT32, MAX_INT32) -def _maybe_add_read_preference(spec, read_preference): +def _maybe_add_read_preference( + spec: MutableMapping[str, Any], read_preference: _ServerMode +) -> MutableMapping[str, Any]: """Add $readPreference to spec when appropriate.""" mode = read_preference.mode document = read_preference.document @@ -108,12 +137,14 @@ def _maybe_add_read_preference(spec, read_preference): return spec -def _convert_exception(exception): +def _convert_exception(exception: Exception) -> Dict[str, Any]: """Convert an Exception into a failure document for publishing.""" return {"errmsg": str(exception), "errtype": exception.__class__.__name__} -def _convert_write_result(operation, command, result): +def _convert_write_result( + operation: str, command: Mapping[str, Any], result: Mapping[str, Any] +) -> Dict[str, Any]: """Convert a legacy write result to write command format.""" # Based on _merge_legacy from bulk.py affected = result.get("n", 0) @@ -177,20 +208,20 @@ _MODIFIERS = SON( def _gen_find_command( - coll, - spec, - projection, - skip, - limit, - batch_size, - options, - read_concern, - collation=None, - session=None, - allow_disk_use=None, -): + coll: str, + spec: Mapping[str, Any], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]], + skip: int, + limit: int, + batch_size: Optional[int], + options: Optional[int], + read_concern: ReadConcern, + collation: Optional[Mapping[str, Any]] = None, + session: Optional[ClientSession] = None, + allow_disk_use: Optional[bool] = None, +) -> SON[str, Any]: """Generate a find command document.""" - cmd = SON([("find", coll)]) + cmd: SON[str, Any] = SON([("find", coll)]) if "$query" in spec: cmd.update( [ @@ -227,9 +258,16 @@ def _gen_find_command( return cmd -def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, conn): +def _gen_get_more_command( + cursor_id: Optional[int], + coll: str, + batch_size: Optional[int], + max_await_time_ms: Optional[int], + comment: Optional[Any], + conn: Connection, +) -> SON[str, Any]: """Generate a getMore command document.""" - cmd = SON([("getMore", cursor_id), ("collection", coll)]) + cmd: SON[str, Any] = SON([("getMore", cursor_id), ("collection", coll)]) if batch_size: cmd["batchSize"] = batch_size if max_await_time_ms is not None: @@ -269,22 +307,22 @@ class _Query: def __init__( self, - flags, - db, - coll, - ntoskip, - spec, - fields, - codec_options, - read_preference, - limit, - batch_size, - read_concern, - collation, - session, - client, - allow_disk_use, - exhaust, + flags: int, + db: str, + coll: str, + ntoskip: int, + spec: Mapping[str, Any], + fields: Optional[Mapping[str, Any]], + codec_options: CodecOptions, + read_preference: _ServerMode, + limit: int, + batch_size: int, + read_concern: ReadConcern, + collation: Optional[Mapping[str, Any]], + session: Optional[ClientSession], + client: MongoClient, + allow_disk_use: Optional[bool], + exhaust: bool, ): self.flags = flags self.db = db @@ -302,16 +340,16 @@ class _Query: self.client = client self.allow_disk_use = allow_disk_use self.name = "find" - self._as_command = None + self._as_command: Optional[Tuple[SON[str, Any], str]] = None self.exhaust = exhaust - def reset(self): + def reset(self) -> None: self._as_command = None - def namespace(self): + def namespace(self) -> str: return f"{self.db}.{self.coll}" - def use_command(self, conn): + def use_command(self, conn: Connection) -> bool: use_find_cmd = False if not self.exhaust: use_find_cmd = True @@ -327,7 +365,9 @@ class _Query: conn.validate_session(self.client, self.session) return use_find_cmd - def as_command(self, conn, apply_timeout=False): + def as_command( + self, conn: Connection, apply_timeout: bool = False + ) -> Tuple[SON[str, Any], str]: """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. # Generate it once, for speed and to avoid repeating side-effects. @@ -335,7 +375,7 @@ class _Query: return self._as_command explain = "$explain" in self.spec - cmd = _gen_find_command( + cmd: SON[str, Any] = _gen_find_command( self.coll, self.spec, self.fields, @@ -362,14 +402,16 @@ class _Query: # Support auto encryption client = self.client if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) + cmd = cast(SON[str, Any], client._encrypter.encrypt(self.db, cmd, self.codec_options)) # Support CSOT if apply_timeout: conn.apply_timeout(client, cmd) self._as_command = cmd, self.db return self._as_command - def get_message(self, read_preference, conn, use_cmd=False): + def get_message( + self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False + ) -> Tuple[int, bytes, int]: """Get a query message, possibly setting the secondaryOk bit.""" # Use the read_preference decided by _socket_from_server. self.read_preference = read_preference @@ -405,6 +447,7 @@ class _Query: ntoreturn = self.limit if conn.is_mongos: + assert isinstance(spec, MutableMapping) spec = _maybe_add_read_preference(spec, read_preference) return _query( @@ -442,18 +485,18 @@ class _GetMore: def __init__( self, - db, - coll, - ntoreturn, - cursor_id, - codec_options, - read_preference, - session, - client, - max_await_time_ms, - conn_mgr, - exhaust, - comment, + db: str, + coll: str, + ntoreturn: int, + cursor_id: int, + codec_options: CodecOptions, + read_preference: _ServerMode, + session: Optional[ClientSession], + client: MongoClient, + max_await_time_ms: Optional[int], + conn_mgr: Any, + exhaust: bool, + comment: Any, ): self.db = db self.coll = coll @@ -465,17 +508,17 @@ class _GetMore: self.client = client self.max_await_time_ms = max_await_time_ms self.conn_mgr = conn_mgr - self._as_command = None + self._as_command: Optional[Tuple[SON[str, Any], str]] = None self.exhaust = exhaust self.comment = comment - def reset(self): + def reset(self) -> None: self._as_command = None - def namespace(self): + def namespace(self) -> str: return f"{self.db}.{self.coll}" - def use_command(self, conn): + def use_command(self, conn: Connection) -> bool: use_cmd = False if not self.exhaust: use_cmd = True @@ -486,13 +529,15 @@ class _GetMore: conn.validate_session(self.client, self.session) return use_cmd - def as_command(self, conn, apply_timeout=False): + def as_command( + self, conn: Connection, apply_timeout: bool = False + ) -> Tuple[SON[str, Any], str]: """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. if self._as_command is not None: return self._as_command - cmd = _gen_get_more_command( + cmd: SON[str, Any] = _gen_get_more_command( self.cursor_id, self.coll, self.ntoreturn, @@ -507,14 +552,16 @@ class _GetMore: # Support auto encryption client = self.client if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) + cmd = cast(SON[str, Any], client._encrypter.encrypt(self.db, cmd, self.codec_options)) # Support CSOT if apply_timeout: conn.apply_timeout(client, cmd=None) self._as_command = cmd, self.db return self._as_command - def get_message(self, dummy0, conn, use_cmd=False): + def get_message( + self, dummy0: Any, conn: Connection, use_cmd: bool = False + ) -> Union[Tuple[int, bytes, int], Tuple[int, bytes]]: """Get a getmore message.""" ns = self.namespace() ctx = conn.compression_context @@ -534,7 +581,7 @@ class _GetMore: class _RawBatchQuery(_Query): - def use_command(self, conn): + def use_command(self, conn: Connection) -> bool: # Compatibility checks. super().use_command(conn) if conn.max_wire_version >= 8: @@ -546,7 +593,7 @@ class _RawBatchQuery(_Query): class _RawBatchGetMore(_GetMore): - def use_command(self, conn): + def use_command(self, conn: Connection) -> bool: # Compatibility checks. super().use_command(conn) if conn.max_wire_version >= 8: @@ -562,27 +609,27 @@ class _CursorAddress(tuple): __namespace: Any - def __new__(cls, address, namespace): + def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: self = tuple.__new__(cls, address) self.__namespace = namespace return self @property - def namespace(self): + def namespace(self) -> str: """The namespace this cursor.""" return self.__namespace - def __hash__(self): + def __hash__(self) -> int: # Two _CursorAddress instances with different namespaces # must not hash the same. return ((*self, self.__namespace)).__hash__() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, _CursorAddress): return tuple(self) == tuple(other) and self.namespace == other.namespace return NotImplemented - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other @@ -590,7 +637,9 @@ _pack_compression_header = struct.Struct(" Tuple[int, bytes]: """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" compressed = ctx.compress(data) request_id = _randint() @@ -610,7 +659,7 @@ def _compress(operation, data, ctx): _pack_header = struct.Struct(" Tuple[int, bytes]: """Takes message data and adds a message header based on the operation. Returns the resultant message string. @@ -625,7 +674,13 @@ _pack_op_msg_flags_type = struct.Struct(" Tuple[bytes, int, int]: """Get a OP_MSG message. Note: this method handles multiple documents in a type one payload but @@ -637,7 +692,7 @@ def _op_msg_no_header(flags, command, identifier, docs, opts): flags_type = _pack_op_msg_flags_type(flags, 0) total_size = len(encoded) max_doc_size = 0 - if identifier: + if identifier and docs is not None: type_one = _pack_byte(1) cstring = _make_c_string(identifier) encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] @@ -651,14 +706,27 @@ def _op_msg_no_header(flags, command, identifier, docs, opts): return b"".join(data), total_size, max_doc_size -def _op_msg_compressed(flags, command, identifier, docs, opts, ctx): +def _op_msg_compressed( + flags: int, + command: Mapping[str, Any], + identifier: str, + docs: Optional[List[Mapping[str, Any]]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> Tuple[int, bytes, int, int]: """Internal OP_MSG message helper.""" msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) rid, msg = _compress(2013, msg, ctx) return rid, msg, total_size, max_bson_size -def _op_msg_uncompressed(flags, command, identifier, docs, opts): +def _op_msg_uncompressed( + flags: int, + command: Mapping[str, Any], + identifier: str, + docs: Optional[List[Mapping[str, Any]]], + opts: CodecOptions, +) -> Tuple[int, bytes, int, int]: """Internal compressed OP_MSG message helper.""" data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) request_id, op_message = __pack_message(2013, data) @@ -669,7 +737,14 @@ if _use_c: _op_msg_uncompressed = _cmessage._op_msg # noqa: F811 -def _op_msg(flags, command, dbname, read_preference, opts, ctx=None): +def _op_msg( + flags: int, + command: MutableMapping[str, Any], + dbname: str, + read_preference: Optional[_ServerMode], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> Tuple[int, bytes, int, int]: """Get a OP_MSG message.""" command["$db"] = dbname # getMore commands do not send $readPreference. @@ -679,7 +754,7 @@ def _op_msg(flags, command, dbname, read_preference, opts, ctx=None): command["$readPreference"] = read_preference.document name = next(iter(command)) try: - identifier = _FIELD_MAP.get(name) + identifier = _FIELD_MAP[name] docs = command.pop(identifier) except KeyError: identifier = "" @@ -694,7 +769,15 @@ def _op_msg(flags, command, dbname, read_preference, opts, ctx=None): command[identifier] = docs -def _query_impl(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts): +def _query_impl( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, +) -> Tuple[bytes, int]: """Get an OP_QUERY message.""" encoded = _dict_to_bson(query, False, opts) if field_selector: @@ -718,8 +801,15 @@ def _query_impl(options, collection_name, num_to_skip, num_to_return, query, fie def _query_compressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx=None -): + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> Tuple[int, bytes, int]: """Internal compressed query message helper.""" op_query, max_bson_size = _query_impl( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts @@ -729,8 +819,14 @@ def _query_compressed( def _query_uncompressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts -): + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, +) -> Tuple[int, bytes, int]: """Internal query message helper.""" op_query, max_bson_size = _query_impl( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts @@ -744,8 +840,15 @@ if _use_c: def _query( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx=None -): + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> Tuple[int, bytes, int]: """Get a **query** message.""" if ctx: return _query_compressed( @@ -759,7 +862,7 @@ def _query( _pack_long_long = struct.Struct(" bytes: """Get an OP_GET_MORE message.""" return b"".join( [ @@ -771,12 +874,19 @@ def _get_more_impl(collection_name, num_to_return, cursor_id): ) -def _get_more_compressed(collection_name, num_to_return, cursor_id, ctx): +def _get_more_compressed( + collection_name: str, + num_to_return: int, + cursor_id: int, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> Tuple[int, bytes]: """Internal compressed getMore message helper.""" return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) -def _get_more_uncompressed(collection_name, num_to_return, cursor_id): +def _get_more_uncompressed( + collection_name: str, num_to_return: int, cursor_id: int +) -> Tuple[int, bytes]: """Internal getMore message helper.""" return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) @@ -785,7 +895,12 @@ if _use_c: _get_more_uncompressed = _cmessage._get_more_message # noqa: F811 -def _get_more(collection_name, num_to_return, cursor_id, ctx=None): +def _get_more( + collection_name: str, + num_to_return: int, + cursor_id: int, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> Tuple[int, bytes]: """Get a **getMore** message.""" if ctx: return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) @@ -811,7 +926,15 @@ class _BulkWriteContext: ) def __init__( - self, database_name, cmd_name, conn, operation_id, listeners, session, op_type, codec + self, + database_name: str, + cmd_name: str, + conn: Connection, + operation_id: int, + listeners: _EventListeners, + session: ClientSession, + op_type: int, + codec: CodecOptions, ): self.db_name = database_name self.conn = conn @@ -826,7 +949,9 @@ class _BulkWriteContext: self.op_type = op_type self.codec = codec - def _batch_command(self, cmd, docs): + def __batch_command( + self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]] + ) -> Tuple[int, bytes, List[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" request_id, msg, to_send = _do_batched_op_msg( namespace, self.op_type, cmd, docs, self.codec, self @@ -835,14 +960,18 @@ class _BulkWriteContext: raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send - def execute(self, cmd, docs, client): - request_id, msg, to_send = self._batch_command(cmd, docs) + def execute( + self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient + ) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]: + request_id, msg, to_send = self.__batch_command(cmd, docs) result = self.write_command(cmd, request_id, msg, to_send) client._process_response(result, self.session) return result, to_send - def execute_unack(self, cmd, docs, client): - request_id, msg, to_send = self._batch_command(cmd, docs) + def execute_unack( + self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient + ) -> List[Mapping[str, Any]]: + request_id, msg, to_send = self.__batch_command(cmd, docs) # Though this isn't strictly a "legacy" write, the helper # handles publishing commands and sending our message # without receiving a result. Send 0 for max_doc_size @@ -852,12 +981,12 @@ class _BulkWriteContext: return to_send @property - def max_bson_size(self): + def max_bson_size(self) -> int: """A proxy for SockInfo.max_bson_size.""" return self.conn.max_bson_size @property - def max_message_size(self): + def max_message_size(self) -> int: """A proxy for SockInfo.max_message_size.""" if self.compress: # Subtract 16 bytes for the message header. @@ -865,16 +994,23 @@ class _BulkWriteContext: return self.conn.max_message_size @property - def max_write_batch_size(self): + def max_write_batch_size(self) -> int: """A proxy for SockInfo.max_write_batch_size.""" return self.conn.max_write_batch_size @property - def max_split_size(self): + def max_split_size(self) -> int: """The maximum size of a BSON command before batch splitting.""" return self.max_bson_size - def unack_write(self, cmd, request_id, msg, max_doc_size, docs): + def unack_write( + self, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + max_doc_size: int, + docs: List[Mapping[str, Any]], + ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" if self.publish: assert self.start_time is not None @@ -896,9 +1032,9 @@ class _BulkWriteContext: assert self.start_time is not None duration = (datetime.datetime.now() - start) + duration if isinstance(exc, OperationFailure): - failure = _convert_write_result(self.name, cmd, exc.details) + failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type] elif isinstance(exc, NotPrimaryError): - failure = exc.details + failure = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) self._fail(request_id, failure, duration) @@ -908,8 +1044,14 @@ class _BulkWriteContext: return result @_handle_reauth - def write_command(self, cmd, request_id, msg, docs): - """A proxy for Connection.write_command that handles event publishing.""" + def write_command( + self, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + docs: List[Mapping[str, Any]], + ) -> Mapping[str, Any]: + """A proxy for SocketInfo.write_command that handles event publishing.""" if self.publish: assert self.start_time is not None duration = datetime.datetime.now() - self.start_time @@ -924,7 +1066,7 @@ class _BulkWriteContext: if self.publish: duration = (datetime.datetime.now() - start) + duration if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure = exc.details + failure: _DocumentOut = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) self._fail(request_id, failure, duration) @@ -933,7 +1075,9 @@ class _BulkWriteContext: self.start_time = datetime.datetime.now() return reply - def _start(self, cmd, request_id, docs): + def _start( + self, cmd: MutableMapping[str, Any], request_id: int, docs: List[Mapping[str, Any]] + ) -> MutableMapping[str, Any]: """Publish a CommandStartedEvent.""" cmd[self.field] = docs self.listeners.publish_command_start( @@ -946,7 +1090,7 @@ class _BulkWriteContext: ) return cmd - def _succeed(self, request_id, reply, duration): + def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None: """Publish a CommandSucceededEvent.""" self.listeners.publish_command_success( duration, @@ -958,7 +1102,7 @@ class _BulkWriteContext: self.conn.service_id, ) - def _fail(self, request_id, failure, duration): + def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None: """Publish a CommandFailedEvent.""" self.listeners.publish_command_failure( duration, @@ -981,7 +1125,9 @@ _MAX_SPLIT_SIZE_ENC = 2097152 class _EncryptedBulkWriteContext(_BulkWriteContext): __slots__ = () - def _batch_command(self, cmd, docs): + def __batch_command( + self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]] + ) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" msg, to_send = _encode_batched_write_command( namespace, self.op_type, cmd, docs, self.codec, self @@ -994,15 +1140,19 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): cmd = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) return cmd, to_send - def execute(self, cmd, docs, client): - batched_cmd, to_send = self._batch_command(cmd, docs) - result = self.conn.command( + def execute( + self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient + ) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]: + batched_cmd, to_send = self.__batch_command(cmd, docs) + result: Mapping[str, Any] = self.conn.command( self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client ) return result, to_send - def execute_unack(self, cmd, docs, client): - batched_cmd, to_send = self._batch_command(cmd, docs) + def execute_unack( + self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient + ) -> List[Mapping[str, Any]]: + batched_cmd, to_send = self.__batch_command(cmd, docs) self.conn.command( self.db_name, batched_cmd, @@ -1013,7 +1163,7 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): return to_send @property - def max_split_size(self): + def max_split_size(self) -> int: """Reduce the batch splitting size.""" return _MAX_SPLIT_SIZE_ENC @@ -1043,7 +1193,15 @@ _OP_MSG_MAP = { } -def _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf): +def _batched_op_msg_impl( + operation: int, + command: Mapping[str, Any], + docs: List[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, + buf: _BytesIO, +) -> Tuple[List[Mapping[str, Any]], int]: """Create a batched OP_MSG write.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size @@ -1103,7 +1261,14 @@ def _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf): return to_send, length -def _encode_batched_op_msg(operation, command, docs, ack, opts, ctx): +def _encode_batched_op_msg( + operation: int, + command: Mapping[str, Any], + docs: List[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> Tuple[bytes, List[Mapping[str, Any]]]: """Encode the next batched insert, update, or delete operation as OP_MSG. """ @@ -1117,17 +1282,32 @@ if _use_c: _encode_batched_op_msg = _cmessage._encode_batched_op_msg # noqa: F811 -def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx): +def _batched_op_msg_compressed( + operation: int, + command: Mapping[str, Any], + docs: List[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> Tuple[int, bytes, List[Mapping[str, Any]]]: """Create the next batched insert, update, or delete operation with OP_MSG, compressed. """ data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) + assert ctx.conn.compression_context is not None request_id, msg = _compress(2013, data, ctx.conn.compression_context) return request_id, msg, to_send -def _batched_op_msg(operation, command, docs, ack, opts, ctx): +def _batched_op_msg( + operation: int, + command: Mapping[str, Any], + docs: List[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> Tuple[int, bytes, List[Mapping[str, Any]]]: """OP_MSG implementation entry point.""" buf = _BytesIO() @@ -1152,7 +1332,14 @@ if _use_c: _batched_op_msg = _cmessage._batched_op_msg # noqa: F811 -def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx): +def _do_batched_op_msg( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: List[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> Tuple[int, bytes, List[Mapping[str, Any]]]: """Create the next batched insert, update, or delete operation using OP_MSG. """ @@ -1169,7 +1356,14 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx): # End OP_MSG ----------------------------------------------------- -def _encode_batched_write_command(namespace, operation, command, docs, opts, ctx): +def _encode_batched_write_command( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: List[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> Tuple[bytes, List[Mapping[str, Any]]]: """Encode the next batched insert, update, or delete command.""" buf = _BytesIO() @@ -1181,7 +1375,15 @@ if _use_c: _encode_batched_write_command = _cmessage._encode_batched_write_command # noqa: F811 -def _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf): +def _batched_write_command_impl( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: List[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, + buf: _BytesIO, +) -> Tuple[List[Mapping[str, Any]], int]: """Create a batched OP_QUERY write command.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size @@ -1258,13 +1460,15 @@ class _OpReply: UNPACK_FROM = struct.Struct(" List[bytes]: """Check the response header from the database, without decoding BSON. Check the response for errors and unpack. @@ -1309,11 +1513,11 @@ class _OpReply: def unpack_response( self, - cursor_id=None, - codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, - user_fields=None, - legacy_response=False, - ): + cursor_id: Optional[int] = None, + codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> List[_DocumentOut]: """Unpack a response from the database and decode the BSON document(s). Check the response for errors and unpack, returning a dictionary @@ -1337,24 +1541,24 @@ class _OpReply: return bson.decode_all(self.documents, codec_options) return bson._decode_all_selective(self.documents, codec_options, user_fields) - def command_response(self, codec_options): + def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]: """Unpack a command response.""" docs = self.unpack_response(codec_options=codec_options) assert self.number_returned == 1 return docs[0] - def raw_command_response(self): + def raw_command_response(self) -> NoReturn: """Return the bytes of the command response.""" # This should never be called on _OpReply. raise NotImplementedError @property - def more_to_come(self): + def more_to_come(self) -> bool: """Is the moreToCome bit set on this response?""" return False @classmethod - def unpack(cls, msg): + def unpack(cls, msg: bytes) -> _OpReply: """Construct an _OpReply from raw bytes.""" # PYTHON-945: ignore starting_from field. flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) @@ -1376,11 +1580,15 @@ class _OpMsg: MORE_TO_COME = 1 << 1 EXHAUST_ALLOWED = 1 << 16 # Only present on requests. - def __init__(self, flags, payload_document): + def __init__(self, flags: int, payload_document: bytes): self.flags = flags self.payload_document = payload_document - def raw_response(self, cursor_id=None, user_fields={}): # noqa: B006 + def raw_response( + self, + cursor_id: Optional[int] = None, + user_fields: Optional[Mapping[str, Any]] = {}, # noqa: B006 + ) -> List[Mapping[str, Any]]: """ cursor_id is ignored user_fields is used to determine which fields must not be decoded @@ -1392,11 +1600,11 @@ class _OpMsg: def unpack_response( self, - cursor_id=None, - codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, - user_fields=None, - legacy_response=False, - ): + cursor_id: Optional[int] = None, + codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> List[_DocumentOut]: """Unpack a OP_MSG command response. :Parameters: @@ -1411,21 +1619,21 @@ class _OpMsg: assert not legacy_response return bson._decode_all_selective(self.payload_document, codec_options, user_fields) - def command_response(self, codec_options): + def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]: """Unpack a command response.""" return self.unpack_response(codec_options=codec_options)[0] - def raw_command_response(self): + def raw_command_response(self) -> bytes: """Return the bytes of the command response.""" return self.payload_document @property - def more_to_come(self): + def more_to_come(self) -> bool: """Is the moreToCome bit set on this response?""" - return self.flags & self.MORE_TO_COME + return bool(self.flags & self.MORE_TO_COME) @classmethod - def unpack(cls, msg): + def unpack(cls, msg: bytes) -> _OpMsg: """Construct an _OpMsg from raw bytes.""" flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) if flags != 0: @@ -1444,7 +1652,7 @@ class _OpMsg: return cls(flags, payload_document) -_UNPACK_REPLY = { +_UNPACK_REPLY: Dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { _OpReply.OP_CODE: _OpReply.unpack, _OpMsg.OP_CODE: _OpMsg.unpack, } diff --git a/pymongo/network.py b/pymongo/network.py index 0d40955be..5aaceda52 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -23,7 +23,6 @@ import time from typing import ( TYPE_CHECKING, Any, - Dict, Mapping, MutableMapping, Optional, @@ -55,7 +54,7 @@ if TYPE_CHECKING: from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode - from pymongo.typings import _Address, _DocumentOut + from pymongo.typings import _Address, _DocumentOut, _DocumentType from pymongo.write_concern import WriteConcern _UNPACK_HEADER = struct.Struct(" Dict[str, Any]: +) -> _DocumentType: """Execute a command over the socket, or raise socket.error. :Parameters: @@ -177,7 +176,7 @@ def command( if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. reply = None - response_doc = {"ok": 1} + response_doc: _DocumentOut = {"ok": 1} else: reply = receive_message(conn, request_id) conn.more_to_come = reply.more_to_come @@ -226,7 +225,7 @@ def command( decrypted = client._encrypter.decrypt(reply.raw_command_response()) response_doc = _decode_all_selective(decrypted, codec_options, user_fields)[0] - return response_doc + return response_doc # type: ignore[return-value] _UNPACK_COMPRESSION_HEADER = struct.Struct(" List[Mapping[str, Any]]: + def docs(self) -> Sequence[Mapping[str, Any]]: """The decoded document(s).""" return self._docs @@ -95,7 +95,7 @@ class PinnedResponse(Response): request_id: int, duration: Optional[timedelta], from_command: bool, - docs: List[Mapping[str, Any]], + docs: List[_DocumentOut], more_to_come: bool, ): """Represent a response to an exhaust cursor's initial query. diff --git a/pymongo/server.py b/pymongo/server.py index ddb8de5c0..a23a87911 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -22,7 +22,6 @@ from typing import ( Callable, ContextManager, List, - Mapping, Optional, Tuple, Union, @@ -111,7 +110,7 @@ class Server: operation: Union[_Query, _GetMore], read_preference: _ServerMode, listeners: _EventListeners, - unpack_res: Callable[..., List[Mapping[str, Any]]], + unpack_res: Callable[..., List[_DocumentOut]], ) -> Response: """Run a _Query or _GetMore operation and return a Response object. @@ -193,9 +192,9 @@ class Server: # Must publish in find / getMore / explain command response # format. if use_cmd: - res: _DocumentOut = docs[0] # type: ignore[assignment] + res: _DocumentOut = docs[0] elif operation.name == "explain": - res = docs[0] if docs else {} # type: ignore[assignment] + res = docs[0] if docs else {} else: res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} if operation.name == "find": diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 21ea36aaf..38bda4e2a 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -111,7 +111,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): # Insert record and verify failure. with self.assertRaises(NotPrimaryError) as exc: self.coll.insert_one({"test": 1}) - self.assertEqual(exc.exception.details["code"], error_code) # type: ignore + self.assertEqual(exc.exception.details["code"], error_code) # type: ignore[call-overload] # Retry before CMAPListener assertion if retry_before=True. if retry: self.coll.insert_one({"test": 1}) diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 04003289e..ec88501a5 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -754,6 +754,7 @@ Bye""" # Kill the cursor to simulate the cursor timing out on the server # when an application spends a long time between two calls to # readchunk(). + assert client.address is not None client._close_cursor_now( outfile._GridOut__chunk_iter._cursor.cursor_id, _CursorAddress(client.address, db.fs.chunks.full_name), diff --git a/test/utils.py b/test/utils.py index ec17e2862..776aba823 100644 --- a/test/utils.py +++ b/test/utils.py @@ -720,7 +720,8 @@ def server_started_with_auth(client): try: command_line = get_command_line(client) except OperationFailure as e: - msg = e.details.get("errmsg", "") # type: ignore + assert e.details is not None + msg = e.details.get("errmsg", "") if e.code == 13 or "unauthorized" in msg or "login" in msg: # Unauthorized. return True