From fd095955f55d0be1b69cf9f72846be658aae0e99 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Mon, 3 Jul 2023 15:16:33 -0700 Subject: [PATCH] PYTHON-3777 add types to command_cursor.py (#1285) --- pymongo/command_cursor.py | 52 ++++++++++++++++++++++++++++----------- pymongo/message.py | 6 +++++ 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 7a2e52868..662bbb884 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -16,18 +16,30 @@ from __future__ import annotations from collections import deque -from typing import TYPE_CHECKING, Any, Generic, Iterator, Mapping, NoReturn, Optional +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterator, + List, + Mapping, + NoReturn, + Optional, + Union, +) from bson import _convert_raw_document_lists_to_streams from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.message import _CursorAddress, _GetMore, _RawBatchGetMore +from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore from pymongo.response import PinnedResponse from pymongo.typings import _Address, _DocumentType if TYPE_CHECKING: + from bson.codec_options import CodecOptions from pymongo.client_session import ClientSession from pymongo.collection import Collection + from pymongo.pool import SocketInfo class CommandCursor(Generic[_DocumentType]): @@ -51,7 +63,9 @@ class CommandCursor(Generic[_DocumentType]): self.__collection: Collection[_DocumentType] = collection self.__id = cursor_info["id"] self.__data = deque(cursor_info["firstBatch"]) - self.__postbatchresumetoken = cursor_info.get("postBatchResumeToken") + self.__postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get( + "postBatchResumeToken" + ) self.__address = address self.__batch_size = batch_size self.__max_await_time_ms = max_await_time_ms @@ -75,7 +89,7 @@ class CommandCursor(Generic[_DocumentType]): def __del__(self) -> None: self.__die() - def __die(self, synchronous=False): + def __die(self, synchronous: bool = False) -> None: """Closes this cursor.""" already_killed = self.__killed self.__killed = True @@ -98,7 +112,7 @@ class CommandCursor(Generic[_DocumentType]): self.__session = None self.__sock_mgr = None - def __end_session(self, synchronous): + def __end_session(self, synchronous: bool) -> None: if self.__session and not self.__explicit_session: self.__session._end_session(lock=synchronous) self.__session = None @@ -131,20 +145,20 @@ class CommandCursor(Generic[_DocumentType]): self.__batch_size = batch_size == 1 and 2 or batch_size return self - def _has_next(self): + def _has_next(self) -> bool: """Returns `True` if the cursor has documents remaining from the previous batch. """ return len(self.__data) > 0 @property - def _post_batch_resume_token(self): + def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]: """Retrieve the postBatchResumeToken from the response to a changeStream aggregate or getMore. """ return self.__postbatchresumetoken - def _maybe_pin_connection(self, sock_info): + def _maybe_pin_connection(self, sock_info: SocketInfo) -> None: client = self.__collection.database.client if not client._should_pin_cursor(self.__session): return @@ -158,7 +172,7 @@ class CommandCursor(Generic[_DocumentType]): else: self.__sock_mgr = sock_mgr - def __send_message(self, operation): + def __send_message(self, operation: _GetMore) -> None: """Send a getmore message and handle the response.""" client = self.__collection.database.client try: @@ -199,11 +213,16 @@ class CommandCursor(Generic[_DocumentType]): self.__data = deque(documents) def _unpack_response( - self, response, cursor_id, codec_options, user_fields=None, legacy_response=False - ): + self, + response: Union[_OpReply, _OpMsg], + cursor_id: Optional[int], + codec_options: CodecOptions[Mapping[str, Any]], + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> List[Mapping[str, Any]]: return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) - def _refresh(self): + def _refresh(self) -> int: """Refreshes the cursor with more data from the server. Returns the length of self.__data after refresh. Will exit early if @@ -362,8 +381,13 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]): ) def _unpack_response( - self, response, cursor_id, codec_options, user_fields=None, legacy_response=False - ): + self, + response: Union[_OpReply, _OpMsg], + cursor_id: Optional[int], + codec_options: CodecOptions, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> List[Mapping[str, Any]]: raw_response = response.raw_response(cursor_id, user_fields=user_fields) if not legacy_response: # OP_MSG returns firstBatch/nextBatch documents as a BSON array diff --git a/pymongo/message.py b/pymongo/message.py index 34f6e6235..735f8a8cc 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -1329,6 +1329,9 @@ class _OpReply: valid at server response - `codec_options` (optional): an instance of :class:`~bson.codec_options.CodecOptions` + - `user_fields` (optional): Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. """ self.raw_response(cursor_id) if legacy_response: @@ -1401,6 +1404,9 @@ class _OpMsg: - `cursor_id` (optional): Ignored, for compatibility with _OpReply. - `codec_options` (optional): an instance of :class:`~bson.codec_options.CodecOptions` + - `user_fields` (optional): Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. """ # If _OpMsg is in-use, this cannot be a legacy response. assert not legacy_response