diff --git a/doc/examples/custom_type.rst b/doc/examples/custom_type.rst index cbb2f8515..acf706deb 100644 --- a/doc/examples/custom_type.rst +++ b/doc/examples/custom_type.rst @@ -135,7 +135,7 @@ Now, we can seamlessly encode and decode instances of .. doctest:: >>> collection.insert_one({"num": Decimal("45.321")}) - + InsertOneResult(ObjectId('...'), acknowledged=True) >>> mydoc = collection.find_one() >>> import pprint >>> pprint.pprint(mydoc) @@ -217,7 +217,7 @@ object, we can seamlessly encode instances of ``DecimalInt``: >>> collection = db.get_collection("test", codec_options=codec_options) >>> collection.drop() >>> collection.insert_one({"num": DecimalInt("45.321")}) - + InsertOneResult(ObjectId('...'), acknowledged=True) >>> mydoc = collection.find_one() >>> pprint.pprint(mydoc) {'_id': ObjectId('...'), 'num': Decimal('45.321')} @@ -311,7 +311,7 @@ We can now seamlessly encode instances of :py:class:`~decimal.Decimal`: .. doctest:: >>> collection.insert_one({"num": Decimal("45.321")}) - + InsertOneResult(ObjectId('...'), acknowledged=True) >>> mydoc = collection.find_one() >>> pprint.pprint(mydoc) {'_id': ObjectId('...'), 'num': Decimal128('45.321')} diff --git a/doc/faq.rst b/doc/faq.rst index 9a1400277..b9b5bc4b9 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -195,7 +195,7 @@ instance of :class:`~bson.objectid.ObjectId`. For example:: >>> my_doc = {'x': 1} >>> collection.insert_one(my_doc) - + InsertOneResult(ObjectId('560db337fba522189f171720'), acknowledged=True) >>> my_doc {'x': 1, '_id': ObjectId('560db337fba522189f171720')} @@ -531,9 +531,9 @@ objects as before: >>> from pymongo import MongoClient >>> client = MongoClient(datetime_conversion=DatetimeConversion.DATETIME_AUTO) >>> client.db.collection.insert_one({"x": datetime(1970, 1, 1)}) - + InsertOneResult(ObjectId('...'), acknowledged=True) >>> client.db.collection.insert_one({"x": DatetimeMS(2**62)}) - + InsertOneResult(ObjectId('...'), acknowledged=True) >>> for x in client.db.collection.find(): ... print(x) ... diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 7a390e50f..cdaf2358d 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -322,6 +322,9 @@ class RewrapManyDataKeyResult: """ return self._bulk_write_result + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._bulk_write_result!r})" + class _Encrypter: """Encrypts and decrypts MongoDB commands. diff --git a/pymongo/results.py b/pymongo/results.py index 266101921..20c6023cd 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -28,6 +28,9 @@ class _WriteResult: def __init__(self, acknowledged: bool) -> None: self.__acknowledged = acknowledged + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__acknowledged})" + def _raise_if_unacknowledged(self, property_name: str) -> None: """Raise an exception on property access if unacknowledged.""" if not self.__acknowledged: @@ -67,6 +70,11 @@ class InsertOneResult(_WriteResult): self.__inserted_id = inserted_id super().__init__(acknowledged) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.__inserted_id!r}, acknowledged={self.acknowledged})" + ) + @property def inserted_id(self) -> Any: """The inserted document's _id.""" @@ -82,6 +90,11 @@ class InsertManyResult(_WriteResult): self.__inserted_ids = inserted_ids super().__init__(acknowledged) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.__inserted_ids!r}, acknowledged={self.acknowledged})" + ) + @property def inserted_ids(self) -> list[Any]: """A list of _ids of the inserted documents, in the order provided. @@ -106,6 +119,9 @@ class UpdateResult(_WriteResult): self.__raw_result = raw_result super().__init__(acknowledged) + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__raw_result!r}, acknowledged={self.acknowledged})" + @property def raw_result(self) -> Optional[Mapping[str, Any]]: """The raw result document returned by the server.""" @@ -148,6 +164,9 @@ class DeleteResult(_WriteResult): self.__raw_result = raw_result super().__init__(acknowledged) + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__raw_result!r}, acknowledged={self.acknowledged})" + @property def raw_result(self) -> Mapping[str, Any]: """The raw result document returned by the server.""" @@ -177,6 +196,9 @@ class BulkWriteResult(_WriteResult): self.__bulk_api_result = bulk_api_result super().__init__(acknowledged) + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__bulk_api_result!r}, acknowledged={self.acknowledged})" + @property def bulk_api_result(self) -> dict[str, Any]: """The raw bulk API result.""" diff --git a/test/test_results.py b/test/test_results.py new file mode 100644 index 000000000..19e086a9a --- /dev/null +++ b/test/test_results.py @@ -0,0 +1,138 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test results module.""" +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +from test import unittest + +from pymongo.errors import InvalidOperation +from pymongo.results import ( + BulkWriteResult, + DeleteResult, + InsertManyResult, + InsertOneResult, + UpdateResult, +) + + +class TestResults(unittest.TestCase): + def repr_test(self, cls, result_arg): + for acknowledged in (True, False): + result = cls(result_arg, acknowledged) + expected_repr = "%s(%r, acknowledged=%r)" % (cls.__name__, result_arg, acknowledged) + self.assertEqual(acknowledged, result.acknowledged) + self.assertEqual(expected_repr, repr(result)) + + def test_bulk_write_result(self): + raw_result = { + "writeErrors": [], + "writeConcernErrors": [], + "nInserted": 1, + "nUpserted": 2, + "nMatched": 2, + "nModified": 2, + "nRemoved": 2, + "upserted": [ + {"index": 5, "_id": 1}, + {"index": 9, "_id": 2}, + ], + } + self.repr_test(BulkWriteResult, raw_result) + + result = BulkWriteResult(raw_result, True) + self.assertEqual(raw_result, result.bulk_api_result) + self.assertEqual(raw_result["nInserted"], result.inserted_count) + self.assertEqual(raw_result["nMatched"], result.matched_count) + self.assertEqual(raw_result["nModified"], result.modified_count) + self.assertEqual(raw_result["nRemoved"], result.deleted_count) + self.assertEqual(raw_result["nUpserted"], result.upserted_count) + self.assertEqual({5: 1, 9: 2}, result.upserted_ids) + + result = BulkWriteResult(raw_result, False) + self.assertEqual(raw_result, result.bulk_api_result) + error_msg = "A value for .* is not available when" + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.inserted_count + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.matched_count + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.modified_count + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.deleted_count + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.upserted_count + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.upserted_ids + + def test_delete_result(self): + raw_result = {"n": 5} + self.repr_test(DeleteResult, {"n": 0}) + + result = DeleteResult(raw_result, True) + self.assertEqual(raw_result, result.raw_result) + self.assertEqual(raw_result["n"], result.deleted_count) + + result = DeleteResult(raw_result, False) + self.assertEqual(raw_result, result.raw_result) + error_msg = "A value for .* is not available when" + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.deleted_count + + def test_insert_many_result(self): + inserted_ids = [1, 2, 3] + self.repr_test(InsertManyResult, inserted_ids) + + for acknowledged in (True, False): + result = InsertManyResult(inserted_ids, acknowledged) + self.assertEqual(inserted_ids, result.inserted_ids) + + def test_insert_one_result(self): + self.repr_test(InsertOneResult, 0) + + for acknowledged in (True, False): + result = InsertOneResult(0, acknowledged) + self.assertEqual(0, result.inserted_id) + + def test_update_result(self): + raw_result = { + "n": 1, + "nModified": 1, + "upserted": None, + } + self.repr_test(UpdateResult, raw_result) + + result = UpdateResult(raw_result, True) + self.assertEqual(raw_result, result.raw_result) + self.assertEqual(raw_result["n"], result.matched_count) + self.assertEqual(raw_result["nModified"], result.modified_count) + self.assertEqual(raw_result["upserted"], result.upserted_id) + + result = UpdateResult(raw_result, False) + self.assertEqual(raw_result, result.raw_result) + error_msg = "A value for .* is not available when" + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.matched_count + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.modified_count + with self.assertRaisesRegex(InvalidOperation, error_msg): + result.upserted_id + + +if __name__ == "__main__": + unittest.main()