PYTHON-3806 add types to message.py (#1312)

This commit is contained in:
Iris 2023-08-02 20:11:25 -07:00 committed by GitHub
parent 43046e04c0
commit 02a365276c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 405 additions and 181 deletions

View File

@ -153,7 +153,7 @@ def _merge_command(
full_result["writeConcernErrors"].append(wce)
def _raise_bulk_write_error(full_result: Mapping[str, Any]) -> NoReturn:
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
"""Raise a BulkWriteError from the full bulk api result."""
if full_result["writeErrors"]:
full_result["writeErrors"].sort(key=lambda error: error["index"])

View File

@ -986,7 +986,7 @@ class ClientSession:
self,
command: MutableMapping[str, Any],
is_retryable: bool,
read_preference: ReadPreference,
read_preference: _ServerMode,
conn: Connection,
) -> None:
self._check_ended()

View File

@ -33,7 +33,7 @@ from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _DocumentType
from pymongo.typings import _Address, _DocumentOut, _DocumentType
if TYPE_CHECKING:
from pymongo.client_session import ClientSession
@ -94,6 +94,7 @@ class CommandCursor(Generic[_DocumentType]):
self.__killed = True
if self.__id and not already_killed:
cursor_id = self.__id
assert self.__address is not None
address = _CursorAddress(self.__address, self.__ns)
else:
# Skip killCursors.
@ -219,7 +220,7 @@ class CommandCursor(Generic[_DocumentType]):
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[Mapping[str, Any]]:
) -> List[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
def _refresh(self) -> int:
@ -380,7 +381,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
comment,
)
def _unpack_response(
def _unpack_response( # type: ignore[override]
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
@ -393,7 +394,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
# Re-assemble the array of documents into a document stream
_convert_raw_document_lists_to_streams(raw_response[0])
return raw_response
return raw_response # type: ignore[return-value]
def __getitem__(self, index: int) -> NoReturn:
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")

View File

@ -57,7 +57,7 @@ from pymongo.message import (
_RawBatchQuery,
)
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _CollationIn, _DocumentType
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
if TYPE_CHECKING:
from _typeshed import SupportsItems
@ -420,6 +420,7 @@ class Cursor(Generic[_DocumentType]):
self.__killed = True
if self.__id and not already_killed:
cursor_id = self.__id
assert self.__address is not None
address = _CursorAddress(self.__address, f"{self.__dbname}.{self.__collname}")
else:
# Skip killCursors.
@ -1129,7 +1130,7 @@ class Cursor(Generic[_DocumentType]):
codec_options: CodecOptions,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[Mapping[str, Any]]:
) -> List[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
def _read_preference(self) -> _ServerMode:
@ -1355,13 +1356,13 @@ class RawBatchCursor(Cursor, Generic[_DocumentType]):
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[Mapping[str, Any]]:
) -> List[_DocumentOut]:
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
if not legacy_response:
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
# Re-assemble the array of documents into a document stream
_convert_raw_document_lists_to_streams(raw_response[0])
return raw_response
return cast(List["_DocumentOut"], raw_response)
def explain(self) -> _DocumentType:
"""Returns an explain plan record for this cursor.

View File

@ -80,8 +80,6 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext
from pymongo.response import Response
_HTTPS_PORT = 443
_KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to CONNECT_TIMEOUT
_MONGOCRYPTD_TIMEOUT_MS = 10000
@ -409,7 +407,7 @@ class _Encrypter:
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
return encrypt_cmd
def decrypt(self, response: Response) -> Optional[bytes]:
def decrypt(self, response: bytes) -> Optional[bytes]:
"""Decrypt a MongoDB command response.
:Parameters:

View File

@ -13,10 +13,25 @@
# limitations under the License.
"""Exceptions raised by PyMongo."""
from typing import Any, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from bson.errors import InvalidDocument
if TYPE_CHECKING:
from pymongo.typings import _DocumentOut
try:
# CPython 3.7+
from ssl import SSLCertVerificationError as _CertificateError
@ -286,9 +301,9 @@ class BulkWriteError(OperationFailure):
.. versionadded:: 2.7
"""
details: Mapping[str, Any]
details: _DocumentOut
def __init__(self, results: Mapping[str, Any]) -> None:
def __init__(self, results: _DocumentOut) -> None:
super().__init__("batch op errors occurred", 65, results)
def __reduce__(self) -> Tuple[Any, Any]:

View File

@ -53,6 +53,7 @@ from pymongo.hello import HelloCompat
if TYPE_CHECKING:
from pymongo.cursor import _Hint
from pymongo.operations import _IndexList
from pymongo.typings import _DocumentOut
# From the SDAM spec, the "node is shutting down" codes.
_SHUTDOWN_CODES: frozenset = frozenset(
@ -156,7 +157,7 @@ def _index_document(index_list: _IndexList) -> SON[str, Any]:
def _check_command_response(
response: Mapping[str, Any],
response: _DocumentOut,
max_wire_version: Optional[int],
allowable_errors: Optional[Container[Union[int, str]]] = None,
parse_write_concern_error: bool = False,

View File

@ -19,12 +19,27 @@ MongoDB.
.. note:: This module is for internal use and is generally not needed by
application developers.
"""
from __future__ import annotations
import datetime
import random
import struct
from io import BytesIO as _BytesIO
from typing import Any, Mapping, NoReturn
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Tuple,
Union,
cast,
)
import bson
from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode
@ -58,6 +73,18 @@ from pymongo.helpers import _handle_reauth
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from datetime import timedelta
from pymongo.client_session import ClientSession
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.mongo_client import MongoClient
from pymongo.monitoring import _EventListeners
from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _DocumentOut
MAX_INT32 = 2147483647
MIN_INT32 = -2147483648
@ -87,12 +114,14 @@ _UNICODE_REPLACE_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" = CodecOptions
)
def _randint():
def _randint() -> int:
"""Generate a pseudo random 32 bit integer."""
return random.randint(MIN_INT32, MAX_INT32)
def _maybe_add_read_preference(spec, read_preference):
def _maybe_add_read_preference(
spec: MutableMapping[str, Any], read_preference: _ServerMode
) -> MutableMapping[str, Any]:
"""Add $readPreference to spec when appropriate."""
mode = read_preference.mode
document = read_preference.document
@ -108,12 +137,14 @@ def _maybe_add_read_preference(spec, read_preference):
return spec
def _convert_exception(exception):
def _convert_exception(exception: Exception) -> Dict[str, Any]:
"""Convert an Exception into a failure document for publishing."""
return {"errmsg": str(exception), "errtype": exception.__class__.__name__}
def _convert_write_result(operation, command, result):
def _convert_write_result(
operation: str, command: Mapping[str, Any], result: Mapping[str, Any]
) -> Dict[str, Any]:
"""Convert a legacy write result to write command format."""
# Based on _merge_legacy from bulk.py
affected = result.get("n", 0)
@ -177,20 +208,20 @@ _MODIFIERS = SON(
def _gen_find_command(
coll,
spec,
projection,
skip,
limit,
batch_size,
options,
read_concern,
collation=None,
session=None,
allow_disk_use=None,
):
coll: str,
spec: Mapping[str, Any],
projection: Optional[Union[Mapping[str, Any], Iterable[str]]],
skip: int,
limit: int,
batch_size: Optional[int],
options: Optional[int],
read_concern: ReadConcern,
collation: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
allow_disk_use: Optional[bool] = None,
) -> SON[str, Any]:
"""Generate a find command document."""
cmd = SON([("find", coll)])
cmd: SON[str, Any] = SON([("find", coll)])
if "$query" in spec:
cmd.update(
[
@ -227,9 +258,16 @@ def _gen_find_command(
return cmd
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, conn):
def _gen_get_more_command(
cursor_id: Optional[int],
coll: str,
batch_size: Optional[int],
max_await_time_ms: Optional[int],
comment: Optional[Any],
conn: Connection,
) -> SON[str, Any]:
"""Generate a getMore command document."""
cmd = SON([("getMore", cursor_id), ("collection", coll)])
cmd: SON[str, Any] = SON([("getMore", cursor_id), ("collection", coll)])
if batch_size:
cmd["batchSize"] = batch_size
if max_await_time_ms is not None:
@ -269,22 +307,22 @@ class _Query:
def __init__(
self,
flags,
db,
coll,
ntoskip,
spec,
fields,
codec_options,
read_preference,
limit,
batch_size,
read_concern,
collation,
session,
client,
allow_disk_use,
exhaust,
flags: int,
db: str,
coll: str,
ntoskip: int,
spec: Mapping[str, Any],
fields: Optional[Mapping[str, Any]],
codec_options: CodecOptions,
read_preference: _ServerMode,
limit: int,
batch_size: int,
read_concern: ReadConcern,
collation: Optional[Mapping[str, Any]],
session: Optional[ClientSession],
client: MongoClient,
allow_disk_use: Optional[bool],
exhaust: bool,
):
self.flags = flags
self.db = db
@ -302,16 +340,16 @@ class _Query:
self.client = client
self.allow_disk_use = allow_disk_use
self.name = "find"
self._as_command = None
self._as_command: Optional[Tuple[SON[str, Any], str]] = None
self.exhaust = exhaust
def reset(self):
def reset(self) -> None:
self._as_command = None
def namespace(self):
def namespace(self) -> str:
return f"{self.db}.{self.coll}"
def use_command(self, conn):
def use_command(self, conn: Connection) -> bool:
use_find_cmd = False
if not self.exhaust:
use_find_cmd = True
@ -327,7 +365,9 @@ class _Query:
conn.validate_session(self.client, self.session)
return use_find_cmd
def as_command(self, conn, apply_timeout=False):
def as_command(
self, conn: Connection, apply_timeout: bool = False
) -> Tuple[SON[str, Any], str]:
"""Return a find command document for this query."""
# We use the command twice: on the wire and for command monitoring.
# Generate it once, for speed and to avoid repeating side-effects.
@ -335,7 +375,7 @@ class _Query:
return self._as_command
explain = "$explain" in self.spec
cmd = _gen_find_command(
cmd: SON[str, Any] = _gen_find_command(
self.coll,
self.spec,
self.fields,
@ -362,14 +402,16 @@ class _Query:
# Support auto encryption
client = self.client
if client._encrypter and not client._encrypter._bypass_auto_encryption:
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
cmd = cast(SON[str, Any], client._encrypter.encrypt(self.db, cmd, self.codec_options))
# Support CSOT
if apply_timeout:
conn.apply_timeout(client, cmd)
self._as_command = cmd, self.db
return self._as_command
def get_message(self, read_preference, conn, use_cmd=False):
def get_message(
self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False
) -> Tuple[int, bytes, int]:
"""Get a query message, possibly setting the secondaryOk bit."""
# Use the read_preference decided by _socket_from_server.
self.read_preference = read_preference
@ -405,6 +447,7 @@ class _Query:
ntoreturn = self.limit
if conn.is_mongos:
assert isinstance(spec, MutableMapping)
spec = _maybe_add_read_preference(spec, read_preference)
return _query(
@ -442,18 +485,18 @@ class _GetMore:
def __init__(
self,
db,
coll,
ntoreturn,
cursor_id,
codec_options,
read_preference,
session,
client,
max_await_time_ms,
conn_mgr,
exhaust,
comment,
db: str,
coll: str,
ntoreturn: int,
cursor_id: int,
codec_options: CodecOptions,
read_preference: _ServerMode,
session: Optional[ClientSession],
client: MongoClient,
max_await_time_ms: Optional[int],
conn_mgr: Any,
exhaust: bool,
comment: Any,
):
self.db = db
self.coll = coll
@ -465,17 +508,17 @@ class _GetMore:
self.client = client
self.max_await_time_ms = max_await_time_ms
self.conn_mgr = conn_mgr
self._as_command = None
self._as_command: Optional[Tuple[SON[str, Any], str]] = None
self.exhaust = exhaust
self.comment = comment
def reset(self):
def reset(self) -> None:
self._as_command = None
def namespace(self):
def namespace(self) -> str:
return f"{self.db}.{self.coll}"
def use_command(self, conn):
def use_command(self, conn: Connection) -> bool:
use_cmd = False
if not self.exhaust:
use_cmd = True
@ -486,13 +529,15 @@ class _GetMore:
conn.validate_session(self.client, self.session)
return use_cmd
def as_command(self, conn, apply_timeout=False):
def as_command(
self, conn: Connection, apply_timeout: bool = False
) -> Tuple[SON[str, Any], str]:
"""Return a getMore command document for this query."""
# See _Query.as_command for an explanation of this caching.
if self._as_command is not None:
return self._as_command
cmd = _gen_get_more_command(
cmd: SON[str, Any] = _gen_get_more_command(
self.cursor_id,
self.coll,
self.ntoreturn,
@ -507,14 +552,16 @@ class _GetMore:
# Support auto encryption
client = self.client
if client._encrypter and not client._encrypter._bypass_auto_encryption:
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
cmd = cast(SON[str, Any], client._encrypter.encrypt(self.db, cmd, self.codec_options))
# Support CSOT
if apply_timeout:
conn.apply_timeout(client, cmd=None)
self._as_command = cmd, self.db
return self._as_command
def get_message(self, dummy0, conn, use_cmd=False):
def get_message(
self, dummy0: Any, conn: Connection, use_cmd: bool = False
) -> Union[Tuple[int, bytes, int], Tuple[int, bytes]]:
"""Get a getmore message."""
ns = self.namespace()
ctx = conn.compression_context
@ -534,7 +581,7 @@ class _GetMore:
class _RawBatchQuery(_Query):
def use_command(self, conn):
def use_command(self, conn: Connection) -> bool:
# Compatibility checks.
super().use_command(conn)
if conn.max_wire_version >= 8:
@ -546,7 +593,7 @@ class _RawBatchQuery(_Query):
class _RawBatchGetMore(_GetMore):
def use_command(self, conn):
def use_command(self, conn: Connection) -> bool:
# Compatibility checks.
super().use_command(conn)
if conn.max_wire_version >= 8:
@ -562,27 +609,27 @@ class _CursorAddress(tuple):
__namespace: Any
def __new__(cls, address, namespace):
def __new__(cls, address: _Address, namespace: str) -> _CursorAddress:
self = tuple.__new__(cls, address)
self.__namespace = namespace
return self
@property
def namespace(self):
def namespace(self) -> str:
"""The namespace this cursor."""
return self.__namespace
def __hash__(self):
def __hash__(self) -> int:
# Two _CursorAddress instances with different namespaces
# must not hash the same.
return ((*self, self.__namespace)).__hash__()
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, _CursorAddress):
return tuple(self) == tuple(other) and self.namespace == other.namespace
return NotImplemented
def __ne__(self, other):
def __ne__(self, other: object) -> bool:
return not self == other
@ -590,7 +637,9 @@ _pack_compression_header = struct.Struct("<iiiiiiB").pack
_COMPRESSION_HEADER_SIZE = 25
def _compress(operation, data, ctx):
def _compress(
operation: int, data: bytes, ctx: Union[SnappyContext, ZlibContext, ZstdContext]
) -> Tuple[int, bytes]:
"""Takes message data, compresses it, and adds an OP_COMPRESSED header."""
compressed = ctx.compress(data)
request_id = _randint()
@ -610,7 +659,7 @@ def _compress(operation, data, ctx):
_pack_header = struct.Struct("<iiii").pack
def __pack_message(operation, data):
def __pack_message(operation: int, data: bytes) -> Tuple[int, bytes]:
"""Takes message data and adds a message header based on the operation.
Returns the resultant message string.
@ -625,7 +674,13 @@ _pack_op_msg_flags_type = struct.Struct("<IB").pack
_pack_byte = struct.Struct("<B").pack
def _op_msg_no_header(flags, command, identifier, docs, opts):
def _op_msg_no_header(
flags: int,
command: Mapping[str, Any],
identifier: str,
docs: Optional[List[Mapping[str, Any]]],
opts: CodecOptions,
) -> Tuple[bytes, int, int]:
"""Get a OP_MSG message.
Note: this method handles multiple documents in a type one payload but
@ -637,7 +692,7 @@ def _op_msg_no_header(flags, command, identifier, docs, opts):
flags_type = _pack_op_msg_flags_type(flags, 0)
total_size = len(encoded)
max_doc_size = 0
if identifier:
if identifier and docs is not None:
type_one = _pack_byte(1)
cstring = _make_c_string(identifier)
encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs]
@ -651,14 +706,27 @@ def _op_msg_no_header(flags, command, identifier, docs, opts):
return b"".join(data), total_size, max_doc_size
def _op_msg_compressed(flags, command, identifier, docs, opts, ctx):
def _op_msg_compressed(
flags: int,
command: Mapping[str, Any],
identifier: str,
docs: Optional[List[Mapping[str, Any]]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> Tuple[int, bytes, int, int]:
"""Internal OP_MSG message helper."""
msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
rid, msg = _compress(2013, msg, ctx)
return rid, msg, total_size, max_bson_size
def _op_msg_uncompressed(flags, command, identifier, docs, opts):
def _op_msg_uncompressed(
flags: int,
command: Mapping[str, Any],
identifier: str,
docs: Optional[List[Mapping[str, Any]]],
opts: CodecOptions,
) -> Tuple[int, bytes, int, int]:
"""Internal compressed OP_MSG message helper."""
data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
request_id, op_message = __pack_message(2013, data)
@ -669,7 +737,14 @@ if _use_c:
_op_msg_uncompressed = _cmessage._op_msg # noqa: F811
def _op_msg(flags, command, dbname, read_preference, opts, ctx=None):
def _op_msg(
flags: int,
command: MutableMapping[str, Any],
dbname: str,
read_preference: Optional[_ServerMode],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> Tuple[int, bytes, int, int]:
"""Get a OP_MSG message."""
command["$db"] = dbname
# getMore commands do not send $readPreference.
@ -679,7 +754,7 @@ def _op_msg(flags, command, dbname, read_preference, opts, ctx=None):
command["$readPreference"] = read_preference.document
name = next(iter(command))
try:
identifier = _FIELD_MAP.get(name)
identifier = _FIELD_MAP[name]
docs = command.pop(identifier)
except KeyError:
identifier = ""
@ -694,7 +769,15 @@ def _op_msg(flags, command, dbname, read_preference, opts, ctx=None):
command[identifier] = docs
def _query_impl(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts):
def _query_impl(
options: int,
collection_name: str,
num_to_skip: int,
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
) -> Tuple[bytes, int]:
"""Get an OP_QUERY message."""
encoded = _dict_to_bson(query, False, opts)
if field_selector:
@ -718,8 +801,15 @@ def _query_impl(options, collection_name, num_to_skip, num_to_return, query, fie
def _query_compressed(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx=None
):
options: int,
collection_name: str,
num_to_skip: int,
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> Tuple[int, bytes, int]:
"""Internal compressed query message helper."""
op_query, max_bson_size = _query_impl(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
@ -729,8 +819,14 @@ def _query_compressed(
def _query_uncompressed(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
):
options: int,
collection_name: str,
num_to_skip: int,
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
) -> Tuple[int, bytes, int]:
"""Internal query message helper."""
op_query, max_bson_size = _query_impl(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
@ -744,8 +840,15 @@ if _use_c:
def _query(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx=None
):
options: int,
collection_name: str,
num_to_skip: int,
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> Tuple[int, bytes, int]:
"""Get a **query** message."""
if ctx:
return _query_compressed(
@ -759,7 +862,7 @@ def _query(
_pack_long_long = struct.Struct("<q").pack
def _get_more_impl(collection_name, num_to_return, cursor_id):
def _get_more_impl(collection_name: str, num_to_return: int, cursor_id: int) -> bytes:
"""Get an OP_GET_MORE message."""
return b"".join(
[
@ -771,12 +874,19 @@ def _get_more_impl(collection_name, num_to_return, cursor_id):
)
def _get_more_compressed(collection_name, num_to_return, cursor_id, ctx):
def _get_more_compressed(
collection_name: str,
num_to_return: int,
cursor_id: int,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> Tuple[int, bytes]:
"""Internal compressed getMore message helper."""
return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx)
def _get_more_uncompressed(collection_name, num_to_return, cursor_id):
def _get_more_uncompressed(
collection_name: str, num_to_return: int, cursor_id: int
) -> Tuple[int, bytes]:
"""Internal getMore message helper."""
return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id))
@ -785,7 +895,12 @@ if _use_c:
_get_more_uncompressed = _cmessage._get_more_message # noqa: F811
def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
def _get_more(
collection_name: str,
num_to_return: int,
cursor_id: int,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> Tuple[int, bytes]:
"""Get a **getMore** message."""
if ctx:
return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx)
@ -811,7 +926,15 @@ class _BulkWriteContext:
)
def __init__(
self, database_name, cmd_name, conn, operation_id, listeners, session, op_type, codec
self,
database_name: str,
cmd_name: str,
conn: Connection,
operation_id: int,
listeners: _EventListeners,
session: ClientSession,
op_type: int,
codec: CodecOptions,
):
self.db_name = database_name
self.conn = conn
@ -826,7 +949,9 @@ class _BulkWriteContext:
self.op_type = op_type
self.codec = codec
def _batch_command(self, cmd, docs):
def __batch_command(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
namespace = self.db_name + ".$cmd"
request_id, msg, to_send = _do_batched_op_msg(
namespace, self.op_type, cmd, docs, self.codec, self
@ -835,14 +960,18 @@ class _BulkWriteContext:
raise InvalidOperation("cannot do an empty bulk write")
return request_id, msg, to_send
def execute(self, cmd, docs, client):
request_id, msg, to_send = self._batch_command(cmd, docs)
def execute(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
request_id, msg, to_send = self.__batch_command(cmd, docs)
result = self.write_command(cmd, request_id, msg, to_send)
client._process_response(result, self.session)
return result, to_send
def execute_unack(self, cmd, docs, client):
request_id, msg, to_send = self._batch_command(cmd, docs)
def execute_unack(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
) -> List[Mapping[str, Any]]:
request_id, msg, to_send = self.__batch_command(cmd, docs)
# Though this isn't strictly a "legacy" write, the helper
# handles publishing commands and sending our message
# without receiving a result. Send 0 for max_doc_size
@ -852,12 +981,12 @@ class _BulkWriteContext:
return to_send
@property
def max_bson_size(self):
def max_bson_size(self) -> int:
"""A proxy for SockInfo.max_bson_size."""
return self.conn.max_bson_size
@property
def max_message_size(self):
def max_message_size(self) -> int:
"""A proxy for SockInfo.max_message_size."""
if self.compress:
# Subtract 16 bytes for the message header.
@ -865,16 +994,23 @@ class _BulkWriteContext:
return self.conn.max_message_size
@property
def max_write_batch_size(self):
def max_write_batch_size(self) -> int:
"""A proxy for SockInfo.max_write_batch_size."""
return self.conn.max_write_batch_size
@property
def max_split_size(self):
def max_split_size(self) -> int:
"""The maximum size of a BSON command before batch splitting."""
return self.max_bson_size
def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
def unack_write(
self,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
max_doc_size: int,
docs: List[Mapping[str, Any]],
) -> Optional[Mapping[str, Any]]:
"""A proxy for Connection.unack_write that handles event publishing."""
if self.publish:
assert self.start_time is not None
@ -896,9 +1032,9 @@ class _BulkWriteContext:
assert self.start_time is not None
duration = (datetime.datetime.now() - start) + duration
if isinstance(exc, OperationFailure):
failure = _convert_write_result(self.name, cmd, exc.details)
failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
self._fail(request_id, failure, duration)
@ -908,8 +1044,14 @@ class _BulkWriteContext:
return result
@_handle_reauth
def write_command(self, cmd, request_id, msg, docs):
"""A proxy for Connection.write_command that handles event publishing."""
def write_command(
self,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
docs: List[Mapping[str, Any]],
) -> Mapping[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
if self.publish:
assert self.start_time is not None
duration = datetime.datetime.now() - self.start_time
@ -924,7 +1066,7 @@ class _BulkWriteContext:
if self.publish:
duration = (datetime.datetime.now() - start) + duration
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure = exc.details
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
self._fail(request_id, failure, duration)
@ -933,7 +1075,9 @@ class _BulkWriteContext:
self.start_time = datetime.datetime.now()
return reply
def _start(self, cmd, request_id, docs):
def _start(
self, cmd: MutableMapping[str, Any], request_id: int, docs: List[Mapping[str, Any]]
) -> MutableMapping[str, Any]:
"""Publish a CommandStartedEvent."""
cmd[self.field] = docs
self.listeners.publish_command_start(
@ -946,7 +1090,7 @@ class _BulkWriteContext:
)
return cmd
def _succeed(self, request_id, reply, duration):
def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None:
"""Publish a CommandSucceededEvent."""
self.listeners.publish_command_success(
duration,
@ -958,7 +1102,7 @@ class _BulkWriteContext:
self.conn.service_id,
)
def _fail(self, request_id, failure, duration):
def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None:
"""Publish a CommandFailedEvent."""
self.listeners.publish_command_failure(
duration,
@ -981,7 +1125,9 @@ _MAX_SPLIT_SIZE_ENC = 2097152
class _EncryptedBulkWriteContext(_BulkWriteContext):
__slots__ = ()
def _batch_command(self, cmd, docs):
def __batch_command(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
namespace = self.db_name + ".$cmd"
msg, to_send = _encode_batched_write_command(
namespace, self.op_type, cmd, docs, self.codec, self
@ -994,15 +1140,19 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
cmd = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS)
return cmd, to_send
def execute(self, cmd, docs, client):
batched_cmd, to_send = self._batch_command(cmd, docs)
result = self.conn.command(
def execute(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
batched_cmd, to_send = self.__batch_command(cmd, docs)
result: Mapping[str, Any] = self.conn.command(
self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client
)
return result, to_send
def execute_unack(self, cmd, docs, client):
batched_cmd, to_send = self._batch_command(cmd, docs)
def execute_unack(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
) -> List[Mapping[str, Any]]:
batched_cmd, to_send = self.__batch_command(cmd, docs)
self.conn.command(
self.db_name,
batched_cmd,
@ -1013,7 +1163,7 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
return to_send
@property
def max_split_size(self):
def max_split_size(self) -> int:
"""Reduce the batch splitting size."""
return _MAX_SPLIT_SIZE_ENC
@ -1043,7 +1193,15 @@ _OP_MSG_MAP = {
}
def _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf):
def _batched_op_msg_impl(
operation: int,
command: Mapping[str, Any],
docs: List[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
buf: _BytesIO,
) -> Tuple[List[Mapping[str, Any]], int]:
"""Create a batched OP_MSG write."""
max_bson_size = ctx.max_bson_size
max_write_batch_size = ctx.max_write_batch_size
@ -1103,7 +1261,14 @@ def _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf):
return to_send, length
def _encode_batched_op_msg(operation, command, docs, ack, opts, ctx):
def _encode_batched_op_msg(
operation: int,
command: Mapping[str, Any],
docs: List[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[bytes, List[Mapping[str, Any]]]:
"""Encode the next batched insert, update, or delete operation
as OP_MSG.
"""
@ -1117,17 +1282,32 @@ if _use_c:
_encode_batched_op_msg = _cmessage._encode_batched_op_msg # noqa: F811
def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx):
def _batched_op_msg_compressed(
operation: int,
command: Mapping[str, Any],
docs: List[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
"""Create the next batched insert, update, or delete operation
with OP_MSG, compressed.
"""
data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)
assert ctx.conn.compression_context is not None
request_id, msg = _compress(2013, data, ctx.conn.compression_context)
return request_id, msg, to_send
def _batched_op_msg(operation, command, docs, ack, opts, ctx):
def _batched_op_msg(
operation: int,
command: Mapping[str, Any],
docs: List[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
"""OP_MSG implementation entry point."""
buf = _BytesIO()
@ -1152,7 +1332,14 @@ if _use_c:
_batched_op_msg = _cmessage._batched_op_msg # noqa: F811
def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx):
def _do_batched_op_msg(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: List[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
"""Create the next batched insert, update, or delete operation
using OP_MSG.
"""
@ -1169,7 +1356,14 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx):
# End OP_MSG -----------------------------------------------------
def _encode_batched_write_command(namespace, operation, command, docs, opts, ctx):
def _encode_batched_write_command(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: List[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> Tuple[bytes, List[Mapping[str, Any]]]:
"""Encode the next batched insert, update, or delete command."""
buf = _BytesIO()
@ -1181,7 +1375,15 @@ if _use_c:
_encode_batched_write_command = _cmessage._encode_batched_write_command # noqa: F811
def _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf):
def _batched_write_command_impl(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: List[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
buf: _BytesIO,
) -> Tuple[List[Mapping[str, Any]], int]:
"""Create a batched OP_QUERY write command."""
max_bson_size = ctx.max_bson_size
max_write_batch_size = ctx.max_write_batch_size
@ -1258,13 +1460,15 @@ class _OpReply:
UNPACK_FROM = struct.Struct("<iqii").unpack_from
OP_CODE = 1
def __init__(self, flags, cursor_id, number_returned, documents):
def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes):
self.flags = flags
self.cursor_id = Int64(cursor_id)
self.number_returned = number_returned
self.documents = documents
def raw_response(self, cursor_id=None, user_fields=None):
def raw_response(
self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None
) -> List[bytes]:
"""Check the response header from the database, without decoding BSON.
Check the response for errors and unpack.
@ -1309,11 +1513,11 @@ class _OpReply:
def unpack_response(
self,
cursor_id=None,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
user_fields=None,
legacy_response=False,
):
cursor_id: Optional[int] = None,
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[_DocumentOut]:
"""Unpack a response from the database and decode the BSON document(s).
Check the response for errors and unpack, returning a dictionary
@ -1337,24 +1541,24 @@ class _OpReply:
return bson.decode_all(self.documents, codec_options)
return bson._decode_all_selective(self.documents, codec_options, user_fields)
def command_response(self, codec_options):
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
"""Unpack a command response."""
docs = self.unpack_response(codec_options=codec_options)
assert self.number_returned == 1
return docs[0]
def raw_command_response(self):
def raw_command_response(self) -> NoReturn:
"""Return the bytes of the command response."""
# This should never be called on _OpReply.
raise NotImplementedError
@property
def more_to_come(self):
def more_to_come(self) -> bool:
"""Is the moreToCome bit set on this response?"""
return False
@classmethod
def unpack(cls, msg):
def unpack(cls, msg: bytes) -> _OpReply:
"""Construct an _OpReply from raw bytes."""
# PYTHON-945: ignore starting_from field.
flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg)
@ -1376,11 +1580,15 @@ class _OpMsg:
MORE_TO_COME = 1 << 1
EXHAUST_ALLOWED = 1 << 16 # Only present on requests.
def __init__(self, flags, payload_document):
def __init__(self, flags: int, payload_document: bytes):
self.flags = flags
self.payload_document = payload_document
def raw_response(self, cursor_id=None, user_fields={}): # noqa: B006
def raw_response(
self,
cursor_id: Optional[int] = None,
user_fields: Optional[Mapping[str, Any]] = {}, # noqa: B006
) -> List[Mapping[str, Any]]:
"""
cursor_id is ignored
user_fields is used to determine which fields must not be decoded
@ -1392,11 +1600,11 @@ class _OpMsg:
def unpack_response(
self,
cursor_id=None,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
user_fields=None,
legacy_response=False,
):
cursor_id: Optional[int] = None,
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[_DocumentOut]:
"""Unpack a OP_MSG command response.
:Parameters:
@ -1411,21 +1619,21 @@ class _OpMsg:
assert not legacy_response
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
def command_response(self, codec_options):
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
"""Unpack a command response."""
return self.unpack_response(codec_options=codec_options)[0]
def raw_command_response(self):
def raw_command_response(self) -> bytes:
"""Return the bytes of the command response."""
return self.payload_document
@property
def more_to_come(self):
def more_to_come(self) -> bool:
"""Is the moreToCome bit set on this response?"""
return self.flags & self.MORE_TO_COME
return bool(self.flags & self.MORE_TO_COME)
@classmethod
def unpack(cls, msg):
def unpack(cls, msg: bytes) -> _OpMsg:
"""Construct an _OpMsg from raw bytes."""
flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
if flags != 0:
@ -1444,7 +1652,7 @@ class _OpMsg:
return cls(flags, payload_document)
_UNPACK_REPLY = {
_UNPACK_REPLY: Dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
_OpReply.OP_CODE: _OpReply.unpack,
_OpMsg.OP_CODE: _OpMsg.unpack,
}

View File

@ -23,7 +23,6 @@ import time
from typing import (
TYPE_CHECKING,
Any,
Dict,
Mapping,
MutableMapping,
Optional,
@ -55,7 +54,7 @@ if TYPE_CHECKING:
from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _DocumentOut
from pymongo.typings import _Address, _DocumentOut, _DocumentType
from pymongo.write_concern import WriteConcern
_UNPACK_HEADER = struct.Struct("<iiii").unpack
@ -67,7 +66,7 @@ def command(
spec: MutableMapping[str, Any],
is_mongos: bool,
read_preference: _ServerMode,
codec_options: CodecOptions,
codec_options: CodecOptions[_DocumentType],
session: Optional[ClientSession],
client: Optional[MongoClient],
check: bool = True,
@ -84,7 +83,7 @@ def command(
user_fields: Optional[Mapping[str, Any]] = None,
exhaust_allowed: bool = False,
write_concern: Optional[WriteConcern] = None,
) -> Dict[str, Any]:
) -> _DocumentType:
"""Execute a command over the socket, or raise socket.error.
:Parameters:
@ -177,7 +176,7 @@ def command(
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc = {"ok": 1}
response_doc: _DocumentOut = {"ok": 1}
else:
reply = receive_message(conn, request_id)
conn.more_to_come = reply.more_to_come
@ -226,7 +225,7 @@ def command(
decrypted = client._encrypter.decrypt(reply.raw_command_response())
response_doc = _decode_all_selective(decrypted, codec_options, user_fields)[0]
return response_doc
return response_doc # type: ignore[return-value]
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack

View File

@ -15,14 +15,14 @@
"""Represent a response from the server."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
if TYPE_CHECKING:
from datetime import timedelta
from pymongo.message import _OpMsg, _OpReply
from pymongo.pool import Connection
from pymongo.typings import _Address
from pymongo.typings import _Address, _DocumentOut
class Response:
@ -35,7 +35,7 @@ class Response:
request_id: int,
duration: Optional[timedelta],
from_command: bool,
docs: List[Mapping[str, Any]],
docs: Sequence[Mapping[str, Any]],
):
"""Represent a response from the server.
@ -79,7 +79,7 @@ class Response:
return self._from_command
@property
def docs(self) -> List[Mapping[str, Any]]:
def docs(self) -> Sequence[Mapping[str, Any]]:
"""The decoded document(s)."""
return self._docs
@ -95,7 +95,7 @@ class PinnedResponse(Response):
request_id: int,
duration: Optional[timedelta],
from_command: bool,
docs: List[Mapping[str, Any]],
docs: List[_DocumentOut],
more_to_come: bool,
):
"""Represent a response to an exhaust cursor's initial query.

View File

@ -22,7 +22,6 @@ from typing import (
Callable,
ContextManager,
List,
Mapping,
Optional,
Tuple,
Union,
@ -111,7 +110,7 @@ class Server:
operation: Union[_Query, _GetMore],
read_preference: _ServerMode,
listeners: _EventListeners,
unpack_res: Callable[..., List[Mapping[str, Any]]],
unpack_res: Callable[..., List[_DocumentOut]],
) -> Response:
"""Run a _Query or _GetMore operation and return a Response object.
@ -193,9 +192,9 @@ class Server:
# Must publish in find / getMore / explain command response
# format.
if use_cmd:
res: _DocumentOut = docs[0] # type: ignore[assignment]
res: _DocumentOut = docs[0]
elif operation.name == "explain":
res = docs[0] if docs else {} # type: ignore[assignment]
res = docs[0] if docs else {}
else:
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1}
if operation.name == "find":

View File

@ -111,7 +111,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
# Insert record and verify failure.
with self.assertRaises(NotPrimaryError) as exc:
self.coll.insert_one({"test": 1})
self.assertEqual(exc.exception.details["code"], error_code) # type: ignore
self.assertEqual(exc.exception.details["code"], error_code) # type: ignore[call-overload]
# Retry before CMAPListener assertion if retry_before=True.
if retry:
self.coll.insert_one({"test": 1})

View File

@ -754,6 +754,7 @@ Bye"""
# Kill the cursor to simulate the cursor timing out on the server
# when an application spends a long time between two calls to
# readchunk().
assert client.address is not None
client._close_cursor_now(
outfile._GridOut__chunk_iter._cursor.cursor_id,
_CursorAddress(client.address, db.fs.chunks.full_name),

View File

@ -720,7 +720,8 @@ def server_started_with_auth(client):
try:
command_line = get_command_line(client)
except OperationFailure as e:
msg = e.details.get("errmsg", "") # type: ignore
assert e.details is not None
msg = e.details.get("errmsg", "")
if e.code == 13 or "unauthorized" in msg or "login" in msg:
# Unauthorized.
return True