diff --git a/bson/__init__.py b/bson/__init__.py index 919772761..e7cbe0982 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -139,13 +139,19 @@ def _get_number(data, position, as_class, tz_aware, uuid_subtype): def _get_string(data, position, as_class, tz_aware, uuid_subtype): - length = struct.unpack(" BSON_MAX_SIZE) { + if (max < (int)key_size) { Py_DECREF(value); goto invalid; } /* just skip the key, they're in order. */ - *position += (int)key_size + 1; + *position += (unsigned)key_size + 1; + if (Py_EnterRecursiveCall(" while decoding a list value")) { + Py_DECREF(value); + return NULL; + } to_append = get_value(self, buffer, position, bson_type, max - (int)key_size, as_class, tz_aware, uuid_subtype); + Py_LeaveRecursiveCall(); if (!to_append) { Py_DECREF(value); return NULL; @@ -1572,8 +1601,11 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, PyObject* data; PyObject* st; PyObject* type_to_create; - int length, subtype; + unsigned length, subtype; + if (max < 4) { + goto invalid; + } memcpy(&length, buffer + *position, 4); if (max < length) { goto invalid; @@ -1779,7 +1811,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, if (!pattern) { return NULL; } - *position += (int)pattern_length + 1; + *position += (unsigned)pattern_length + 1; if ((flags_length = strlen(buffer + *position)) > BSON_MAX_SIZE) { Py_DECREF(pattern); goto invalid; @@ -1804,7 +1836,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, flags |= 64; } } - *position += (int)flags_length + 1; + *position += (unsigned)flags_length + 1; if ((compile_func = _get_object(state->RECompile, "re", "compile"))) { value = PyObject_CallFunction(compile_func, "Oi", pattern, flags); Py_DECREF(compile_func); @@ -1814,23 +1846,32 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, } case 12: { - size_t coll_length; + unsigned coll_length; PyObject* collection; PyObject* id = NULL; PyObject* objectid_type; PyObject* dbref_type; - *position += 4; - coll_length = strlen(buffer + *position); - if (coll_length > BSON_MAX_SIZE || max < (int)coll_length + 12) { + if (max < 4) { goto invalid; } + memcpy(&coll_length, buffer + *position, 4); + /* Encoded string length + string + 12 byte ObjectId */ + if (max < 4 + coll_length + 12) { + goto invalid; + } + *position += 4; + /* Strings must end in \0 */ + if (buffer[*position + coll_length - 1]) { + goto invalid; + } + collection = PyUnicode_DecodeUTF8(buffer + *position, - coll_length, "strict"); + coll_length - 1, "strict"); if (!collection) { return NULL; } - *position += (int)coll_length + 1; + *position += coll_length; if ((objectid_type = _get_object(state->ObjectId, "bson.objectid", "ObjectId"))) { id = PyObject_CallFunction(objectid_type, "s#", buffer + *position, 12); @@ -1853,16 +1894,25 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, { PyObject* code; PyObject* code_type; - int value_length = ((int*)(buffer + *position))[0] - 1; - if (max < value_length) { + unsigned value_length; + if (max < 4) { + goto invalid; + } + memcpy(&value_length, buffer + *position, 4); + /* Encoded string length + string */ + if (max < 4 + value_length) { goto invalid; } *position += 4; - code = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict"); + /* Strings must end in \0 */ + if (buffer[*position + value_length - 1]) { + goto invalid; + } + code = PyUnicode_DecodeUTF8(buffer + *position, value_length - 1, "strict"); if (!code) { return NULL; } - *position += value_length + 1; + *position += value_length; if ((code_type = _get_object(state->Code, "bson.code", "Code"))) { value = PyObject_CallFunctionObjArgs(code_type, code, NULL, NULL); Py_DECREF(code_type); @@ -1872,25 +1922,56 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, } case 15: { - size_t code_length; - int scope_size; + unsigned c_w_s_size; + unsigned code_size; + unsigned scope_size; PyObject* code; PyObject* scope; PyObject* code_type; - *position += 8; - code_length = strlen(buffer + *position); - if (code_length > BSON_MAX_SIZE || max < 8 + (int)code_length) { + if (max < 8) { goto invalid; } - code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict"); + + memcpy(&c_w_s_size, buffer + *position, 4); + *position += 4; + + if (max < c_w_s_size) { + goto invalid; + } + + memcpy(&code_size, buffer + *position, 4); + /* code_w_scope length + code length + code + scope length */ + if (max < 4 + 4 + code_size + 4) { + goto invalid; + } + *position += 4; + /* Strings must end in \0 */ + if (buffer[*position + code_size - 1]) { + goto invalid; + } + code = PyUnicode_DecodeUTF8(buffer + *position, code_size - 1, "strict"); if (!code) { return NULL; } - *position += (int)code_length + 1; + *position += code_size; memcpy(&scope_size, buffer + *position, 4); - scope = elements_to_dict(self, buffer + *position + 4, scope_size - 5, + if (scope_size < BSON_MIN_SIZE) { + Py_DECREF(code); + goto invalid; + } + /* code length + code + scope length + scope */ + if ((4 + code_size + 4 + scope_size) != c_w_s_size) { + Py_DECREF(code); + goto invalid; + } + + /* Check for bad eoo */ + if (buffer[*position + scope_size - 1]) { + goto invalid; + } + scope = elements_to_dict(self, buffer + *position + 4, (int)scope_size - 5, (PyObject*)&PyDict_Type, tz_aware, uuid_subtype); if (!scope) { Py_DECREF(code); @@ -1989,16 +2070,17 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, error = _error("InvalidBSON"); if (error) { - PyErr_SetNone(error); + PyErr_SetString(error, + "invalid length or type code"); Py_DECREF(error); } return NULL; } -static PyObject* elements_to_dict(PyObject* self, const char* string, int max, +static PyObject* _elements_to_dict(PyObject* self, const char* string, int max, PyObject* as_class, unsigned char tz_aware, unsigned char uuid_subtype) { - int position = 0; + unsigned position = 0; PyObject* dict = PyObject_CallObject(as_class, NULL); if (!dict) { return NULL; @@ -2038,6 +2120,18 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, int max, return dict; } +static PyObject* elements_to_dict(PyObject* self, const char* string, int max, + PyObject* as_class, unsigned char tz_aware, + unsigned char uuid_subtype) { + PyObject* result; + if (Py_EnterRecursiveCall(" while decoding a BSON document")) + return NULL; + result = _elements_to_dict(self, string, max, + as_class, tz_aware, uuid_subtype); + Py_LeaveRecursiveCall(); + return result; +} + static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { int size; Py_ssize_t total_size; @@ -2068,7 +2162,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { #else total_size = PyString_Size(bson); #endif - if (total_size < 5) { + if (total_size < BSON_MIN_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, @@ -2088,7 +2182,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { } memcpy(&size, string, 4); - if (size < 0) { + if (size < BSON_MIN_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "invalid message size"); @@ -2097,7 +2191,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { return NULL; } - if (total_size < size) { + if (total_size < size || total_size > BSON_MAX_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "objsize too large"); @@ -2173,7 +2267,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) { return NULL; while (total_size > 0) { - if (total_size < 5) { + if (total_size < BSON_MIN_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, @@ -2185,7 +2279,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) { } memcpy(&size, string, 4); - if (size < 0) { + if (size < BSON_MIN_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "invalid message size"); diff --git a/test/test_bson.py b/test/test_bson.py index 26497faa4..57039e69c 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -63,6 +63,8 @@ class TestBSON(unittest.TestCase): # the simplest valid BSON document self.assertTrue(is_valid(b("\x05\x00\x00\x00\x00"))) self.assertTrue(is_valid(BSON(b("\x05\x00\x00\x00\x00")))) + + # failure cases self.assertFalse(is_valid(b("\x04\x00\x00\x00\x00"))) self.assertFalse(is_valid(b("\x05\x00\x00\x00\x01"))) self.assertFalse(is_valid(b("\x05\x00\x00\x00"))) @@ -70,6 +72,17 @@ class TestBSON(unittest.TestCase): self.assertFalse(is_valid(b("\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12"))) self.assertFalse(is_valid(b("\x09\x00\x00\x00\x10a\x00\x05\x00"))) self.assertFalse(is_valid(b("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"))) + self.assertFalse(is_valid(b("\x13\x00\x00\x00\x02foo\x00" + "\x04\x00\x00\x00bar\x00\x00"))) + self.assertFalse(is_valid(b("\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00" + "\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00"))) + self.assertFalse(is_valid(b("\x15\x00\x00\x00\x03foo\x00\x0c" + "\x00\x00\x00\x08bar\x00\x01\x00\x00"))) + self.assertFalse(is_valid(b("\x1c\x00\x00\x00\x03foo\x00" + "\x12\x00\x00\x00\x02bar\x00" + "\x05\x00\x00\x00baz\x00\x00\x00"))) + self.assertFalse(is_valid(b("\x10\x00\x00\x00\x02a\x00" + "\x04\x00\x00\x00abc\xff\x00"))) def test_random_data_is_not_bson(self): qcheck.check_unittest(self, qcheck.isnt(is_valid),