From fd0787a57b72ed30e367f519bebb0ba505ddc121 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 27 Aug 2024 19:05:15 -0500 Subject: [PATCH] PYTHON-4615 Address sign-compare warning, improve array_of_documents_to_buffer validation (#1804) --- bson/__init__.py | 7 ++++--- bson/_cbsonmodule.c | 50 ++++++++++++++++++++++++++++----------------- test/test_bson.py | 18 ++++++++++++++++ 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index a7c9ddc50..48fffd745 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -284,10 +284,10 @@ def _get_object_size(data: Any, position: int, obj_end: int) -> Tuple[int, int]: except struct.error as exc: raise InvalidBSON(str(exc)) from None end = position + obj_size - 1 - if data[end] != 0: - raise InvalidBSON("bad eoo") if end >= obj_end: raise InvalidBSON("invalid object length") + if data[end] != 0: + raise InvalidBSON("bad eoo") # If this is the top-level document, validate the total size too. if position == 0 and obj_size != obj_end: raise InvalidBSON("invalid object length") @@ -1180,9 +1180,10 @@ def _decode_selective( return doc -def _array_of_documents_to_buffer(view: memoryview) -> bytes: +def _array_of_documents_to_buffer(data: Union[memoryview, bytes]) -> bytes: # Extract the raw bytes of each document. position = 0 + view = memoryview(data) _, end = _get_object_size(view, position, len(view)) position += 4 buffers: list[memoryview] = [] diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 3b3aecc44..68ec9fe45 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -2901,11 +2901,31 @@ static PyObject* _cbson_array_of_documents_to_buffer(PyObject* self, PyObject* a "not enough data for a BSON document"); Py_DECREF(InvalidBSON); } - goto done; + goto fail; } memcpy(&size, string, 4); size = BSON_UINT32_FROM_LE(size); + + /* validate the size of the array */ + if (view.len != (int32_t)size || (int32_t)size < BSON_MIN_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "objsize too large"); + Py_DECREF(InvalidBSON); + } + goto fail; + } + + if (string[size - 1]) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "bad eoo"); + Py_DECREF(InvalidBSON); + } + goto fail; + } + /* save space for length */ if (pymongo_buffer_save_space(buffer, size) == -1) { goto fail; @@ -2948,30 +2968,22 @@ static PyObject* _cbson_array_of_documents_to_buffer(PyObject* self, PyObject* a goto fail; } - if (view.len < size) { - PyObject* InvalidBSON = _error("InvalidBSON"); - if (InvalidBSON) { - PyErr_SetString(InvalidBSON, "objsize too large"); - Py_DECREF(InvalidBSON); - } - goto fail; - } - - if (string[size - 1]) { - PyObject* InvalidBSON = _error("InvalidBSON"); - if (InvalidBSON) { - PyErr_SetString(InvalidBSON, "bad eoo"); - Py_DECREF(InvalidBSON); - } - goto fail; - } - if (pymongo_buffer_write(buffer, string + position, value_length) == 1) { goto fail; } position += value_length; } + if (position != size - 1) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, + "bad object or element length"); + Py_DECREF(InvalidBSON); + } + goto fail; + } + /* objectify buffer */ result = Py_BuildValue("y#", pymongo_buffer_get_buffer(buffer), (Py_ssize_t)pymongo_buffer_get_position(buffer)); diff --git a/test/test_bson.py b/test/test_bson.py index 79a7fa061..8c8fe6018 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -41,6 +41,7 @@ from bson import ( EPOCH_AWARE, DatetimeMS, Regex, + _array_of_documents_to_buffer, _datetime_to_millis, decode, decode_all, @@ -1366,6 +1367,23 @@ class TestDatetimeConversion(unittest.TestCase): with self.assertRaisesRegex(InvalidBSON, re.compile(re.escape(_DATETIME_ERROR_SUGGESTION))): decode(encode({"a": DatetimeMS(small_ms)})) + def test_array_of_documents_to_buffer(self): + doc = dict(a=1) + buf = _array_of_documents_to_buffer(encode({"0": doc})) + self.assertEqual(buf, encode(doc)) + buf = _array_of_documents_to_buffer(encode({"0": doc, "1": doc})) + self.assertEqual(buf, encode(doc) + encode(doc)) + with self.assertRaises(InvalidBSON): + _array_of_documents_to_buffer(encode({"0": doc, "1": doc}) + b"1") + buf = encode({"0": doc, "1": doc}) + buf = buf[:-1] + b"1" + with self.assertRaises(InvalidBSON): + _array_of_documents_to_buffer(buf) + # We replace the size of the array with \xff\xff\xff\x00 which is -221 as an int32. + buf = b"\x14\x00\x00\x00\x04a\x00\xff\xff\xff\x00\x100\x00\x01\x00\x00\x00\x00\x00" + with self.assertRaises(InvalidBSON): + _array_of_documents_to_buffer(buf) + class TestLongLongToString(unittest.TestCase): def test_long_long_to_string(self):