diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 15dc3986c..1dc5914e2 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -129,7 +129,7 @@ static PyObject* _error(char* name) { /* Safely downcast from Py_ssize_t to int, setting an * exception and returning -1 on error. */ static int -_downcast_and_check(Py_ssize_t size, unsigned extra) { +_downcast_and_check(Py_ssize_t size, int extra) { if (size > BSON_MAX_SIZE || ((BSON_MAX_SIZE - extra) < size)) { PyObject* InvalidStringData = _error("InvalidStringData"); if (InvalidStringData) { @@ -1382,9 +1382,9 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, case 3: { PyObject* collection; - unsigned size; + int size; memcpy(&size, buffer + *position, 4); - if (max < size) { + if (size < 0 || max < size) { goto invalid; } value = elements_to_dict(self, buffer + *position + 4, @@ -1667,7 +1667,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, int flags; size_t flags_length, i; size_t pattern_length = strlen(buffer + *position); - if (max < pattern_length || pattern_length > BSON_MAX_SIZE) { + if (pattern_length > BSON_MAX_SIZE || max < (int)pattern_length) { goto invalid; } pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict"); @@ -1679,7 +1679,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, Py_DECREF(pattern); goto invalid; } - if (max < pattern_length + flags_length) { + if (max < (int)(pattern_length + flags_length)) { Py_DECREF(pattern); goto invalid; } @@ -1706,24 +1706,22 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, } case 12: { - size_t collection_length; + size_t coll_length; PyObject* collection; PyObject* id; *position += 4; - collection_length = strlen(buffer + *position); - if (max < collection_length || collection_length > BSON_MAX_SIZE) { + coll_length = strlen(buffer + *position); + if (coll_length > BSON_MAX_SIZE || max < (int)coll_length + 12) { goto invalid; } - collection = PyUnicode_DecodeUTF8(buffer + *position, collection_length, "strict"); + collection = PyUnicode_DecodeUTF8(buffer + *position, + coll_length, "strict"); if (!collection) { return NULL; } - *position += (int)collection_length + 1; - if (max < collection_length + 12) { - Py_DECREF(collection); - goto invalid; - } + *position += (int)coll_length + 1; + id = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12); if (!id) { Py_DECREF(collection); @@ -1761,7 +1759,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, *position += 8; code_length = strlen(buffer + *position); - if (max < 8 + code_length || code_length > BSON_MAX_SIZE) { + if (code_length > BSON_MAX_SIZE || max < 8 + (int)code_length) { goto invalid; } code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict"); @@ -1877,7 +1875,7 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, int max, PyObject* value; int type = (int)string[position++]; size_t name_length = strlen(string + position); - if (name_length > BSON_MAX_SIZE || position + name_length >= max) { + if (name_length > BSON_MAX_SIZE || position + (int)name_length >= max) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetNone(InvalidBSON); @@ -1908,7 +1906,7 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, int max, } static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { - unsigned int size; + int size; Py_ssize_t total_size; const char* string; PyObject* bson; @@ -1955,7 +1953,16 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { if (!string) { return NULL; } + memcpy(&size, string, 4); + if (size < 0) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "invalid message size"); + Py_DECREF(InvalidBSON); + } + return NULL; + } if (total_size < size) { PyObject* InvalidBSON = _error("InvalidBSON"); @@ -1995,7 +2002,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { } static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) { - unsigned int size; + int size; Py_ssize_t total_size; const char* string; PyObject* bson; @@ -2045,6 +2052,15 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) { } memcpy(&size, string, 4); + if (size < 0) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "invalid message size"); + Py_DECREF(InvalidBSON); + } + Py_DECREF(result); + return NULL; + } if (total_size < size) { PyObject* InvalidBSON = _error("InvalidBSON");