From 82a8a60af64a6b3fedcfcf7e7e002e8adb2ecc38 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 29 Jan 2025 14:05:59 -0500 Subject: [PATCH 1/4] PYTHON-5077 - Convert test.test_data_lake to async (#2091) --- test/asynchronous/test_data_lake.py | 111 ++++++++++++++++++++++++++++ test/test_data_lake.py | 22 +++--- tools/synchro.py | 1 + 3 files changed, 124 insertions(+), 10 deletions(-) create mode 100644 test/asynchronous/test_data_lake.py diff --git a/test/asynchronous/test_data_lake.py b/test/asynchronous/test_data_lake.py new file mode 100644 index 000000000..0b259fb0d --- /dev/null +++ b/test/asynchronous/test_data_lake.py @@ -0,0 +1,111 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Atlas Data Lake.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, AsyncUnitTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils import ( + OvertCommandListener, +) + +from pymongo.asynchronous.helpers import anext + +_IS_SYNC = False + +pytestmark = pytest.mark.data_lake + + +class TestDataLakeMustConnect(AsyncUnitTest): + async def test_connected_to_data_lake(self): + data_lake = os.environ.get("TEST_DATA_LAKE") + if not data_lake: + self.skipTest("TEST_DATA_LAKE is not set") + + self.assertTrue( + async_client_context.is_data_lake and async_client_context.connected, + "client context must be connected to data lake when DATA_LAKE is set. Failed attempts:\n{}".format( + async_client_context.connection_attempt_info() + ), + ) + + +class TestDataLakeProse(AsyncIntegrationTest): + # Default test database and collection names. + TEST_DB = "test" + TEST_COLLECTION = "driverdata" + + @async_client_context.require_data_lake + async def asyncSetUp(self): + await super().asyncSetUp() + + # Test killCursors + async def test_1(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(event_listeners=[listener]) + cursor = client[self.TEST_DB][self.TEST_COLLECTION].find({}, batch_size=2) + await anext(cursor) + + # find command assertions + find_cmd = listener.succeeded_events[-1] + self.assertEqual(find_cmd.command_name, "find") + cursor_id = find_cmd.reply["cursor"]["id"] + cursor_ns = find_cmd.reply["cursor"]["ns"] + + # killCursors command assertions + await cursor.close() + started = listener.started_events[-1] + self.assertEqual(started.command_name, "killCursors") + succeeded = listener.succeeded_events[-1] + self.assertEqual(succeeded.command_name, "killCursors") + + self.assertIn(cursor_id, started.command["cursors"]) + target_ns = ".".join([started.command["$db"], started.command["killCursors"]]) + self.assertEqual(cursor_ns, target_ns) + + self.assertIn(cursor_id, succeeded.reply["cursorsKilled"]) + + # Test no auth + async def test_2(self): + client = await self.async_rs_client_noauth() + await client.admin.command("ping") + + # Test with auth + async def test_3(self): + for mechanism in ["SCRAM-SHA-1", "SCRAM-SHA-256"]: + client = await self.async_rs_or_single_client(authMechanism=mechanism) + await client[self.TEST_DB][self.TEST_COLLECTION].find_one() + + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = Path(__file__).parent / "data_lake/unified" +else: + TEST_PATH = Path(__file__).parent.parent / "data_lake/unified" + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_data_lake.py b/test/test_data_lake.py index a374db550..797ef8500 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -23,20 +23,20 @@ import pytest sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import IntegrationTest, UnitTest, client_context, unittest from test.unified_format import generate_test_classes from test.utils import ( OvertCommandListener, ) +from pymongo.synchronous.helpers import next + +_IS_SYNC = True + pytestmark = pytest.mark.data_lake -# Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data_lake") - - -class TestDataLakeMustConnect(unittest.TestCase): +class TestDataLakeMustConnect(UnitTest): def test_connected_to_data_lake(self): data_lake = os.environ.get("TEST_DATA_LAKE") if not data_lake: @@ -55,10 +55,9 @@ class TestDataLakeProse(IntegrationTest): TEST_DB = "test" TEST_COLLECTION = "driverdata" - @classmethod @client_context.require_data_lake - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() # Test killCursors def test_1(self): @@ -99,7 +98,10 @@ class TestDataLakeProse(IntegrationTest): # Location of JSON test specifications. -TEST_PATH = Path(__file__).parent / "data_lake/unified" +if _IS_SYNC: + TEST_PATH = Path(__file__).parent / "data_lake/unified" +else: + TEST_PATH = Path(__file__).parent.parent / "data_lake/unified" # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/tools/synchro.py b/tools/synchro.py index dbcbbd135..cbac5752c 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -203,6 +203,7 @@ converted_tests = [ "test_crud_unified.py", "test_cursor.py", "test_database.py", + "test_data_lake.py", "test_encryption.py", "test_grid_file.py", "test_logger.py", From cbc3af704f022622def68e1a7752b12e671d6df9 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 29 Jan 2025 14:06:09 -0500 Subject: [PATCH 2/4] PYTHON-5076 - Convert test.test_custom_types to async (#2090) --- test/asynchronous/test_custom_types.py | 989 +++++++++++++++++++++++++ test/test_custom_types.py | 13 +- tools/synchro.py | 1 + 3 files changed, 998 insertions(+), 5 deletions(-) create mode 100644 test/asynchronous/test_custom_types.py diff --git a/test/asynchronous/test_custom_types.py b/test/asynchronous/test_custom_types.py new file mode 100644 index 000000000..0f9d737af --- /dev/null +++ b/test/asynchronous/test_custom_types.py @@ -0,0 +1,989 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test support for callbacks to encode/decode custom types.""" +from __future__ import annotations + +import datetime +import sys +import tempfile +from collections import OrderedDict +from decimal import Decimal +from random import random +from typing import Any, Tuple, Type, no_type_check + +from gridfs.asynchronous.grid_file import AsyncGridIn, AsyncGridOut + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest + +from bson import ( + _BUILT_IN_TYPES, + RE_TYPE, + Decimal128, + _bson_to_dict, + _dict_to_bson, + decode, + decode_all, + decode_file_iter, + decode_iter, + encode, +) +from bson.codec_options import ( + CodecOptions, + TypeCodec, + TypeDecoder, + TypeEncoder, + TypeRegistry, +) +from bson.errors import InvalidDocument +from bson.int64 import Int64 +from bson.raw_bson import RawBSONDocument +from pymongo.asynchronous.collection import ReturnDocument +from pymongo.asynchronous.helpers import anext +from pymongo.errors import DuplicateKeyError +from pymongo.message import _CursorAddress + +_IS_SYNC = False + + +class DecimalEncoder(TypeEncoder): + @property + def python_type(self): + return Decimal + + def transform_python(self, value): + return Decimal128(value) + + +class DecimalDecoder(TypeDecoder): + @property + def bson_type(self): + return Decimal128 + + def transform_bson(self, value): + return value.to_decimal() + + +class DecimalCodec(DecimalDecoder, DecimalEncoder): + pass + + +DECIMAL_CODECOPTS = CodecOptions(type_registry=TypeRegistry([DecimalCodec()])) + + +class UndecipherableInt64Type: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + # Does not compare equal to integers. + return False + + +class UndecipherableIntDecoder(TypeDecoder): + bson_type = Int64 + + def transform_bson(self, value): + return UndecipherableInt64Type(value) + + +class UndecipherableIntEncoder(TypeEncoder): + python_type = UndecipherableInt64Type + + def transform_python(self, value): + return Int64(value.value) + + +UNINT_DECODER_CODECOPTS = CodecOptions( + type_registry=TypeRegistry( + [ + UndecipherableIntDecoder(), + ] + ) +) + + +UNINT_CODECOPTS = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntDecoder(), UndecipherableIntEncoder()]) +) + + +class UppercaseTextDecoder(TypeDecoder): + bson_type = str + + def transform_bson(self, value): + return value.upper() + + +UPPERSTR_DECODER_CODECOPTS = CodecOptions( + type_registry=TypeRegistry( + [ + UppercaseTextDecoder(), + ] + ) +) + + +def type_obfuscating_decoder_factory(rt_type): + class ResumeTokenToNanDecoder(TypeDecoder): + bson_type = rt_type + + def transform_bson(self, value): + return "NaN" + + return ResumeTokenToNanDecoder + + +class CustomBSONTypeTests: + @no_type_check + def roundtrip(self, doc): + bsonbytes = encode(doc, codec_options=self.codecopts) + rt_document = decode(bsonbytes, codec_options=self.codecopts) + self.assertEqual(doc, rt_document) + + def test_encode_decode_roundtrip(self): + self.roundtrip({"average": Decimal("56.47")}) + self.roundtrip({"average": {"b": Decimal("56.47")}}) + self.roundtrip({"average": [Decimal("56.47")]}) + self.roundtrip({"average": [[Decimal("56.47")]]}) + self.roundtrip({"average": [{"b": Decimal("56.47")}]}) + + @no_type_check + def test_decode_all(self): + documents = [] + for dec in range(3): + documents.append({"average": Decimal(f"56.4{dec}")}) + + bsonstream = b"" + for doc in documents: + bsonstream += encode(doc, codec_options=self.codecopts) + + self.assertEqual(decode_all(bsonstream, self.codecopts), documents) + + @no_type_check + def test__bson_to_dict(self): + document = {"average": Decimal("56.47")} + rawbytes = encode(document, codec_options=self.codecopts) + decoded_document = _bson_to_dict(rawbytes, self.codecopts) + self.assertEqual(document, decoded_document) + + @no_type_check + def test__dict_to_bson(self): + document = {"average": Decimal("56.47")} + rawbytes = encode(document, codec_options=self.codecopts) + encoded_document = _dict_to_bson(document, False, self.codecopts) + self.assertEqual(encoded_document, rawbytes) + + def _generate_multidocument_bson_stream(self): + inp_num = [str(random() * 100)[:4] for _ in range(10)] + docs = [{"n": Decimal128(dec)} for dec in inp_num] + edocs = [{"n": Decimal(dec)} for dec in inp_num] + bsonstream = b"" + for doc in docs: + bsonstream += encode(doc) + return edocs, bsonstream + + @no_type_check + def test_decode_iter(self): + expected, bson_data = self._generate_multidocument_bson_stream() + for expected_doc, decoded_doc in zip(expected, decode_iter(bson_data, self.codecopts)): + self.assertEqual(expected_doc, decoded_doc) + + @no_type_check + def test_decode_file_iter(self): + expected, bson_data = self._generate_multidocument_bson_stream() + fileobj = tempfile.TemporaryFile() + fileobj.write(bson_data) + fileobj.seek(0) + + for expected_doc, decoded_doc in zip(expected, decode_file_iter(fileobj, self.codecopts)): + self.assertEqual(expected_doc, decoded_doc) + + fileobj.close() + + +class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.codecopts = DECIMAL_CODECOPTS + + +class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, unittest.TestCase): + @classmethod + def setUpClass(cls): + codec_options = CodecOptions( + type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder())) + ) + cls.codecopts = codec_options + + +class TestBSONFallbackEncoder(unittest.TestCase): + def _get_codec_options(self, fallback_encoder): + type_registry = TypeRegistry(fallback_encoder=fallback_encoder) + return CodecOptions(type_registry=type_registry) + + def test_simple(self): + codecopts = self._get_codec_options(lambda x: Decimal128(x)) + document = {"average": Decimal("56.47")} + bsonbytes = encode(document, codec_options=codecopts) + + exp_document = {"average": Decimal128("56.47")} + exp_bsonbytes = encode(exp_document) + self.assertEqual(bsonbytes, exp_bsonbytes) + + def test_erroring_fallback_encoder(self): + codecopts = self._get_codec_options(lambda _: 1 / 0) + + # fallback converter should not be invoked when encoding known types. + encode( + {"a": 1, "b": Decimal128("1.01"), "c": {"arr": ["abc", 3.678]}}, codec_options=codecopts + ) + + # expect an error when encoding a custom type. + document = {"average": Decimal("56.47")} + with self.assertRaises(ZeroDivisionError): + encode(document, codec_options=codecopts) + + def test_noop_fallback_encoder(self): + codecopts = self._get_codec_options(lambda x: x) + document = {"average": Decimal("56.47")} + with self.assertRaises(InvalidDocument): + encode(document, codec_options=codecopts) + + def test_type_unencodable_by_fallback_encoder(self): + def fallback_encoder(value): + try: + return Decimal128(value) + except: + raise TypeError("cannot encode type %s" % (type(value))) + + codecopts = self._get_codec_options(fallback_encoder) + document = {"average": Decimal} + with self.assertRaises(TypeError): + encode(document, codec_options=codecopts) + + def test_call_only_once_for_not_handled_big_integers(self): + called_with = [] + + def fallback_encoder(value): + called_with.append(value) + return value + + codecopts = self._get_codec_options(fallback_encoder) + document = {"a": {"b": {"c": 2 << 65}}} + + msg = "MongoDB can only handle up to 8-byte ints" + with self.assertRaises(OverflowError, msg=msg): + encode(document, codec_options=codecopts) + + self.assertEqual(called_with, [2 << 65]) + + +class TestBSONTypeEnDeCodecs(unittest.TestCase): + def test_instantiation(self): + msg = "Can't instantiate abstract class" + + def run_test(base, attrs, fail): + codec = type("testcodec", (base,), attrs) + if fail: + with self.assertRaisesRegex(TypeError, msg): + codec() + else: + codec() + + class MyType: + pass + + run_test( + TypeEncoder, + { + "python_type": MyType, + }, + fail=True, + ) + run_test(TypeEncoder, {"transform_python": lambda s, x: x}, fail=True) + run_test( + TypeEncoder, {"transform_python": lambda s, x: x, "python_type": MyType}, fail=False + ) + + run_test( + TypeDecoder, + { + "bson_type": Decimal128, + }, + fail=True, + ) + run_test(TypeDecoder, {"transform_bson": lambda s, x: x}, fail=True) + run_test( + TypeDecoder, {"transform_bson": lambda s, x: x, "bson_type": Decimal128}, fail=False + ) + + run_test(TypeCodec, {"bson_type": Decimal128, "python_type": MyType}, fail=True) + run_test( + TypeCodec, + {"transform_bson": lambda s, x: x, "transform_python": lambda s, x: x}, + fail=True, + ) + run_test( + TypeCodec, + { + "python_type": MyType, + "transform_python": lambda s, x: x, + "transform_bson": lambda s, x: x, + "bson_type": Decimal128, + }, + fail=False, + ) + + def test_type_checks(self): + self.assertTrue(issubclass(TypeCodec, TypeEncoder)) + self.assertTrue(issubclass(TypeCodec, TypeDecoder)) + self.assertFalse(issubclass(TypeDecoder, TypeEncoder)) + self.assertFalse(issubclass(TypeEncoder, TypeDecoder)) + + +class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): + TypeA: Any + TypeB: Any + fallback_encoder_A2B: Any + fallback_encoder_A2BSON: Any + B2BSON: Type[TypeEncoder] + B2A: Type[TypeEncoder] + A2B: Type[TypeEncoder] + + @classmethod + def setUpClass(cls): + class TypeA: + def __init__(self, x): + self.value = x + + class TypeB: + def __init__(self, x): + self.value = x + + # transforms A, and only A into B + def fallback_encoder_A2B(value): + assert isinstance(value, TypeA) + return TypeB(value.value) + + # transforms A, and only A into something encodable + def fallback_encoder_A2BSON(value): + assert isinstance(value, TypeA) + return value.value + + # transforms B into something encodable + class B2BSON(TypeEncoder): + python_type = TypeB + + def transform_python(self, value): + return value.value + + # transforms A into B + # technically, this isn't a proper type encoder as the output is not + # BSON-encodable. + class A2B(TypeEncoder): + python_type = TypeA + + def transform_python(self, value): + return TypeB(value.value) + + # transforms B into A + # technically, this isn't a proper type encoder as the output is not + # BSON-encodable. + class B2A(TypeEncoder): + python_type = TypeB + + def transform_python(self, value): + return TypeA(value.value) + + cls.TypeA = TypeA + cls.TypeB = TypeB + cls.fallback_encoder_A2B = staticmethod(fallback_encoder_A2B) + cls.fallback_encoder_A2BSON = staticmethod(fallback_encoder_A2BSON) + cls.B2BSON = B2BSON + cls.B2A = B2A + cls.A2B = A2B + + def test_encode_fallback_then_custom(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2BSON()], fallback_encoder=self.fallback_encoder_A2B) + ) + testdoc = {"x": self.TypeA(123)} + expected_bytes = encode({"x": 123}) + + self.assertEqual(encode(testdoc, codec_options=codecopts), expected_bytes) + + def test_encode_custom_then_fallback(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2A()], fallback_encoder=self.fallback_encoder_A2BSON) + ) + testdoc = {"x": self.TypeB(123)} + expected_bytes = encode({"x": 123}) + + self.assertEqual(encode(testdoc, codec_options=codecopts), expected_bytes) + + def test_chaining_encoders_fails(self): + codecopts = CodecOptions(type_registry=TypeRegistry([self.A2B(), self.B2BSON()])) + + with self.assertRaises(InvalidDocument): + encode({"x": self.TypeA(123)}, codec_options=codecopts) + + def test_infinite_loop_exceeds_max_recursion_depth(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2A()], fallback_encoder=self.fallback_encoder_A2B) + ) + + # Raises max recursion depth exceeded error + with self.assertRaises(RuntimeError): + encode({"x": self.TypeA(100)}, codec_options=codecopts) + + +class TestTypeRegistry(unittest.TestCase): + types: Tuple[object, object] + codecs: Tuple[Type[TypeCodec], Type[TypeCodec]] + fallback_encoder: Any + + @classmethod + def setUpClass(cls): + class MyIntType: + def __init__(self, x): + assert isinstance(x, int) + self.x = x + + class MyStrType: + def __init__(self, x): + assert isinstance(x, str) + self.x = x + + class MyIntCodec(TypeCodec): + @property + def python_type(self): + return MyIntType + + @property + def bson_type(self): + return int + + def transform_python(self, value): + return value.x + + def transform_bson(self, value): + return MyIntType(value) + + class MyStrCodec(TypeCodec): + @property + def python_type(self): + return MyStrType + + @property + def bson_type(self): + return str + + def transform_python(self, value): + return value.x + + def transform_bson(self, value): + return MyStrType(value) + + def fallback_encoder(value): + return value + + cls.types = (MyIntType, MyStrType) + cls.codecs = (MyIntCodec, MyStrCodec) + cls.fallback_encoder = fallback_encoder + + def test_simple(self): + codec_instances = [codec() for codec in self.codecs] + + def assert_proper_initialization(type_registry, codec_instances): + self.assertEqual( + type_registry._encoder_map, + { + self.types[0]: codec_instances[0].transform_python, + self.types[1]: codec_instances[1].transform_python, + }, + ) + self.assertEqual( + type_registry._decoder_map, + {int: codec_instances[0].transform_bson, str: codec_instances[1].transform_bson}, + ) + self.assertEqual(type_registry._fallback_encoder, self.fallback_encoder) + + type_registry = TypeRegistry(codec_instances, self.fallback_encoder) + assert_proper_initialization(type_registry, codec_instances) + + type_registry = TypeRegistry( + fallback_encoder=self.fallback_encoder, type_codecs=codec_instances + ) + assert_proper_initialization(type_registry, codec_instances) + + # Ensure codec list held by the type registry doesn't change if we + # mutate the initial list. + codec_instances_copy = list(codec_instances) + codec_instances.pop(0) + self.assertListEqual(type_registry._TypeRegistry__type_codecs, codec_instances_copy) + + def test_simple_separate_codecs(self): + class MyIntEncoder(TypeEncoder): + python_type = self.types[0] + + def transform_python(self, value): + return value.x + + class MyIntDecoder(TypeDecoder): + bson_type = int + + def transform_bson(self, value): + return self.types[0](value) + + codec_instances: list = [MyIntDecoder(), MyIntEncoder()] + type_registry = TypeRegistry(codec_instances) + + self.assertEqual( + type_registry._encoder_map, + {MyIntEncoder.python_type: codec_instances[1].transform_python}, + ) + self.assertEqual( + type_registry._decoder_map, + {MyIntDecoder.bson_type: codec_instances[0].transform_bson}, + ) + + def test_initialize_fail(self): + err_msg = "Expected an instance of TypeEncoder, TypeDecoder, or TypeCodec, got .* instead" + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry(self.codecs) # type: ignore[arg-type] + + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry([type("AnyType", (object,), {})()]) + + err_msg = f"fallback_encoder {True!r} is not a callable" + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry([], True) # type: ignore[arg-type] + + err_msg = "fallback_encoder {!r} is not a callable".format("hello") + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry(fallback_encoder="hello") # type: ignore[arg-type] + + def test_type_registry_repr(self): + codec_instances = [codec() for codec in self.codecs] + type_registry = TypeRegistry(codec_instances) + r = f"TypeRegistry(type_codecs={codec_instances!r}, fallback_encoder={None!r})" + self.assertEqual(r, repr(type_registry)) + + def test_type_registry_eq(self): + codec_instances = [codec() for codec in self.codecs] + self.assertEqual(TypeRegistry(codec_instances), TypeRegistry(codec_instances)) + + codec_instances_2 = [codec() for codec in self.codecs] + self.assertNotEqual(TypeRegistry(codec_instances), TypeRegistry(codec_instances_2)) + + def test_builtin_types_override_fails(self): + def run_test(base, attrs): + msg = ( + r"TypeEncoders cannot change how built-in types " + r"are encoded \(encoder .* transforms type .*\)" + ) + for pytype in _BUILT_IN_TYPES: + attrs.update({"python_type": pytype, "transform_python": lambda x: x}) + codec = type("testcodec", (base,), attrs) + codec_instance = codec() + with self.assertRaisesRegex(TypeError, msg): + TypeRegistry( + [ + codec_instance, + ] + ) + + # Test only some subtypes as not all can be subclassed. + if pytype in [ + bool, + type(None), + RE_TYPE, + ]: + continue + + class MyType(pytype): # type: ignore + pass + + attrs.update({"python_type": MyType, "transform_python": lambda x: x}) + codec = type("testcodec", (base,), attrs) + codec_instance = codec() + with self.assertRaisesRegex(TypeError, msg): + TypeRegistry( + [ + codec_instance, + ] + ) + + run_test(TypeEncoder, {}) + run_test(TypeCodec, {"bson_type": Decimal128, "transform_bson": lambda x: x}) + + +class TestCollectionWCustomType(AsyncIntegrationTest): + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.drop() + + async def asyncTearDown(self): + await self.db.test.drop() + + async def test_overflow_int_w_custom_decoder(self): + type_registry = TypeRegistry(fallback_encoder=lambda val: str(val)) + codec_options = CodecOptions(type_registry=type_registry) + collection = self.db.get_collection("test", codec_options=codec_options) + + await collection.insert_one({"_id": 1, "data": 2**520}) + ret = await collection.find_one() + self.assertEqual(ret["data"], str(2**520)) + + async def test_command_errors_w_custom_type_decoder(self): + db = self.db + test_doc = {"_id": 1, "data": "a"} + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + + result = await test.insert_one(test_doc) + self.assertEqual(result.inserted_id, test_doc["_id"]) + with self.assertRaises(DuplicateKeyError): + await test.insert_one(test_doc) + + async def test_find_w_custom_type_decoder(self): + db = self.db + input_docs = [{"x": Int64(k)} for k in [1, 2, 3]] + for doc in input_docs: + await db.test.insert_one(doc) + + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + async for doc in test.find({}, batch_size=1): + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + + async def test_find_w_custom_type_decoder_and_document_class(self): + async def run_test(doc_cls): + db = self.db + input_docs = [{"x": Int64(k)} for k in [1, 2, 3]] + for doc in input_docs: + await db.test.insert_one(doc) + + test = db.get_collection( + "test", + codec_options=CodecOptions( + type_registry=TypeRegistry([UndecipherableIntDecoder()]), document_class=doc_cls + ), + ) + async for doc in test.find({}, batch_size=1): + self.assertIsInstance(doc, doc_cls) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + + for doc_cls in [RawBSONDocument, OrderedDict]: + await run_test(doc_cls) + + async def test_aggregate_w_custom_type_decoder(self): + db = self.db + await db.test.insert_many( + [ + {"status": "in progress", "qty": Int64(1)}, + {"status": "complete", "qty": Int64(10)}, + {"status": "in progress", "qty": Int64(1)}, + {"status": "complete", "qty": Int64(10)}, + {"status": "in progress", "qty": Int64(1)}, + ] + ) + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + + pipeline: list = [ + {"$match": {"status": "complete"}}, + {"$group": {"_id": "$status", "total_qty": {"$sum": "$qty"}}}, + ] + result = await test.aggregate(pipeline) + + res = (await result.to_list())[0] + self.assertEqual(res["_id"], "complete") + self.assertIsInstance(res["total_qty"], UndecipherableInt64Type) + self.assertEqual(res["total_qty"].value, 20) + + async def test_distinct_w_custom_type(self): + await self.db.drop_collection("test") + + test = self.db.get_collection("test", codec_options=UNINT_CODECOPTS) + values = [ + UndecipherableInt64Type(1), + UndecipherableInt64Type(2), + UndecipherableInt64Type(3), + {"b": UndecipherableInt64Type(3)}, + ] + await test.insert_many({"a": val} for val in values) + + self.assertEqual(values, await test.distinct("a")) + + async def test_find_one_and__w_custom_type_decoder(self): + db = self.db + c = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + await c.insert_one({"_id": 1, "x": Int64(1)}) + + doc = await c.find_one_and_update( + {"_id": 1}, {"$inc": {"x": 1}}, return_document=ReturnDocument.AFTER + ) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 2) + + doc = await c.find_one_and_replace( + {"_id": 1}, {"x": Int64(3), "y": True}, return_document=ReturnDocument.AFTER + ) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 3) + self.assertEqual(doc["y"], True) + + doc = await c.find_one_and_delete({"y": True}) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 3) + self.assertIsNone(await c.find_one()) + + +class TestGridFileCustomType(AsyncIntegrationTest): + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.drop_collection("fs.files") + await self.db.drop_collection("fs.chunks") + + async def test_grid_out_custom_opts(self): + db = self.db.with_options(codec_options=UPPERSTR_DECODER_CODECOPTS) + one = AsyncGridIn( + db.fs, + _id=5, + filename="my_file", + chunkSize=1000, + metadata={"foo": "red", "bar": "blue"}, + bar=3, + baz="hello", + ) + await one.write(b"hello world") + await one.close() + + two = AsyncGridOut(db.fs, 5) + await two.open() + + self.assertEqual("my_file", two.name) + self.assertEqual("my_file", two.filename) + self.assertEqual(5, two._id) + self.assertEqual(11, two.length) + self.assertEqual(1000, two.chunk_size) + self.assertTrue(isinstance(two.upload_date, datetime.datetime)) + self.assertEqual({"foo": "red", "bar": "blue"}, two.metadata) + self.assertEqual(3, two.bar) + + for attr in [ + "_id", + "name", + "content_type", + "length", + "chunk_size", + "upload_date", + "aliases", + "metadata", + "md5", + ]: + self.assertRaises(AttributeError, setattr, two, attr, 5) + + +class ChangeStreamsWCustomTypesTestMixin: + @no_type_check + async def change_stream(self, *args, **kwargs): + stream = await self.watched_target.watch(*args, max_await_time_ms=1, **kwargs) + self.addAsyncCleanup(stream.close) + return stream + + @no_type_check + async def insert_and_check(self, change_stream, insert_doc, expected_doc): + await self.input_target.insert_one(insert_doc) + change = await anext(change_stream) + self.assertEqual(change["fullDocument"], expected_doc) + + @no_type_check + async def kill_change_stream_cursor(self, change_stream): + # Cause a cursor not found error on the next getMore. + cursor = change_stream._cursor + address = _CursorAddress(cursor.address, cursor._ns) + client = self.input_target.database.client + await client._close_cursor_now(cursor.cursor_id, address) + + @no_type_check + async def test_simple(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntEncoder(), UppercaseTextDecoder()]) + ) + await self.create_targets(codec_options=codecopts) + + input_docs = [ + {"_id": UndecipherableInt64Type(1), "data": "hello"}, + {"_id": 2, "data": "world"}, + {"_id": UndecipherableInt64Type(3), "data": "!"}, + ] + expected_docs = [ + {"_id": 1, "data": "HELLO"}, + {"_id": 2, "data": "WORLD"}, + {"_id": 3, "data": "!"}, + ] + + change_stream = await self.change_stream() + + await self.insert_and_check(change_stream, input_docs[0], expected_docs[0]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, input_docs[1], expected_docs[1]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, input_docs[2], expected_docs[2]) + + @no_type_check + async def test_custom_type_in_pipeline(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntEncoder(), UppercaseTextDecoder()]) + ) + await self.create_targets(codec_options=codecopts) + + input_docs = [ + {"_id": UndecipherableInt64Type(1), "data": "hello"}, + {"_id": 2, "data": "world"}, + {"_id": UndecipherableInt64Type(3), "data": "!"}, + ] + expected_docs = [{"_id": 2, "data": "WORLD"}, {"_id": 3, "data": "!"}] + + # UndecipherableInt64Type should be encoded with the TypeRegistry. + change_stream = await self.change_stream( + [{"$match": {"documentKey._id": {"$gte": UndecipherableInt64Type(2)}}}] + ) + + await self.input_target.insert_one(input_docs[0]) + await self.insert_and_check(change_stream, input_docs[1], expected_docs[0]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, input_docs[2], expected_docs[1]) + + @no_type_check + async def test_break_resume_token(self): + # Get one document from a change stream to determine resumeToken type. + await self.create_targets() + change_stream = await self.change_stream() + await self.input_target.insert_one({"data": "test"}) + change = await anext(change_stream) + resume_token_decoder = type_obfuscating_decoder_factory(type(change["_id"]["_data"])) + + # Custom-decoding the resumeToken type breaks resume tokens. + codecopts = CodecOptions( + type_registry=TypeRegistry([resume_token_decoder(), UndecipherableIntEncoder()]) + ) + + # Re-create targets, change stream and proceed. + await self.create_targets(codec_options=codecopts) + + docs = [{"_id": 1}, {"_id": 2}, {"_id": 3}] + + change_stream = await self.change_stream() + await self.insert_and_check(change_stream, docs[0], docs[0]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, docs[1], docs[1]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, docs[2], docs[2]) + + @no_type_check + async def test_document_class(self): + async def run_test(doc_cls): + codecopts = CodecOptions( + type_registry=TypeRegistry([UppercaseTextDecoder(), UndecipherableIntEncoder()]), + document_class=doc_cls, + ) + + await self.create_targets(codec_options=codecopts) + change_stream = await self.change_stream() + + doc = {"a": UndecipherableInt64Type(101), "b": "xyz"} + await self.input_target.insert_one(doc) + change = await anext(change_stream) + + self.assertIsInstance(change, doc_cls) + self.assertEqual(change["fullDocument"]["a"], 101) + self.assertEqual(change["fullDocument"]["b"], "XYZ") + + for doc_cls in [OrderedDict, RawBSONDocument]: + await run_test(doc_cls) + + +class TestCollectionChangeStreamsWCustomTypes( + AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin +): + @async_client_context.require_change_streams + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.delete_many({}) + + async def asyncTearDown(self): + await self.input_target.drop() + + async def create_targets(self, *args, **kwargs): + self.watched_target = self.db.get_collection("test", *args, **kwargs) + self.input_target = self.watched_target + # Ensure the collection exists and is empty. + await self.input_target.insert_one({}) + await self.input_target.delete_many({}) + + +class TestDatabaseChangeStreamsWCustomTypes( + AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin +): + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_change_streams + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.delete_many({}) + + async def asyncTearDown(self): + await self.input_target.drop() + await self.client.drop_database(self.watched_target) + + async def create_targets(self, *args, **kwargs): + self.watched_target = self.client.get_database(self.db.name, *args, **kwargs) + self.input_target = self.watched_target.test + # Insert a record to ensure db, coll are created. + await self.input_target.insert_one({"data": "dummy"}) + + +class TestClusterChangeStreamsWCustomTypes( + AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin +): + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_change_streams + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.delete_many({}) + + async def asyncTearDown(self): + await self.input_target.drop() + await self.client.drop_database(self.db) + + async def create_targets(self, *args, **kwargs): + codec_options = kwargs.pop("codec_options", None) + if codec_options: + kwargs["type_registry"] = codec_options.type_registry + kwargs["document_class"] = codec_options.document_class + self.watched_target = await self.async_rs_client(*args, **kwargs) + self.input_target = self.watched_target[self.db.name].test + # Insert a record to ensure db, coll are created. + await self.input_target.insert_one({"data": "dummy"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 6771ea25f..08e2a46f8 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -23,10 +23,11 @@ from decimal import Decimal from random import random from typing import Any, Tuple, Type, no_type_check +from gridfs.synchronous.grid_file import GridIn, GridOut + sys.path[0:0] = [""] -from test import client_context, unittest -from test.test_client import IntegrationTest +from test import IntegrationTest, client_context, unittest from bson import ( _BUILT_IN_TYPES, @@ -50,10 +51,12 @@ from bson.codec_options import ( from bson.errors import InvalidDocument from bson.int64 import Int64 from bson.raw_bson import RawBSONDocument -from gridfs import GridIn, GridOut from pymongo.errors import DuplicateKeyError from pymongo.message import _CursorAddress from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.helpers import next + +_IS_SYNC = True class DecimalEncoder(TypeEncoder): @@ -707,7 +710,7 @@ class TestCollectionWCustomType(IntegrationTest): ] result = test.aggregate(pipeline) - res = list(result)[0] + res = (result.to_list())[0] self.assertEqual(res["_id"], "complete") self.assertIsInstance(res["total_qty"], UndecipherableInt64Type) self.assertEqual(res["total_qty"].value, 20) @@ -774,6 +777,7 @@ class TestGridFileCustomType(IntegrationTest): one.close() two = GridOut(db.fs, 5) + two.open() self.assertEqual("my_file", two.name) self.assertEqual("my_file", two.filename) @@ -970,7 +974,6 @@ class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustom kwargs["type_registry"] = codec_options.type_registry kwargs["document_class"] = codec_options.document_class self.watched_target = self.rs_client(*args, **kwargs) - self.addCleanup(self.watched_target.close) self.input_target = self.watched_target[self.db.name].test # Insert a record to ensure db, coll are created. self.input_target.insert_one({"data": "dummy"}) diff --git a/tools/synchro.py b/tools/synchro.py index cbac5752c..897e5e801 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -202,6 +202,7 @@ converted_tests = [ "test_create_entities.py", "test_crud_unified.py", "test_cursor.py", + "test_custom_types.py", "test_database.py", "test_data_lake.py", "test_encryption.py", From b4e32a1d8388fe5bf731c0c866b8bb96bbf19870 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 29 Jan 2025 13:27:07 -0600 Subject: [PATCH 3/4] PYTHON-5047 Fix dry run logic in releases again (#2092) --- .github/workflows/release-python.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index bcf37d1a2..0801d12f5 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -23,7 +23,7 @@ env: SILK_ASSET_GROUP: mongodb-python-driver EVERGREEN_PROJECT: mongo-python-driver # Constant - DRY_RUN: ${{ inputs.dry_run == 'true' }} + DRY_RUN: ${{ github.event_name == 'workflow_dispatch' && inputs.dry_run || 'true' }} FOLLOWING_VERSION: ${{ inputs.following_version || '' }} VERSION: ${{ inputs.version || '10.10.10.10' }} From 1784e2c4b9c7e5efbed1796e81e37fa49f8845f0 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 29 Jan 2025 15:35:00 -0500 Subject: [PATCH 4/4] PYTHON-5112 - Fix just install (#2095) --- .evergreen/scripts/setup-dev-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/scripts/setup-dev-env.sh b/.evergreen/scripts/setup-dev-env.sh index ae4b44c62..b56897961 100755 --- a/.evergreen/scripts/setup-dev-env.sh +++ b/.evergreen/scripts/setup-dev-env.sh @@ -30,8 +30,8 @@ if [ ! -d $BIN_DIR ]; then fi export UV_PYTHON=${PYTHON_BINARY} echo "export UV_PYTHON=$UV_PYTHON" >> $HERE/env.sh + echo "Using python $UV_PYTHON" fi -echo "Using python $UV_PYTHON" uv sync --frozen uv run --frozen --with pip pip install -e . echo "Setting up python environment... done."