PYTHON-3777 add types to command_cursor.py (#1285)

This commit is contained in:
Iris 2023-07-03 15:16:33 -07:00 committed by GitHub
parent 94fabf5e98
commit fd095955f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 14 deletions

View File

@ -16,18 +16,30 @@
from __future__ import annotations
from collections import deque
from typing import TYPE_CHECKING, Any, Generic, Iterator, Mapping, NoReturn, Optional
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterator,
List,
Mapping,
NoReturn,
Optional,
Union,
)
from bson import _convert_raw_document_lists_to_streams
from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import _CursorAddress, _GetMore, _RawBatchGetMore
from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _DocumentType
if TYPE_CHECKING:
from bson.codec_options import CodecOptions
from pymongo.client_session import ClientSession
from pymongo.collection import Collection
from pymongo.pool import SocketInfo
class CommandCursor(Generic[_DocumentType]):
@ -51,7 +63,9 @@ class CommandCursor(Generic[_DocumentType]):
self.__collection: Collection[_DocumentType] = collection
self.__id = cursor_info["id"]
self.__data = deque(cursor_info["firstBatch"])
self.__postbatchresumetoken = cursor_info.get("postBatchResumeToken")
self.__postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
"postBatchResumeToken"
)
self.__address = address
self.__batch_size = batch_size
self.__max_await_time_ms = max_await_time_ms
@ -75,7 +89,7 @@ class CommandCursor(Generic[_DocumentType]):
def __del__(self) -> None:
self.__die()
def __die(self, synchronous=False):
def __die(self, synchronous: bool = False) -> None:
"""Closes this cursor."""
already_killed = self.__killed
self.__killed = True
@ -98,7 +112,7 @@ class CommandCursor(Generic[_DocumentType]):
self.__session = None
self.__sock_mgr = None
def __end_session(self, synchronous):
def __end_session(self, synchronous: bool) -> None:
if self.__session and not self.__explicit_session:
self.__session._end_session(lock=synchronous)
self.__session = None
@ -131,20 +145,20 @@ class CommandCursor(Generic[_DocumentType]):
self.__batch_size = batch_size == 1 and 2 or batch_size
return self
def _has_next(self):
def _has_next(self) -> bool:
"""Returns `True` if the cursor has documents remaining from the
previous batch.
"""
return len(self.__data) > 0
@property
def _post_batch_resume_token(self):
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
"""Retrieve the postBatchResumeToken from the response to a
changeStream aggregate or getMore.
"""
return self.__postbatchresumetoken
def _maybe_pin_connection(self, sock_info):
def _maybe_pin_connection(self, sock_info: SocketInfo) -> None:
client = self.__collection.database.client
if not client._should_pin_cursor(self.__session):
return
@ -158,7 +172,7 @@ class CommandCursor(Generic[_DocumentType]):
else:
self.__sock_mgr = sock_mgr
def __send_message(self, operation):
def __send_message(self, operation: _GetMore) -> None:
"""Send a getmore message and handle the response."""
client = self.__collection.database.client
try:
@ -199,11 +213,16 @@ class CommandCursor(Generic[_DocumentType]):
self.__data = deque(documents)
def _unpack_response(
self, response, cursor_id, codec_options, user_fields=None, legacy_response=False
):
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[Mapping[str, Any]]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
def _refresh(self):
def _refresh(self) -> int:
"""Refreshes the cursor with more data from the server.
Returns the length of self.__data after refresh. Will exit early if
@ -362,8 +381,13 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
)
def _unpack_response(
self, response, cursor_id, codec_options, user_fields=None, legacy_response=False
):
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[Mapping[str, Any]]:
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
if not legacy_response:
# OP_MSG returns firstBatch/nextBatch documents as a BSON array

View File

@ -1329,6 +1329,9 @@ class _OpReply:
valid at server response
- `codec_options` (optional): an instance of
:class:`~bson.codec_options.CodecOptions`
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
"""
self.raw_response(cursor_id)
if legacy_response:
@ -1401,6 +1404,9 @@ class _OpMsg:
- `cursor_id` (optional): Ignored, for compatibility with _OpReply.
- `codec_options` (optional): an instance of
:class:`~bson.codec_options.CodecOptions`
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
"""
# If _OpMsg is in-use, this cannot be a legacy response.
assert not legacy_response