PYTHON-3806 add types to message.py (#1312)
This commit is contained in:
parent
43046e04c0
commit
02a365276c
@ -153,7 +153,7 @@ def _merge_command(
|
||||
full_result["writeConcernErrors"].append(wce)
|
||||
|
||||
|
||||
def _raise_bulk_write_error(full_result: Mapping[str, Any]) -> NoReturn:
|
||||
def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn:
|
||||
"""Raise a BulkWriteError from the full bulk api result."""
|
||||
if full_result["writeErrors"]:
|
||||
full_result["writeErrors"].sort(key=lambda error: error["index"])
|
||||
|
||||
@ -986,7 +986,7 @@ class ClientSession:
|
||||
self,
|
||||
command: MutableMapping[str, Any],
|
||||
is_retryable: bool,
|
||||
read_preference: ReadPreference,
|
||||
read_preference: _ServerMode,
|
||||
conn: Connection,
|
||||
) -> None:
|
||||
self._check_ended()
|
||||
|
||||
@ -33,7 +33,7 @@ from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.typings import _Address, _DocumentType
|
||||
from pymongo.typings import _Address, _DocumentOut, _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.client_session import ClientSession
|
||||
@ -94,6 +94,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
self.__killed = True
|
||||
if self.__id and not already_killed:
|
||||
cursor_id = self.__id
|
||||
assert self.__address is not None
|
||||
address = _CursorAddress(self.__address, self.__ns)
|
||||
else:
|
||||
# Skip killCursors.
|
||||
@ -219,7 +220,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
codec_options: CodecOptions[Mapping[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[Mapping[str, Any]]:
|
||||
) -> List[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
def _refresh(self) -> int:
|
||||
@ -380,7 +381,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
|
||||
comment,
|
||||
)
|
||||
|
||||
def _unpack_response(
|
||||
def _unpack_response( # type: ignore[override]
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
@ -393,7 +394,7 @@ class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]):
|
||||
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
|
||||
# Re-assemble the array of documents into a document stream
|
||||
_convert_raw_document_lists_to_streams(raw_response[0])
|
||||
return raw_response
|
||||
return raw_response # type: ignore[return-value]
|
||||
|
||||
def __getitem__(self, index: int) -> NoReturn:
|
||||
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")
|
||||
|
||||
@ -57,7 +57,7 @@ from pymongo.message import (
|
||||
_RawBatchQuery,
|
||||
)
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.typings import _Address, _CollationIn, _DocumentType
|
||||
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsItems
|
||||
@ -420,6 +420,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
self.__killed = True
|
||||
if self.__id and not already_killed:
|
||||
cursor_id = self.__id
|
||||
assert self.__address is not None
|
||||
address = _CursorAddress(self.__address, f"{self.__dbname}.{self.__collname}")
|
||||
else:
|
||||
# Skip killCursors.
|
||||
@ -1129,7 +1130,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
codec_options: CodecOptions,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[Mapping[str, Any]]:
|
||||
) -> List[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
def _read_preference(self) -> _ServerMode:
|
||||
@ -1355,13 +1356,13 @@ class RawBatchCursor(Cursor, Generic[_DocumentType]):
|
||||
codec_options: CodecOptions[Mapping[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[Mapping[str, Any]]:
|
||||
) -> List[_DocumentOut]:
|
||||
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
|
||||
if not legacy_response:
|
||||
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
|
||||
# Re-assemble the array of documents into a document stream
|
||||
_convert_raw_document_lists_to_streams(raw_response[0])
|
||||
return raw_response
|
||||
return cast(List["_DocumentOut"], raw_response)
|
||||
|
||||
def explain(self) -> _DocumentType:
|
||||
"""Returns an explain plan record for this cursor.
|
||||
|
||||
@ -80,8 +80,6 @@ from pymongo.write_concern import WriteConcern
|
||||
if TYPE_CHECKING:
|
||||
from pymongocrypt.mongocrypt import MongoCryptKmsContext
|
||||
|
||||
from pymongo.response import Response
|
||||
|
||||
_HTTPS_PORT = 443
|
||||
_KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to CONNECT_TIMEOUT
|
||||
_MONGOCRYPTD_TIMEOUT_MS = 10000
|
||||
@ -409,7 +407,7 @@ class _Encrypter:
|
||||
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
|
||||
return encrypt_cmd
|
||||
|
||||
def decrypt(self, response: Response) -> Optional[bytes]:
|
||||
def decrypt(self, response: bytes) -> Optional[bytes]:
|
||||
"""Decrypt a MongoDB command response.
|
||||
|
||||
:Parameters:
|
||||
|
||||
@ -13,10 +13,25 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Exceptions raised by PyMongo."""
|
||||
from typing import Any, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.errors import InvalidDocument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
try:
|
||||
# CPython 3.7+
|
||||
from ssl import SSLCertVerificationError as _CertificateError
|
||||
@ -286,9 +301,9 @@ class BulkWriteError(OperationFailure):
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
|
||||
details: Mapping[str, Any]
|
||||
details: _DocumentOut
|
||||
|
||||
def __init__(self, results: Mapping[str, Any]) -> None:
|
||||
def __init__(self, results: _DocumentOut) -> None:
|
||||
super().__init__("batch op errors occurred", 65, results)
|
||||
|
||||
def __reduce__(self) -> Tuple[Any, Any]:
|
||||
|
||||
@ -53,6 +53,7 @@ from pymongo.hello import HelloCompat
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.cursor import _Hint
|
||||
from pymongo.operations import _IndexList
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
# From the SDAM spec, the "node is shutting down" codes.
|
||||
_SHUTDOWN_CODES: frozenset = frozenset(
|
||||
@ -156,7 +157,7 @@ def _index_document(index_list: _IndexList) -> SON[str, Any]:
|
||||
|
||||
|
||||
def _check_command_response(
|
||||
response: Mapping[str, Any],
|
||||
response: _DocumentOut,
|
||||
max_wire_version: Optional[int],
|
||||
allowable_errors: Optional[Container[Union[int, str]]] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
|
||||
@ -19,12 +19,27 @@ MongoDB.
|
||||
.. note:: This module is for internal use and is generally not needed by
|
||||
application developers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import random
|
||||
import struct
|
||||
from io import BytesIO as _BytesIO
|
||||
from typing import Any, Mapping, NoReturn
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import bson
|
||||
from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode
|
||||
@ -58,6 +73,18 @@ from pymongo.helpers import _handle_reauth
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _Address, _DocumentOut
|
||||
|
||||
MAX_INT32 = 2147483647
|
||||
MIN_INT32 = -2147483648
|
||||
|
||||
@ -87,12 +114,14 @@ _UNICODE_REPLACE_CODEC_OPTIONS: "CodecOptions[Mapping[str, Any]]" = CodecOptions
|
||||
)
|
||||
|
||||
|
||||
def _randint():
|
||||
def _randint() -> int:
|
||||
"""Generate a pseudo random 32 bit integer."""
|
||||
return random.randint(MIN_INT32, MAX_INT32)
|
||||
|
||||
|
||||
def _maybe_add_read_preference(spec, read_preference):
|
||||
def _maybe_add_read_preference(
|
||||
spec: MutableMapping[str, Any], read_preference: _ServerMode
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Add $readPreference to spec when appropriate."""
|
||||
mode = read_preference.mode
|
||||
document = read_preference.document
|
||||
@ -108,12 +137,14 @@ def _maybe_add_read_preference(spec, read_preference):
|
||||
return spec
|
||||
|
||||
|
||||
def _convert_exception(exception):
|
||||
def _convert_exception(exception: Exception) -> Dict[str, Any]:
|
||||
"""Convert an Exception into a failure document for publishing."""
|
||||
return {"errmsg": str(exception), "errtype": exception.__class__.__name__}
|
||||
|
||||
|
||||
def _convert_write_result(operation, command, result):
|
||||
def _convert_write_result(
|
||||
operation: str, command: Mapping[str, Any], result: Mapping[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a legacy write result to write command format."""
|
||||
# Based on _merge_legacy from bulk.py
|
||||
affected = result.get("n", 0)
|
||||
@ -177,20 +208,20 @@ _MODIFIERS = SON(
|
||||
|
||||
|
||||
def _gen_find_command(
|
||||
coll,
|
||||
spec,
|
||||
projection,
|
||||
skip,
|
||||
limit,
|
||||
batch_size,
|
||||
options,
|
||||
read_concern,
|
||||
collation=None,
|
||||
session=None,
|
||||
allow_disk_use=None,
|
||||
):
|
||||
coll: str,
|
||||
spec: Mapping[str, Any],
|
||||
projection: Optional[Union[Mapping[str, Any], Iterable[str]]],
|
||||
skip: int,
|
||||
limit: int,
|
||||
batch_size: Optional[int],
|
||||
options: Optional[int],
|
||||
read_concern: ReadConcern,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
allow_disk_use: Optional[bool] = None,
|
||||
) -> SON[str, Any]:
|
||||
"""Generate a find command document."""
|
||||
cmd = SON([("find", coll)])
|
||||
cmd: SON[str, Any] = SON([("find", coll)])
|
||||
if "$query" in spec:
|
||||
cmd.update(
|
||||
[
|
||||
@ -227,9 +258,16 @@ def _gen_find_command(
|
||||
return cmd
|
||||
|
||||
|
||||
def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, conn):
|
||||
def _gen_get_more_command(
|
||||
cursor_id: Optional[int],
|
||||
coll: str,
|
||||
batch_size: Optional[int],
|
||||
max_await_time_ms: Optional[int],
|
||||
comment: Optional[Any],
|
||||
conn: Connection,
|
||||
) -> SON[str, Any]:
|
||||
"""Generate a getMore command document."""
|
||||
cmd = SON([("getMore", cursor_id), ("collection", coll)])
|
||||
cmd: SON[str, Any] = SON([("getMore", cursor_id), ("collection", coll)])
|
||||
if batch_size:
|
||||
cmd["batchSize"] = batch_size
|
||||
if max_await_time_ms is not None:
|
||||
@ -269,22 +307,22 @@ class _Query:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flags,
|
||||
db,
|
||||
coll,
|
||||
ntoskip,
|
||||
spec,
|
||||
fields,
|
||||
codec_options,
|
||||
read_preference,
|
||||
limit,
|
||||
batch_size,
|
||||
read_concern,
|
||||
collation,
|
||||
session,
|
||||
client,
|
||||
allow_disk_use,
|
||||
exhaust,
|
||||
flags: int,
|
||||
db: str,
|
||||
coll: str,
|
||||
ntoskip: int,
|
||||
spec: Mapping[str, Any],
|
||||
fields: Optional[Mapping[str, Any]],
|
||||
codec_options: CodecOptions,
|
||||
read_preference: _ServerMode,
|
||||
limit: int,
|
||||
batch_size: int,
|
||||
read_concern: ReadConcern,
|
||||
collation: Optional[Mapping[str, Any]],
|
||||
session: Optional[ClientSession],
|
||||
client: MongoClient,
|
||||
allow_disk_use: Optional[bool],
|
||||
exhaust: bool,
|
||||
):
|
||||
self.flags = flags
|
||||
self.db = db
|
||||
@ -302,16 +340,16 @@ class _Query:
|
||||
self.client = client
|
||||
self.allow_disk_use = allow_disk_use
|
||||
self.name = "find"
|
||||
self._as_command = None
|
||||
self._as_command: Optional[Tuple[SON[str, Any], str]] = None
|
||||
self.exhaust = exhaust
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self._as_command = None
|
||||
|
||||
def namespace(self):
|
||||
def namespace(self) -> str:
|
||||
return f"{self.db}.{self.coll}"
|
||||
|
||||
def use_command(self, conn):
|
||||
def use_command(self, conn: Connection) -> bool:
|
||||
use_find_cmd = False
|
||||
if not self.exhaust:
|
||||
use_find_cmd = True
|
||||
@ -327,7 +365,9 @@ class _Query:
|
||||
conn.validate_session(self.client, self.session)
|
||||
return use_find_cmd
|
||||
|
||||
def as_command(self, conn, apply_timeout=False):
|
||||
def as_command(
|
||||
self, conn: Connection, apply_timeout: bool = False
|
||||
) -> Tuple[SON[str, Any], str]:
|
||||
"""Return a find command document for this query."""
|
||||
# We use the command twice: on the wire and for command monitoring.
|
||||
# Generate it once, for speed and to avoid repeating side-effects.
|
||||
@ -335,7 +375,7 @@ class _Query:
|
||||
return self._as_command
|
||||
|
||||
explain = "$explain" in self.spec
|
||||
cmd = _gen_find_command(
|
||||
cmd: SON[str, Any] = _gen_find_command(
|
||||
self.coll,
|
||||
self.spec,
|
||||
self.fields,
|
||||
@ -362,14 +402,16 @@ class _Query:
|
||||
# Support auto encryption
|
||||
client = self.client
|
||||
if client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
|
||||
cmd = cast(SON[str, Any], client._encrypter.encrypt(self.db, cmd, self.codec_options))
|
||||
# Support CSOT
|
||||
if apply_timeout:
|
||||
conn.apply_timeout(client, cmd)
|
||||
self._as_command = cmd, self.db
|
||||
return self._as_command
|
||||
|
||||
def get_message(self, read_preference, conn, use_cmd=False):
|
||||
def get_message(
|
||||
self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False
|
||||
) -> Tuple[int, bytes, int]:
|
||||
"""Get a query message, possibly setting the secondaryOk bit."""
|
||||
# Use the read_preference decided by _socket_from_server.
|
||||
self.read_preference = read_preference
|
||||
@ -405,6 +447,7 @@ class _Query:
|
||||
ntoreturn = self.limit
|
||||
|
||||
if conn.is_mongos:
|
||||
assert isinstance(spec, MutableMapping)
|
||||
spec = _maybe_add_read_preference(spec, read_preference)
|
||||
|
||||
return _query(
|
||||
@ -442,18 +485,18 @@ class _GetMore:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db,
|
||||
coll,
|
||||
ntoreturn,
|
||||
cursor_id,
|
||||
codec_options,
|
||||
read_preference,
|
||||
session,
|
||||
client,
|
||||
max_await_time_ms,
|
||||
conn_mgr,
|
||||
exhaust,
|
||||
comment,
|
||||
db: str,
|
||||
coll: str,
|
||||
ntoreturn: int,
|
||||
cursor_id: int,
|
||||
codec_options: CodecOptions,
|
||||
read_preference: _ServerMode,
|
||||
session: Optional[ClientSession],
|
||||
client: MongoClient,
|
||||
max_await_time_ms: Optional[int],
|
||||
conn_mgr: Any,
|
||||
exhaust: bool,
|
||||
comment: Any,
|
||||
):
|
||||
self.db = db
|
||||
self.coll = coll
|
||||
@ -465,17 +508,17 @@ class _GetMore:
|
||||
self.client = client
|
||||
self.max_await_time_ms = max_await_time_ms
|
||||
self.conn_mgr = conn_mgr
|
||||
self._as_command = None
|
||||
self._as_command: Optional[Tuple[SON[str, Any], str]] = None
|
||||
self.exhaust = exhaust
|
||||
self.comment = comment
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self._as_command = None
|
||||
|
||||
def namespace(self):
|
||||
def namespace(self) -> str:
|
||||
return f"{self.db}.{self.coll}"
|
||||
|
||||
def use_command(self, conn):
|
||||
def use_command(self, conn: Connection) -> bool:
|
||||
use_cmd = False
|
||||
if not self.exhaust:
|
||||
use_cmd = True
|
||||
@ -486,13 +529,15 @@ class _GetMore:
|
||||
conn.validate_session(self.client, self.session)
|
||||
return use_cmd
|
||||
|
||||
def as_command(self, conn, apply_timeout=False):
|
||||
def as_command(
|
||||
self, conn: Connection, apply_timeout: bool = False
|
||||
) -> Tuple[SON[str, Any], str]:
|
||||
"""Return a getMore command document for this query."""
|
||||
# See _Query.as_command for an explanation of this caching.
|
||||
if self._as_command is not None:
|
||||
return self._as_command
|
||||
|
||||
cmd = _gen_get_more_command(
|
||||
cmd: SON[str, Any] = _gen_get_more_command(
|
||||
self.cursor_id,
|
||||
self.coll,
|
||||
self.ntoreturn,
|
||||
@ -507,14 +552,16 @@ class _GetMore:
|
||||
# Support auto encryption
|
||||
client = self.client
|
||||
if client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
|
||||
cmd = cast(SON[str, Any], client._encrypter.encrypt(self.db, cmd, self.codec_options))
|
||||
# Support CSOT
|
||||
if apply_timeout:
|
||||
conn.apply_timeout(client, cmd=None)
|
||||
self._as_command = cmd, self.db
|
||||
return self._as_command
|
||||
|
||||
def get_message(self, dummy0, conn, use_cmd=False):
|
||||
def get_message(
|
||||
self, dummy0: Any, conn: Connection, use_cmd: bool = False
|
||||
) -> Union[Tuple[int, bytes, int], Tuple[int, bytes]]:
|
||||
"""Get a getmore message."""
|
||||
ns = self.namespace()
|
||||
ctx = conn.compression_context
|
||||
@ -534,7 +581,7 @@ class _GetMore:
|
||||
|
||||
|
||||
class _RawBatchQuery(_Query):
|
||||
def use_command(self, conn):
|
||||
def use_command(self, conn: Connection) -> bool:
|
||||
# Compatibility checks.
|
||||
super().use_command(conn)
|
||||
if conn.max_wire_version >= 8:
|
||||
@ -546,7 +593,7 @@ class _RawBatchQuery(_Query):
|
||||
|
||||
|
||||
class _RawBatchGetMore(_GetMore):
|
||||
def use_command(self, conn):
|
||||
def use_command(self, conn: Connection) -> bool:
|
||||
# Compatibility checks.
|
||||
super().use_command(conn)
|
||||
if conn.max_wire_version >= 8:
|
||||
@ -562,27 +609,27 @@ class _CursorAddress(tuple):
|
||||
|
||||
__namespace: Any
|
||||
|
||||
def __new__(cls, address, namespace):
|
||||
def __new__(cls, address: _Address, namespace: str) -> _CursorAddress:
|
||||
self = tuple.__new__(cls, address)
|
||||
self.__namespace = namespace
|
||||
return self
|
||||
|
||||
@property
|
||||
def namespace(self):
|
||||
def namespace(self) -> str:
|
||||
"""The namespace this cursor."""
|
||||
return self.__namespace
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
# Two _CursorAddress instances with different namespaces
|
||||
# must not hash the same.
|
||||
return ((*self, self.__namespace)).__hash__()
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, _CursorAddress):
|
||||
return tuple(self) == tuple(other) and self.namespace == other.namespace
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
@ -590,7 +637,9 @@ _pack_compression_header = struct.Struct("<iiiiiiB").pack
|
||||
_COMPRESSION_HEADER_SIZE = 25
|
||||
|
||||
|
||||
def _compress(operation, data, ctx):
|
||||
def _compress(
|
||||
operation: int, data: bytes, ctx: Union[SnappyContext, ZlibContext, ZstdContext]
|
||||
) -> Tuple[int, bytes]:
|
||||
"""Takes message data, compresses it, and adds an OP_COMPRESSED header."""
|
||||
compressed = ctx.compress(data)
|
||||
request_id = _randint()
|
||||
@ -610,7 +659,7 @@ def _compress(operation, data, ctx):
|
||||
_pack_header = struct.Struct("<iiii").pack
|
||||
|
||||
|
||||
def __pack_message(operation, data):
|
||||
def __pack_message(operation: int, data: bytes) -> Tuple[int, bytes]:
|
||||
"""Takes message data and adds a message header based on the operation.
|
||||
|
||||
Returns the resultant message string.
|
||||
@ -625,7 +674,13 @@ _pack_op_msg_flags_type = struct.Struct("<IB").pack
|
||||
_pack_byte = struct.Struct("<B").pack
|
||||
|
||||
|
||||
def _op_msg_no_header(flags, command, identifier, docs, opts):
|
||||
def _op_msg_no_header(
|
||||
flags: int,
|
||||
command: Mapping[str, Any],
|
||||
identifier: str,
|
||||
docs: Optional[List[Mapping[str, Any]]],
|
||||
opts: CodecOptions,
|
||||
) -> Tuple[bytes, int, int]:
|
||||
"""Get a OP_MSG message.
|
||||
|
||||
Note: this method handles multiple documents in a type one payload but
|
||||
@ -637,7 +692,7 @@ def _op_msg_no_header(flags, command, identifier, docs, opts):
|
||||
flags_type = _pack_op_msg_flags_type(flags, 0)
|
||||
total_size = len(encoded)
|
||||
max_doc_size = 0
|
||||
if identifier:
|
||||
if identifier and docs is not None:
|
||||
type_one = _pack_byte(1)
|
||||
cstring = _make_c_string(identifier)
|
||||
encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs]
|
||||
@ -651,14 +706,27 @@ def _op_msg_no_header(flags, command, identifier, docs, opts):
|
||||
return b"".join(data), total_size, max_doc_size
|
||||
|
||||
|
||||
def _op_msg_compressed(flags, command, identifier, docs, opts, ctx):
|
||||
def _op_msg_compressed(
|
||||
flags: int,
|
||||
command: Mapping[str, Any],
|
||||
identifier: str,
|
||||
docs: Optional[List[Mapping[str, Any]]],
|
||||
opts: CodecOptions,
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
|
||||
) -> Tuple[int, bytes, int, int]:
|
||||
"""Internal OP_MSG message helper."""
|
||||
msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
|
||||
rid, msg = _compress(2013, msg, ctx)
|
||||
return rid, msg, total_size, max_bson_size
|
||||
|
||||
|
||||
def _op_msg_uncompressed(flags, command, identifier, docs, opts):
|
||||
def _op_msg_uncompressed(
|
||||
flags: int,
|
||||
command: Mapping[str, Any],
|
||||
identifier: str,
|
||||
docs: Optional[List[Mapping[str, Any]]],
|
||||
opts: CodecOptions,
|
||||
) -> Tuple[int, bytes, int, int]:
|
||||
"""Internal compressed OP_MSG message helper."""
|
||||
data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
|
||||
request_id, op_message = __pack_message(2013, data)
|
||||
@ -669,7 +737,14 @@ if _use_c:
|
||||
_op_msg_uncompressed = _cmessage._op_msg # noqa: F811
|
||||
|
||||
|
||||
def _op_msg(flags, command, dbname, read_preference, opts, ctx=None):
|
||||
def _op_msg(
|
||||
flags: int,
|
||||
command: MutableMapping[str, Any],
|
||||
dbname: str,
|
||||
read_preference: Optional[_ServerMode],
|
||||
opts: CodecOptions,
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
) -> Tuple[int, bytes, int, int]:
|
||||
"""Get a OP_MSG message."""
|
||||
command["$db"] = dbname
|
||||
# getMore commands do not send $readPreference.
|
||||
@ -679,7 +754,7 @@ def _op_msg(flags, command, dbname, read_preference, opts, ctx=None):
|
||||
command["$readPreference"] = read_preference.document
|
||||
name = next(iter(command))
|
||||
try:
|
||||
identifier = _FIELD_MAP.get(name)
|
||||
identifier = _FIELD_MAP[name]
|
||||
docs = command.pop(identifier)
|
||||
except KeyError:
|
||||
identifier = ""
|
||||
@ -694,7 +769,15 @@ def _op_msg(flags, command, dbname, read_preference, opts, ctx=None):
|
||||
command[identifier] = docs
|
||||
|
||||
|
||||
def _query_impl(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts):
|
||||
def _query_impl(
|
||||
options: int,
|
||||
collection_name: str,
|
||||
num_to_skip: int,
|
||||
num_to_return: int,
|
||||
query: Mapping[str, Any],
|
||||
field_selector: Optional[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
) -> Tuple[bytes, int]:
|
||||
"""Get an OP_QUERY message."""
|
||||
encoded = _dict_to_bson(query, False, opts)
|
||||
if field_selector:
|
||||
@ -718,8 +801,15 @@ def _query_impl(options, collection_name, num_to_skip, num_to_return, query, fie
|
||||
|
||||
|
||||
def _query_compressed(
|
||||
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx=None
|
||||
):
|
||||
options: int,
|
||||
collection_name: str,
|
||||
num_to_skip: int,
|
||||
num_to_return: int,
|
||||
query: Mapping[str, Any],
|
||||
field_selector: Optional[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
|
||||
) -> Tuple[int, bytes, int]:
|
||||
"""Internal compressed query message helper."""
|
||||
op_query, max_bson_size = _query_impl(
|
||||
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
|
||||
@ -729,8 +819,14 @@ def _query_compressed(
|
||||
|
||||
|
||||
def _query_uncompressed(
|
||||
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
|
||||
):
|
||||
options: int,
|
||||
collection_name: str,
|
||||
num_to_skip: int,
|
||||
num_to_return: int,
|
||||
query: Mapping[str, Any],
|
||||
field_selector: Optional[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
) -> Tuple[int, bytes, int]:
|
||||
"""Internal query message helper."""
|
||||
op_query, max_bson_size = _query_impl(
|
||||
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
|
||||
@ -744,8 +840,15 @@ if _use_c:
|
||||
|
||||
|
||||
def _query(
|
||||
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx=None
|
||||
):
|
||||
options: int,
|
||||
collection_name: str,
|
||||
num_to_skip: int,
|
||||
num_to_return: int,
|
||||
query: Mapping[str, Any],
|
||||
field_selector: Optional[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
) -> Tuple[int, bytes, int]:
|
||||
"""Get a **query** message."""
|
||||
if ctx:
|
||||
return _query_compressed(
|
||||
@ -759,7 +862,7 @@ def _query(
|
||||
_pack_long_long = struct.Struct("<q").pack
|
||||
|
||||
|
||||
def _get_more_impl(collection_name, num_to_return, cursor_id):
|
||||
def _get_more_impl(collection_name: str, num_to_return: int, cursor_id: int) -> bytes:
|
||||
"""Get an OP_GET_MORE message."""
|
||||
return b"".join(
|
||||
[
|
||||
@ -771,12 +874,19 @@ def _get_more_impl(collection_name, num_to_return, cursor_id):
|
||||
)
|
||||
|
||||
|
||||
def _get_more_compressed(collection_name, num_to_return, cursor_id, ctx):
|
||||
def _get_more_compressed(
|
||||
collection_name: str,
|
||||
num_to_return: int,
|
||||
cursor_id: int,
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
|
||||
) -> Tuple[int, bytes]:
|
||||
"""Internal compressed getMore message helper."""
|
||||
return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx)
|
||||
|
||||
|
||||
def _get_more_uncompressed(collection_name, num_to_return, cursor_id):
|
||||
def _get_more_uncompressed(
|
||||
collection_name: str, num_to_return: int, cursor_id: int
|
||||
) -> Tuple[int, bytes]:
|
||||
"""Internal getMore message helper."""
|
||||
return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id))
|
||||
|
||||
@ -785,7 +895,12 @@ if _use_c:
|
||||
_get_more_uncompressed = _cmessage._get_more_message # noqa: F811
|
||||
|
||||
|
||||
def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
|
||||
def _get_more(
|
||||
collection_name: str,
|
||||
num_to_return: int,
|
||||
cursor_id: int,
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
) -> Tuple[int, bytes]:
|
||||
"""Get a **getMore** message."""
|
||||
if ctx:
|
||||
return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx)
|
||||
@ -811,7 +926,15 @@ class _BulkWriteContext:
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, database_name, cmd_name, conn, operation_id, listeners, session, op_type, codec
|
||||
self,
|
||||
database_name: str,
|
||||
cmd_name: str,
|
||||
conn: Connection,
|
||||
operation_id: int,
|
||||
listeners: _EventListeners,
|
||||
session: ClientSession,
|
||||
op_type: int,
|
||||
codec: CodecOptions,
|
||||
):
|
||||
self.db_name = database_name
|
||||
self.conn = conn
|
||||
@ -826,7 +949,9 @@ class _BulkWriteContext:
|
||||
self.op_type = op_type
|
||||
self.codec = codec
|
||||
|
||||
def _batch_command(self, cmd, docs):
|
||||
def __batch_command(
|
||||
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
|
||||
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
|
||||
namespace = self.db_name + ".$cmd"
|
||||
request_id, msg, to_send = _do_batched_op_msg(
|
||||
namespace, self.op_type, cmd, docs, self.codec, self
|
||||
@ -835,14 +960,18 @@ class _BulkWriteContext:
|
||||
raise InvalidOperation("cannot do an empty bulk write")
|
||||
return request_id, msg, to_send
|
||||
|
||||
def execute(self, cmd, docs, client):
|
||||
request_id, msg, to_send = self._batch_command(cmd, docs)
|
||||
def execute(
|
||||
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
|
||||
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
|
||||
request_id, msg, to_send = self.__batch_command(cmd, docs)
|
||||
result = self.write_command(cmd, request_id, msg, to_send)
|
||||
client._process_response(result, self.session)
|
||||
return result, to_send
|
||||
|
||||
def execute_unack(self, cmd, docs, client):
|
||||
request_id, msg, to_send = self._batch_command(cmd, docs)
|
||||
def execute_unack(
|
||||
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
|
||||
) -> List[Mapping[str, Any]]:
|
||||
request_id, msg, to_send = self.__batch_command(cmd, docs)
|
||||
# Though this isn't strictly a "legacy" write, the helper
|
||||
# handles publishing commands and sending our message
|
||||
# without receiving a result. Send 0 for max_doc_size
|
||||
@ -852,12 +981,12 @@ class _BulkWriteContext:
|
||||
return to_send
|
||||
|
||||
@property
|
||||
def max_bson_size(self):
|
||||
def max_bson_size(self) -> int:
|
||||
"""A proxy for SockInfo.max_bson_size."""
|
||||
return self.conn.max_bson_size
|
||||
|
||||
@property
|
||||
def max_message_size(self):
|
||||
def max_message_size(self) -> int:
|
||||
"""A proxy for SockInfo.max_message_size."""
|
||||
if self.compress:
|
||||
# Subtract 16 bytes for the message header.
|
||||
@ -865,16 +994,23 @@ class _BulkWriteContext:
|
||||
return self.conn.max_message_size
|
||||
|
||||
@property
|
||||
def max_write_batch_size(self):
|
||||
def max_write_batch_size(self) -> int:
|
||||
"""A proxy for SockInfo.max_write_batch_size."""
|
||||
return self.conn.max_write_batch_size
|
||||
|
||||
@property
|
||||
def max_split_size(self):
|
||||
def max_split_size(self) -> int:
|
||||
"""The maximum size of a BSON command before batch splitting."""
|
||||
return self.max_bson_size
|
||||
|
||||
def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
|
||||
def unack_write(
|
||||
self,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
max_doc_size: int,
|
||||
docs: List[Mapping[str, Any]],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for Connection.unack_write that handles event publishing."""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
@ -896,9 +1032,9 @@ class _BulkWriteContext:
|
||||
assert self.start_time is not None
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
if isinstance(exc, OperationFailure):
|
||||
failure = _convert_write_result(self.name, cmd, exc.details)
|
||||
failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type]
|
||||
elif isinstance(exc, NotPrimaryError):
|
||||
failure = exc.details
|
||||
failure = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
self._fail(request_id, failure, duration)
|
||||
@ -908,8 +1044,14 @@ class _BulkWriteContext:
|
||||
return result
|
||||
|
||||
@_handle_reauth
|
||||
def write_command(self, cmd, request_id, msg, docs):
|
||||
"""A proxy for Connection.write_command that handles event publishing."""
|
||||
def write_command(
|
||||
self,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
docs: List[Mapping[str, Any]],
|
||||
) -> Mapping[str, Any]:
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
@ -924,7 +1066,7 @@ class _BulkWriteContext:
|
||||
if self.publish:
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure = exc.details
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
self._fail(request_id, failure, duration)
|
||||
@ -933,7 +1075,9 @@ class _BulkWriteContext:
|
||||
self.start_time = datetime.datetime.now()
|
||||
return reply
|
||||
|
||||
def _start(self, cmd, request_id, docs):
|
||||
def _start(
|
||||
self, cmd: MutableMapping[str, Any], request_id: int, docs: List[Mapping[str, Any]]
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Publish a CommandStartedEvent."""
|
||||
cmd[self.field] = docs
|
||||
self.listeners.publish_command_start(
|
||||
@ -946,7 +1090,7 @@ class _BulkWriteContext:
|
||||
)
|
||||
return cmd
|
||||
|
||||
def _succeed(self, request_id, reply, duration):
|
||||
def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None:
|
||||
"""Publish a CommandSucceededEvent."""
|
||||
self.listeners.publish_command_success(
|
||||
duration,
|
||||
@ -958,7 +1102,7 @@ class _BulkWriteContext:
|
||||
self.conn.service_id,
|
||||
)
|
||||
|
||||
def _fail(self, request_id, failure, duration):
|
||||
def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None:
|
||||
"""Publish a CommandFailedEvent."""
|
||||
self.listeners.publish_command_failure(
|
||||
duration,
|
||||
@ -981,7 +1125,9 @@ _MAX_SPLIT_SIZE_ENC = 2097152
|
||||
class _EncryptedBulkWriteContext(_BulkWriteContext):
|
||||
__slots__ = ()
|
||||
|
||||
def _batch_command(self, cmd, docs):
|
||||
def __batch_command(
|
||||
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
|
||||
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
|
||||
namespace = self.db_name + ".$cmd"
|
||||
msg, to_send = _encode_batched_write_command(
|
||||
namespace, self.op_type, cmd, docs, self.codec, self
|
||||
@ -994,15 +1140,19 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
|
||||
cmd = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS)
|
||||
return cmd, to_send
|
||||
|
||||
def execute(self, cmd, docs, client):
|
||||
batched_cmd, to_send = self._batch_command(cmd, docs)
|
||||
result = self.conn.command(
|
||||
def execute(
|
||||
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
|
||||
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
|
||||
batched_cmd, to_send = self.__batch_command(cmd, docs)
|
||||
result: Mapping[str, Any] = self.conn.command(
|
||||
self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client
|
||||
)
|
||||
return result, to_send
|
||||
|
||||
def execute_unack(self, cmd, docs, client):
|
||||
batched_cmd, to_send = self._batch_command(cmd, docs)
|
||||
def execute_unack(
|
||||
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]], client: MongoClient
|
||||
) -> List[Mapping[str, Any]]:
|
||||
batched_cmd, to_send = self.__batch_command(cmd, docs)
|
||||
self.conn.command(
|
||||
self.db_name,
|
||||
batched_cmd,
|
||||
@ -1013,7 +1163,7 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
|
||||
return to_send
|
||||
|
||||
@property
|
||||
def max_split_size(self):
|
||||
def max_split_size(self) -> int:
|
||||
"""Reduce the batch splitting size."""
|
||||
return _MAX_SPLIT_SIZE_ENC
|
||||
|
||||
@ -1043,7 +1193,15 @@ _OP_MSG_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf):
|
||||
def _batched_op_msg_impl(
|
||||
operation: int,
|
||||
command: Mapping[str, Any],
|
||||
docs: List[Mapping[str, Any]],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
ctx: _BulkWriteContext,
|
||||
buf: _BytesIO,
|
||||
) -> Tuple[List[Mapping[str, Any]], int]:
|
||||
"""Create a batched OP_MSG write."""
|
||||
max_bson_size = ctx.max_bson_size
|
||||
max_write_batch_size = ctx.max_write_batch_size
|
||||
@ -1103,7 +1261,14 @@ def _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf):
|
||||
return to_send, length
|
||||
|
||||
|
||||
def _encode_batched_op_msg(operation, command, docs, ack, opts, ctx):
|
||||
def _encode_batched_op_msg(
|
||||
operation: int,
|
||||
command: Mapping[str, Any],
|
||||
docs: List[Mapping[str, Any]],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
ctx: _BulkWriteContext,
|
||||
) -> Tuple[bytes, List[Mapping[str, Any]]]:
|
||||
"""Encode the next batched insert, update, or delete operation
|
||||
as OP_MSG.
|
||||
"""
|
||||
@ -1117,17 +1282,32 @@ if _use_c:
|
||||
_encode_batched_op_msg = _cmessage._encode_batched_op_msg # noqa: F811
|
||||
|
||||
|
||||
def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx):
|
||||
def _batched_op_msg_compressed(
|
||||
operation: int,
|
||||
command: Mapping[str, Any],
|
||||
docs: List[Mapping[str, Any]],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
ctx: _BulkWriteContext,
|
||||
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
|
||||
"""Create the next batched insert, update, or delete operation
|
||||
with OP_MSG, compressed.
|
||||
"""
|
||||
data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)
|
||||
|
||||
assert ctx.conn.compression_context is not None
|
||||
request_id, msg = _compress(2013, data, ctx.conn.compression_context)
|
||||
return request_id, msg, to_send
|
||||
|
||||
|
||||
def _batched_op_msg(operation, command, docs, ack, opts, ctx):
|
||||
def _batched_op_msg(
|
||||
operation: int,
|
||||
command: Mapping[str, Any],
|
||||
docs: List[Mapping[str, Any]],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
ctx: _BulkWriteContext,
|
||||
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
|
||||
"""OP_MSG implementation entry point."""
|
||||
buf = _BytesIO()
|
||||
|
||||
@ -1152,7 +1332,14 @@ if _use_c:
|
||||
_batched_op_msg = _cmessage._batched_op_msg # noqa: F811
|
||||
|
||||
|
||||
def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx):
|
||||
def _do_batched_op_msg(
|
||||
namespace: str,
|
||||
operation: int,
|
||||
command: MutableMapping[str, Any],
|
||||
docs: List[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
ctx: _BulkWriteContext,
|
||||
) -> Tuple[int, bytes, List[Mapping[str, Any]]]:
|
||||
"""Create the next batched insert, update, or delete operation
|
||||
using OP_MSG.
|
||||
"""
|
||||
@ -1169,7 +1356,14 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx):
|
||||
# End OP_MSG -----------------------------------------------------
|
||||
|
||||
|
||||
def _encode_batched_write_command(namespace, operation, command, docs, opts, ctx):
|
||||
def _encode_batched_write_command(
|
||||
namespace: str,
|
||||
operation: int,
|
||||
command: MutableMapping[str, Any],
|
||||
docs: List[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
ctx: _BulkWriteContext,
|
||||
) -> Tuple[bytes, List[Mapping[str, Any]]]:
|
||||
"""Encode the next batched insert, update, or delete command."""
|
||||
buf = _BytesIO()
|
||||
|
||||
@ -1181,7 +1375,15 @@ if _use_c:
|
||||
_encode_batched_write_command = _cmessage._encode_batched_write_command # noqa: F811
|
||||
|
||||
|
||||
def _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf):
|
||||
def _batched_write_command_impl(
|
||||
namespace: str,
|
||||
operation: int,
|
||||
command: MutableMapping[str, Any],
|
||||
docs: List[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
ctx: _BulkWriteContext,
|
||||
buf: _BytesIO,
|
||||
) -> Tuple[List[Mapping[str, Any]], int]:
|
||||
"""Create a batched OP_QUERY write command."""
|
||||
max_bson_size = ctx.max_bson_size
|
||||
max_write_batch_size = ctx.max_write_batch_size
|
||||
@ -1258,13 +1460,15 @@ class _OpReply:
|
||||
UNPACK_FROM = struct.Struct("<iqii").unpack_from
|
||||
OP_CODE = 1
|
||||
|
||||
def __init__(self, flags, cursor_id, number_returned, documents):
|
||||
def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes):
|
||||
self.flags = flags
|
||||
self.cursor_id = Int64(cursor_id)
|
||||
self.number_returned = number_returned
|
||||
self.documents = documents
|
||||
|
||||
def raw_response(self, cursor_id=None, user_fields=None):
|
||||
def raw_response(
|
||||
self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None
|
||||
) -> List[bytes]:
|
||||
"""Check the response header from the database, without decoding BSON.
|
||||
|
||||
Check the response for errors and unpack.
|
||||
@ -1309,11 +1513,11 @@ class _OpReply:
|
||||
|
||||
def unpack_response(
|
||||
self,
|
||||
cursor_id=None,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields=None,
|
||||
legacy_response=False,
|
||||
):
|
||||
cursor_id: Optional[int] = None,
|
||||
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[_DocumentOut]:
|
||||
"""Unpack a response from the database and decode the BSON document(s).
|
||||
|
||||
Check the response for errors and unpack, returning a dictionary
|
||||
@ -1337,24 +1541,24 @@ class _OpReply:
|
||||
return bson.decode_all(self.documents, codec_options)
|
||||
return bson._decode_all_selective(self.documents, codec_options, user_fields)
|
||||
|
||||
def command_response(self, codec_options):
|
||||
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
|
||||
"""Unpack a command response."""
|
||||
docs = self.unpack_response(codec_options=codec_options)
|
||||
assert self.number_returned == 1
|
||||
return docs[0]
|
||||
|
||||
def raw_command_response(self):
|
||||
def raw_command_response(self) -> NoReturn:
|
||||
"""Return the bytes of the command response."""
|
||||
# This should never be called on _OpReply.
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def more_to_come(self):
|
||||
def more_to_come(self) -> bool:
|
||||
"""Is the moreToCome bit set on this response?"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, msg):
|
||||
def unpack(cls, msg: bytes) -> _OpReply:
|
||||
"""Construct an _OpReply from raw bytes."""
|
||||
# PYTHON-945: ignore starting_from field.
|
||||
flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg)
|
||||
@ -1376,11 +1580,15 @@ class _OpMsg:
|
||||
MORE_TO_COME = 1 << 1
|
||||
EXHAUST_ALLOWED = 1 << 16 # Only present on requests.
|
||||
|
||||
def __init__(self, flags, payload_document):
|
||||
def __init__(self, flags: int, payload_document: bytes):
|
||||
self.flags = flags
|
||||
self.payload_document = payload_document
|
||||
|
||||
def raw_response(self, cursor_id=None, user_fields={}): # noqa: B006
|
||||
def raw_response(
|
||||
self,
|
||||
cursor_id: Optional[int] = None,
|
||||
user_fields: Optional[Mapping[str, Any]] = {}, # noqa: B006
|
||||
) -> List[Mapping[str, Any]]:
|
||||
"""
|
||||
cursor_id is ignored
|
||||
user_fields is used to determine which fields must not be decoded
|
||||
@ -1392,11 +1600,11 @@ class _OpMsg:
|
||||
|
||||
def unpack_response(
|
||||
self,
|
||||
cursor_id=None,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields=None,
|
||||
legacy_response=False,
|
||||
):
|
||||
cursor_id: Optional[int] = None,
|
||||
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[_DocumentOut]:
|
||||
"""Unpack a OP_MSG command response.
|
||||
|
||||
:Parameters:
|
||||
@ -1411,21 +1619,21 @@ class _OpMsg:
|
||||
assert not legacy_response
|
||||
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
|
||||
|
||||
def command_response(self, codec_options):
|
||||
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
|
||||
"""Unpack a command response."""
|
||||
return self.unpack_response(codec_options=codec_options)[0]
|
||||
|
||||
def raw_command_response(self):
|
||||
def raw_command_response(self) -> bytes:
|
||||
"""Return the bytes of the command response."""
|
||||
return self.payload_document
|
||||
|
||||
@property
|
||||
def more_to_come(self):
|
||||
def more_to_come(self) -> bool:
|
||||
"""Is the moreToCome bit set on this response?"""
|
||||
return self.flags & self.MORE_TO_COME
|
||||
return bool(self.flags & self.MORE_TO_COME)
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, msg):
|
||||
def unpack(cls, msg: bytes) -> _OpMsg:
|
||||
"""Construct an _OpMsg from raw bytes."""
|
||||
flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
|
||||
if flags != 0:
|
||||
@ -1444,7 +1652,7 @@ class _OpMsg:
|
||||
return cls(flags, payload_document)
|
||||
|
||||
|
||||
_UNPACK_REPLY = {
|
||||
_UNPACK_REPLY: Dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
|
||||
_OpReply.OP_CODE: _OpReply.unpack,
|
||||
_OpMsg.OP_CODE: _OpMsg.unpack,
|
||||
}
|
||||
|
||||
@ -23,7 +23,6 @@ import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
@ -55,7 +54,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _Address, _DocumentOut
|
||||
from pymongo.typings import _Address, _DocumentOut, _DocumentType
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
@ -67,7 +66,7 @@ def command(
|
||||
spec: MutableMapping[str, Any],
|
||||
is_mongos: bool,
|
||||
read_preference: _ServerMode,
|
||||
codec_options: CodecOptions,
|
||||
codec_options: CodecOptions[_DocumentType],
|
||||
session: Optional[ClientSession],
|
||||
client: Optional[MongoClient],
|
||||
check: bool = True,
|
||||
@ -84,7 +83,7 @@ def command(
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
exhaust_allowed: bool = False,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> _DocumentType:
|
||||
"""Execute a command over the socket, or raise socket.error.
|
||||
|
||||
:Parameters:
|
||||
@ -177,7 +176,7 @@ def command(
|
||||
if use_op_msg and unacknowledged:
|
||||
# Unacknowledged, fake a successful command response.
|
||||
reply = None
|
||||
response_doc = {"ok": 1}
|
||||
response_doc: _DocumentOut = {"ok": 1}
|
||||
else:
|
||||
reply = receive_message(conn, request_id)
|
||||
conn.more_to_come = reply.more_to_come
|
||||
@ -226,7 +225,7 @@ def command(
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
response_doc = _decode_all_selective(decrypted, codec_options, user_fields)[0]
|
||||
|
||||
return response_doc
|
||||
return response_doc # type: ignore[return-value]
|
||||
|
||||
|
||||
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
|
||||
@ -15,14 +15,14 @@
|
||||
"""Represent a response from the server."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.typings import _Address
|
||||
from pymongo.typings import _Address, _DocumentOut
|
||||
|
||||
|
||||
class Response:
|
||||
@ -35,7 +35,7 @@ class Response:
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
docs: List[Mapping[str, Any]],
|
||||
docs: Sequence[Mapping[str, Any]],
|
||||
):
|
||||
"""Represent a response from the server.
|
||||
|
||||
@ -79,7 +79,7 @@ class Response:
|
||||
return self._from_command
|
||||
|
||||
@property
|
||||
def docs(self) -> List[Mapping[str, Any]]:
|
||||
def docs(self) -> Sequence[Mapping[str, Any]]:
|
||||
"""The decoded document(s)."""
|
||||
return self._docs
|
||||
|
||||
@ -95,7 +95,7 @@ class PinnedResponse(Response):
|
||||
request_id: int,
|
||||
duration: Optional[timedelta],
|
||||
from_command: bool,
|
||||
docs: List[Mapping[str, Any]],
|
||||
docs: List[_DocumentOut],
|
||||
more_to_come: bool,
|
||||
):
|
||||
"""Represent a response to an exhaust cursor's initial query.
|
||||
|
||||
@ -22,7 +22,6 @@ from typing import (
|
||||
Callable,
|
||||
ContextManager,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
@ -111,7 +110,7 @@ class Server:
|
||||
operation: Union[_Query, _GetMore],
|
||||
read_preference: _ServerMode,
|
||||
listeners: _EventListeners,
|
||||
unpack_res: Callable[..., List[Mapping[str, Any]]],
|
||||
unpack_res: Callable[..., List[_DocumentOut]],
|
||||
) -> Response:
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
|
||||
@ -193,9 +192,9 @@ class Server:
|
||||
# Must publish in find / getMore / explain command response
|
||||
# format.
|
||||
if use_cmd:
|
||||
res: _DocumentOut = docs[0] # type: ignore[assignment]
|
||||
res: _DocumentOut = docs[0]
|
||||
elif operation.name == "explain":
|
||||
res = docs[0] if docs else {} # type: ignore[assignment]
|
||||
res = docs[0] if docs else {}
|
||||
else:
|
||||
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1}
|
||||
if operation.name == "find":
|
||||
|
||||
@ -111,7 +111,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
|
||||
# Insert record and verify failure.
|
||||
with self.assertRaises(NotPrimaryError) as exc:
|
||||
self.coll.insert_one({"test": 1})
|
||||
self.assertEqual(exc.exception.details["code"], error_code) # type: ignore
|
||||
self.assertEqual(exc.exception.details["code"], error_code) # type: ignore[call-overload]
|
||||
# Retry before CMAPListener assertion if retry_before=True.
|
||||
if retry:
|
||||
self.coll.insert_one({"test": 1})
|
||||
|
||||
@ -754,6 +754,7 @@ Bye"""
|
||||
# Kill the cursor to simulate the cursor timing out on the server
|
||||
# when an application spends a long time between two calls to
|
||||
# readchunk().
|
||||
assert client.address is not None
|
||||
client._close_cursor_now(
|
||||
outfile._GridOut__chunk_iter._cursor.cursor_id,
|
||||
_CursorAddress(client.address, db.fs.chunks.full_name),
|
||||
|
||||
@ -720,7 +720,8 @@ def server_started_with_auth(client):
|
||||
try:
|
||||
command_line = get_command_line(client)
|
||||
except OperationFailure as e:
|
||||
msg = e.details.get("errmsg", "") # type: ignore
|
||||
assert e.details is not None
|
||||
msg = e.details.get("errmsg", "")
|
||||
if e.code == 13 or "unauthorized" in msg or "login" in msg:
|
||||
# Unauthorized.
|
||||
return True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user