diff --git a/bson/__init__.py b/bson/__init__.py index d0a8daa27..fd11c9952 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -896,12 +896,21 @@ def _name_value_to_bson( in_fallback_call: bool = False, ) -> bytes: """Encode a single name, value pair.""" + + was_integer_overflow = False + # First see if the type is already cached. KeyError will only ever # happen once per subtype. try: return _ENCODERS[type(value)](name, value, check_keys, opts) # type: ignore except KeyError: pass + except OverflowError: + if not isinstance(value, int): + raise + + # Give the fallback_encoder a chance + was_integer_overflow = True # Second, fall back to trying _type_marker. This has to be done # before the loop below since users could subclass one of our @@ -927,7 +936,7 @@ def _name_value_to_bson( # is done after trying the custom type encoder because checking for each # subtype is expensive. for base in _BUILT_IN_TYPES: - if isinstance(value, base): + if not was_integer_overflow and isinstance(value, base): func = _ENCODERS[base] # Cache this type for faster subsequent lookup. _ENCODERS[type(value)] = func @@ -941,6 +950,8 @@ def _name_value_to_bson( name, fallback_encoder(value), check_keys, opts, in_fallback_call=True ) + if was_integer_overflow: + raise OverflowError("BSON can only handle up to 8-byte ints") raise InvalidDocument(f"cannot encode object: {value!r}, of type: {type(value)!r}") diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 5918a678c..68ea6b63c 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -662,6 +662,13 @@ static int write_element_to_buffer(PyObject* self, buffer_t buffer, static void _set_cannot_encode(PyObject* value) { + if (PyLong_Check(value)) { + if ((PyLong_AsLongLong(value) == -1) && PyErr_Occurred()) { + return PyErr_SetString(PyExc_OverflowError, + "MongoDB can only handle up to 8-byte ints"); + } + } + PyObject* type = NULL; PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument == NULL) { @@ -1069,16 +1076,17 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, long long long_long_value; PyErr_Clear(); long_long_value = PyLong_AsLongLong(value); - if (PyErr_Occurred()) { /* Overflow AGAIN */ - PyErr_SetString(PyExc_OverflowError, - "MongoDB can only handle up to 8-byte ints"); - return 0; + if (PyErr_Occurred()) { + /* Ignore error and give the fallback_encoder a chance. */ + PyErr_Clear(); + } else { + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x12; + return buffer_write_int64(buffer, (int64_t)long_long_value); } - *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x12; - return buffer_write_int64(buffer, (int64_t)long_long_value); + } else { + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x10; + return buffer_write_int32(buffer, (int32_t)int_value); } - *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x10; - return buffer_write_int32(buffer, (int32_t)int_value); } else if (PyFloat_Check(value)) { const double d = PyFloat_AsDouble(value); *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x01; diff --git a/doc/contributors.rst b/doc/contributors.rst index e6d5e5310..2a4ca1ea4 100644 --- a/doc/contributors.rst +++ b/doc/contributors.rst @@ -97,3 +97,4 @@ The following is a list of people who have contributed to - Sean Cheah (thalassemia) - Dainis Gorbunovs (DainisGorbunovs) - Iris Ho (sleepyStick) +- Stephan Hof (stephan-hof) diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 14d7b4b05..7e190483a 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -274,6 +274,22 @@ class TestBSONFallbackEncoder(unittest.TestCase): with self.assertRaises(TypeError): encode(document, codec_options=codecopts) + def test_call_only_once_for_not_handled_big_integers(self): + called_with = [] + + def fallback_encoder(value): + called_with.append(value) + return value + + codecopts = self._get_codec_options(fallback_encoder) + document = {"a": {"b": {"c": 2 << 65}}} + + msg = "MongoDB can only handle up to 8-byte ints" + with self.assertRaises(OverflowError, msg=msg): + encode(document, codec_options=codecopts) + + self.assertEqual(called_with, [2 << 65]) + class TestBSONTypeEnDeCodecs(unittest.TestCase): def test_instantiation(self): @@ -623,6 +639,15 @@ class TestCollectionWCustomType(IntegrationTest): def tearDown(self): self.db.test.drop() + def test_overflow_int_w_custom_decoder(self): + type_registry = TypeRegistry(fallback_encoder=lambda val: str(val)) + codec_options = CodecOptions(type_registry=type_registry) + collection = self.db.get_collection("test", codec_options=codec_options) + + collection.insert_one({"_id": 1, "data": 2**520}) + ret = collection.find_one() + self.assertEqual(ret["data"], str(2**520)) + def test_command_errors_w_custom_type_decoder(self): db = self.db test_doc = {"_id": 1, "data": "a"}