Merge 7b31987278 into 8dc7efade2
This commit is contained in:
commit
4ef38681df
349
test/test_message.py
Normal file
349
test/test_message.py
Normal file
@ -0,0 +1,349 @@
|
||||
# Copyright 2026-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for message.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
|
||||
from bson import CodecOptions, encode
|
||||
from pymongo.compression_support import ZlibContext
|
||||
from pymongo.errors import DocumentTooLarge, OperationFailure
|
||||
from pymongo.message import (
|
||||
_compress,
|
||||
_convert_client_bulk_exception,
|
||||
_convert_exception,
|
||||
_convert_write_result,
|
||||
_gen_find_command,
|
||||
_gen_get_more_command,
|
||||
_get_more_compressed,
|
||||
_get_more_uncompressed,
|
||||
_maybe_add_read_preference,
|
||||
_op_msg,
|
||||
_query_compressed,
|
||||
_query_uncompressed,
|
||||
_raise_document_too_large,
|
||||
)
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference, SecondaryPreferred
|
||||
|
||||
_OPTS = CodecOptions()
|
||||
_ZLIB_CTX = ZlibContext(-1)
|
||||
|
||||
|
||||
class TestMaybeAddReadPreference(unittest.TestCase):
|
||||
def test_primary_no_read_preference_added(self):
|
||||
spec: dict = {"find": "col"}
|
||||
result = _maybe_add_read_preference(spec, ReadPreference.PRIMARY)
|
||||
self.assertNotIn("$readPreference", result)
|
||||
self.assertNotIn("$query", result)
|
||||
|
||||
def test_secondary_adds_read_preference(self):
|
||||
spec: dict = {"find": "col"}
|
||||
result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY)
|
||||
self.assertIn("$readPreference", result)
|
||||
self.assertEqual(result["$readPreference"]["mode"], "secondary")
|
||||
self.assertIn("$query", result)
|
||||
|
||||
def test_secondary_preferred_no_tags_does_not_add(self):
|
||||
spec: dict = {"find": "col"}
|
||||
result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertNotIn("$readPreference", result)
|
||||
|
||||
def test_secondary_preferred_with_tags_adds_read_preference(self):
|
||||
pref = SecondaryPreferred(tag_sets=[{"dc": "east"}])
|
||||
spec: dict = {"find": "col"}
|
||||
result = _maybe_add_read_preference(spec, pref)
|
||||
self.assertIn("$readPreference", result)
|
||||
|
||||
def test_existing_query_wrapper_preserved(self):
|
||||
spec: dict = {"$query": {"x": 1}, "other": 2}
|
||||
result = _maybe_add_read_preference(spec, ReadPreference.SECONDARY)
|
||||
self.assertIn("$readPreference", result)
|
||||
self.assertEqual(result["$query"], {"x": 1})
|
||||
|
||||
|
||||
class TestConvertException(unittest.TestCase):
|
||||
def test_basic_exception(self):
|
||||
exc = ValueError("bad value")
|
||||
doc = _convert_exception(exc)
|
||||
self.assertEqual(doc["errmsg"], "bad value")
|
||||
self.assertEqual(doc["errtype"], "ValueError")
|
||||
|
||||
def test_client_bulk_exception_includes_code(self):
|
||||
exc = OperationFailure("failed", code=11000)
|
||||
doc = _convert_client_bulk_exception(exc)
|
||||
self.assertEqual(doc["errmsg"], "failed")
|
||||
self.assertEqual(doc["code"], 11000)
|
||||
self.assertEqual(doc["errtype"], "OperationFailure")
|
||||
|
||||
|
||||
class TestConvertWriteResult(unittest.TestCase):
|
||||
"""Tests for _convert_write_result.
|
||||
|
||||
In the update command spec, `q` is the query/filter and `u` is the update document.
|
||||
"""
|
||||
|
||||
def test_insert_basic(self):
|
||||
cmd = {"documents": [{"_id": 1}, {"_id": 2}]}
|
||||
result = _convert_write_result("insert", cmd, {"n": 0})
|
||||
self.assertEqual(result["ok"], 1)
|
||||
self.assertEqual(result["n"], 2)
|
||||
|
||||
def test_update_basic(self):
|
||||
cmd = {"updates": [{"q": {}, "u": {"$set": {"x": 1}}}]}
|
||||
result = _convert_write_result("update", cmd, {"n": 1, "updatedExisting": True})
|
||||
self.assertEqual(result["ok"], 1)
|
||||
self.assertNotIn("upserted", result)
|
||||
|
||||
def test_update_with_upserted_id(self):
|
||||
cmd = {"updates": [{"q": {}, "u": {"_id": 42}}]}
|
||||
result = _convert_write_result("update", cmd, {"n": 1, "upserted": 42})
|
||||
self.assertIn("upserted", result)
|
||||
self.assertEqual(result["upserted"][0]["_id"], 42)
|
||||
|
||||
def test_update_upsert_id_precedence(self):
|
||||
# When _id is in both the update document and the query spec,
|
||||
# the update document's _id wins.
|
||||
cmd = {"updates": [{"q": {"_id": 99}, "u": {"_id": 42}}]}
|
||||
result = _convert_write_result("update", cmd, {"n": 1, "updatedExisting": False})
|
||||
self.assertEqual(result["upserted"][0]["_id"], 42)
|
||||
|
||||
def test_update_upsert_no_upserted_id_from_query(self):
|
||||
cmd = {"updates": [{"q": {"_id": 77}, "u": {"$set": {"x": 1}}}]}
|
||||
result = _convert_write_result("update", cmd, {"n": 1, "updatedExisting": False})
|
||||
self.assertIn("upserted", result)
|
||||
self.assertEqual(result["upserted"][0]["_id"], 77)
|
||||
|
||||
def test_delete_basic(self):
|
||||
cmd = {"deletes": [{"q": {}, "limit": 1}]}
|
||||
result = _convert_write_result("delete", cmd, {"n": 1})
|
||||
self.assertEqual(result["ok"], 1)
|
||||
self.assertEqual(result["n"], 1)
|
||||
|
||||
def test_write_error(self):
|
||||
cmd = {"documents": [{"_id": 1}]}
|
||||
gle = {"n": 0, "err": "duplicate key error", "code": 11000}
|
||||
result = _convert_write_result("insert", cmd, gle)
|
||||
self.assertIn("writeErrors", result)
|
||||
self.assertEqual(result["writeErrors"][0]["code"], 11000)
|
||||
|
||||
def test_write_concern_timeout(self):
|
||||
cmd = {"documents": [{"_id": 1}]}
|
||||
gle = {"n": 1, "errmsg": "timeout", "wtimeout": True}
|
||||
result = _convert_write_result("insert", cmd, gle)
|
||||
self.assertIn("writeConcernError", result)
|
||||
self.assertEqual(result["writeConcernError"]["code"], 64)
|
||||
|
||||
def test_write_error_with_err_info(self):
|
||||
# Covers the `if "errInfo" in result:` branch, which test_write_error does not enter.
|
||||
cmd = {"documents": [{"_id": 1}]}
|
||||
gle = {"n": 0, "err": "err", "code": 123, "errInfo": {"detail": "x"}}
|
||||
result = _convert_write_result("insert", cmd, gle)
|
||||
self.assertIn("errInfo", result["writeErrors"][0])
|
||||
|
||||
|
||||
class TestCompress(unittest.TestCase):
|
||||
def test_compressed_message_has_op_compressed_header(self):
|
||||
msg = _compress(2013, b"hello world", _ZLIB_CTX)[1]
|
||||
op_code = struct.unpack("<i", msg[12:16])[0]
|
||||
self.assertEqual(op_code, 2012) # OP_COMPRESSED
|
||||
|
||||
|
||||
class TestOpMsg(unittest.TestCase):
|
||||
def test_uncompressed_op_code(self):
|
||||
msg = _op_msg(0, {"ping": 1}, "testdb", None, _OPTS)[1]
|
||||
op_code = struct.unpack("<i", msg[12:16])[0]
|
||||
self.assertEqual(op_code, 2013) # OP_MSG
|
||||
|
||||
def test_max_doc_size_zero_without_docs(self):
|
||||
max_doc_size = _op_msg(0, {"ping": 1}, "testdb", None, _OPTS)[3]
|
||||
self.assertEqual(max_doc_size, 0)
|
||||
|
||||
def test_max_doc_size_matches_largest_encoded_doc(self):
|
||||
docs = [{"_id": 1, "x": 2}, {"_id": 3, "x": 4}]
|
||||
cmd: dict = {"insert": "col", "documents": docs}
|
||||
max_doc_size = _op_msg(0, cmd, "testdb", None, _OPTS)[3]
|
||||
self.assertEqual(max_doc_size, max(len(encode(d)) for d in docs))
|
||||
|
||||
def test_read_preference_added_for_non_primary(self):
|
||||
cmd: dict = {"find": "col"}
|
||||
_op_msg(0, cmd, "testdb", ReadPreference.SECONDARY, _OPTS)
|
||||
self.assertIn("$readPreference", cmd)
|
||||
|
||||
def test_read_preference_skipped_if_already_present(self):
|
||||
cmd: dict = {"find": "col", "$readPreference": {"mode": "nearest"}}
|
||||
_op_msg(0, cmd, "testdb", ReadPreference.SECONDARY, _OPTS)
|
||||
self.assertEqual(cmd["$readPreference"]["mode"], "nearest")
|
||||
|
||||
def test_with_compression_context(self):
|
||||
msg = _op_msg(0, {"ping": 1}, "testdb", None, _OPTS, _ZLIB_CTX)[1]
|
||||
op_code = struct.unpack("<i", msg[12:16])[0]
|
||||
self.assertEqual(op_code, 2012) # OP_COMPRESSED
|
||||
|
||||
def test_command_with_documents_field_is_restored(self):
|
||||
docs = [{"_id": 1}]
|
||||
cmd: dict = {"insert": "col", "documents": docs}
|
||||
_op_msg(0, cmd, "testdb", None, _OPTS)
|
||||
self.assertIn("documents", cmd)
|
||||
self.assertEqual(cmd["documents"], docs)
|
||||
|
||||
|
||||
class TestLegacyWireOps(unittest.TestCase):
|
||||
"""Tests for pre-OP_MSG wire ops (OP_QUERY and OP_GET_MORE), compressed and uncompressed."""
|
||||
|
||||
def test_query_uncompressed_op_code(self):
|
||||
msg = _query_uncompressed(0, "db.col", 0, 0, {}, None, _OPTS)[1]
|
||||
op_code = struct.unpack("<i", msg[12:16])[0]
|
||||
self.assertEqual(op_code, 2004) # OP_QUERY
|
||||
|
||||
def test_query_compressed_op_code(self):
|
||||
msg = _query_compressed(0, "db.col", 0, 0, {}, None, _OPTS, _ZLIB_CTX)[1]
|
||||
op_code = struct.unpack("<i", msg[12:16])[0]
|
||||
self.assertEqual(op_code, 2012) # OP_COMPRESSED
|
||||
|
||||
def test_get_more_uncompressed_op_code(self):
|
||||
msg = _get_more_uncompressed("db.col", 0, 0)[1]
|
||||
op_code = struct.unpack("<i", msg[12:16])[0]
|
||||
self.assertEqual(op_code, 2005) # OP_GET_MORE
|
||||
|
||||
def test_get_more_compressed_op_code(self):
|
||||
msg = _get_more_compressed("db.col", 0, 0, _ZLIB_CTX)[1]
|
||||
op_code = struct.unpack("<i", msg[12:16])[0]
|
||||
self.assertEqual(op_code, 2012) # OP_COMPRESSED
|
||||
|
||||
|
||||
class TestRaiseDocumentTooLarge(unittest.TestCase):
|
||||
def test_insert_includes_sizes(self):
|
||||
with self.assertRaises(DocumentTooLarge) as ctx:
|
||||
_raise_document_too_large("insert", 2_000_000, 1_000_000)
|
||||
msg = str(ctx.exception)
|
||||
self.assertIn("2000000", msg)
|
||||
self.assertIn("1000000", msg)
|
||||
|
||||
def test_update_generic_message(self):
|
||||
with self.assertRaises(DocumentTooLarge) as ctx:
|
||||
_raise_document_too_large("update", 2_000_000, 1_000_000)
|
||||
self.assertIn("update", str(ctx.exception))
|
||||
|
||||
|
||||
class TestGenFindCommand(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
cmd = _gen_find_command("col", {}, None, 0, 0, None, None, ReadConcern())
|
||||
self.assertEqual(cmd["find"], "col")
|
||||
self.assertEqual(cmd["filter"], {})
|
||||
|
||||
def test_with_projection(self):
|
||||
cmd = _gen_find_command("col", {}, {"x": 1}, 0, 0, None, None, ReadConcern())
|
||||
self.assertEqual(cmd["projection"], {"x": 1})
|
||||
|
||||
def test_with_skip(self):
|
||||
cmd = _gen_find_command("col", {}, None, 5, 0, None, None, ReadConcern())
|
||||
self.assertEqual(cmd["skip"], 5)
|
||||
|
||||
def test_with_positive_limit(self):
|
||||
cmd = _gen_find_command("col", {}, None, 0, 10, None, None, ReadConcern())
|
||||
self.assertEqual(cmd["limit"], 10)
|
||||
self.assertNotIn("singleBatch", cmd)
|
||||
|
||||
def test_with_negative_limit_sets_single_batch(self):
|
||||
cmd = _gen_find_command("col", {}, None, 0, -5, None, None, ReadConcern())
|
||||
self.assertEqual(cmd["limit"], 5)
|
||||
self.assertTrue(cmd["singleBatch"])
|
||||
|
||||
def test_batch_size_adjusted_when_equal_to_limit(self):
|
||||
cmd = _gen_find_command("col", {}, None, 0, 10, 10, None, ReadConcern())
|
||||
self.assertEqual(cmd["batchSize"], 11)
|
||||
|
||||
def test_batch_size_not_adjusted_when_different(self):
|
||||
# Covers the False branch of `if limit == batch_size:` — distinct from the True branch above.
|
||||
cmd = _gen_find_command("col", {}, None, 0, 10, 5, None, ReadConcern())
|
||||
self.assertEqual(cmd["batchSize"], 5)
|
||||
|
||||
def test_read_concern_level_included(self):
|
||||
cmd = _gen_find_command("col", {}, None, 0, 0, None, None, ReadConcern("majority"))
|
||||
self.assertEqual(cmd["readConcern"], {"level": "majority"})
|
||||
|
||||
def test_query_with_dollar_query_modifier(self):
|
||||
spec = {"$query": {"x": 1}, "$orderby": {"x": 1}}
|
||||
cmd = _gen_find_command("col", spec, None, 0, 0, None, None, ReadConcern())
|
||||
self.assertIn("sort", cmd)
|
||||
self.assertNotIn("$orderby", cmd)
|
||||
self.assertNotIn("$query", cmd)
|
||||
|
||||
def test_allow_disk_use(self):
|
||||
cmd = _gen_find_command(
|
||||
"col", {}, None, 0, 0, None, None, ReadConcern(), allow_disk_use=True
|
||||
)
|
||||
self.assertTrue(cmd["allowDiskUse"])
|
||||
|
||||
def test_collation(self):
|
||||
cmd = _gen_find_command(
|
||||
"col", {}, None, 0, 0, None, None, ReadConcern(), collation={"locale": "en"}
|
||||
)
|
||||
self.assertEqual(cmd["collation"]["locale"], "en")
|
||||
|
||||
def test_options_tailable(self):
|
||||
cmd = _gen_find_command("col", {}, None, 0, 0, None, 2, ReadConcern())
|
||||
self.assertTrue(cmd.get("tailable"))
|
||||
|
||||
def test_dollar_query_with_explain_removed(self):
|
||||
spec = {"$query": {"x": 1}, "$explain": 1}
|
||||
cmd = _gen_find_command("col", spec, None, 0, 0, None, None, ReadConcern())
|
||||
self.assertNotIn("$explain", cmd)
|
||||
|
||||
def test_dollar_query_with_read_preference_removed(self):
|
||||
# Covers the separate `if "$readPreference" in cmd:` branch — not entered by test_dollar_query_with_explain_removed.
|
||||
spec = {"$query": {"x": 1}, "$readPreference": {"mode": "secondary"}}
|
||||
cmd = _gen_find_command("col", spec, None, 0, 0, None, None, ReadConcern())
|
||||
self.assertNotIn("$readPreference", cmd)
|
||||
|
||||
|
||||
class TestGenGetMoreCommand(unittest.TestCase):
|
||||
def _make_conn(self, max_wire_version=9):
|
||||
conn = MagicMock()
|
||||
conn.max_wire_version = max_wire_version
|
||||
return conn
|
||||
|
||||
def test_basic(self):
|
||||
cmd = _gen_get_more_command(12345, "col", None, None, None, self._make_conn())
|
||||
self.assertEqual(cmd["getMore"], 12345)
|
||||
self.assertEqual(cmd["collection"], "col")
|
||||
|
||||
def test_with_batch_size(self):
|
||||
cmd = _gen_get_more_command(1, "col", 100, None, None, self._make_conn())
|
||||
self.assertEqual(cmd["batchSize"], 100)
|
||||
|
||||
def test_with_max_await_time_ms(self):
|
||||
cmd = _gen_get_more_command(1, "col", None, 500, None, self._make_conn())
|
||||
self.assertEqual(cmd["maxTimeMS"], 500)
|
||||
|
||||
def test_comment_added_on_high_wire_version(self):
|
||||
cmd = _gen_get_more_command(1, "col", None, None, "my comment", self._make_conn(9))
|
||||
self.assertEqual(cmd["comment"], "my comment")
|
||||
|
||||
def test_comment_not_added_on_low_wire_version(self):
|
||||
cmd = _gen_get_more_command(1, "col", None, None, "my comment", self._make_conn(8))
|
||||
self.assertNotIn("comment", cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user