From 568a3b1294711cbeeb7bff36718d981e4ad103e2 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 8 Dec 2023 10:08:41 -0800 Subject: [PATCH] PYTHON-4084 Fix BSON inflation for RawBSONDocument (#1456) --- bson/_cbsonmodule.c | 45 ++++++++++++------------------------------- doc/changelog.rst | 1 + test/test_raw_bson.py | 8 +++++++- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 4e1881a27..7e691d8be 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -1878,19 +1878,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, goto invalid; } - if (options->is_raw_bson) { - value = PyObject_CallFunction( - options->document_class, "y#O", - buffer + *position, (Py_ssize_t)size, options->options_obj); - if (!value) { - goto invalid; - } - *position += size; - break; - } - - value = elements_to_dict(self, buffer + *position + 4, - size - 5, options); + value = elements_to_dict(self, buffer + *position, + size, options); if (!value) { goto invalid; } @@ -2456,8 +2445,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, if (buffer[*position + scope_size - 1]) { goto invalid; } - scope = elements_to_dict(self, buffer + *position + 4, - scope_size - 5, options); + scope = elements_to_dict(self, buffer + *position, + scope_size, options); if (!scope) { Py_DECREF(code); goto invalid; @@ -2809,9 +2798,14 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, unsigned max, const codec_options_t* options) { PyObject* result; + if (options->is_raw_bson) { + return PyObject_CallFunction( + options->document_class, "y#O", + string, max, options->options_obj); + } if (Py_EnterRecursiveCall(" while decoding a BSON document")) return NULL; - result = _elements_to_dict(self, string, max, options); + result = _elements_to_dict(self, string + 4, max - 5, options); Py_LeaveRecursiveCall(); return result; } @@ -2902,15 +2896,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { goto done; } - /* No need to decode fields if using RawBSONDocument */ - if (options.is_raw_bson) { - result = PyObject_CallFunction( - options.document_class, "y#O", string, (Py_ssize_t)size, - options_obj); - } - else { - result = elements_to_dict(self, string + 4, (unsigned)size - 5, &options); - } + result = elements_to_dict(self, string, (unsigned)size, &options); done: PyBuffer_Release(&view); destroy_codec_options(&options); @@ -2988,14 +2974,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) { goto fail; } - /* No need to decode fields if using RawBSONDocument. */ - if (options.is_raw_bson) { - dict = PyObject_CallFunction( - options.document_class, "y#O", string, (Py_ssize_t)size, - options_obj); - } else { - dict = elements_to_dict(self, string + 4, (unsigned)size - 5, &options); - } + dict = elements_to_dict(self, string, (unsigned)size, &options); if (!dict) { Py_DECREF(result); goto fail; diff --git a/doc/changelog.rst b/doc/changelog.rst index 898ab5187..38f4c531f 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -10,6 +10,7 @@ PyMongo 4.7 brings a number of improvements including: :attr:`pymongo.monitoring.CommandStartedEvent.server_connection_id`, :attr:`pymongo.monitoring.CommandSucceededEvent.server_connection_id`, and :attr:`pymongo.monitoring.CommandFailedEvent.server_connection_id` properties. +- Fixed a bug where inflating a :class:`~bson.raw_bson.RawBSONDocument` containing a :class:`~bson.code.Code` would cause an error. Changes in Version 4.6.1 ------------------------ diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 93012dff4..105eab7d9 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import client_context, unittest from test.test_client import IntegrationTest -from bson import decode, encode +from bson import Code, decode, encode from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.errors import InvalidBSON @@ -199,6 +199,12 @@ class TestRawBSONDocument(IntegrationTest): for rkey, elt in zip(rawdoc, keyvaluepairs): self.assertEqual(rkey, elt[0]) + def test_contains_code_with_scope(self): + doc = RawBSONDocument(encode({"value": Code("x=1", scope={})})) + + self.assertEqual(decode(encode(doc)), {"value": Code("x=1", {})}) + self.assertEqual(doc["value"].scope, RawBSONDocument(encode({}))) + if __name__ == "__main__": unittest.main()