Better handling of corrupt/invalid BSON PYTHON-571
This commit is contained in:
parent
12283dcffa
commit
ab4d63b658
@ -139,13 +139,19 @@ def _get_number(data, position, as_class, tz_aware, uuid_subtype):
|
||||
|
||||
|
||||
def _get_string(data, position, as_class, tz_aware, uuid_subtype):
|
||||
length = struct.unpack("<i", data[position:position + 4])[0] - 1
|
||||
length = struct.unpack("<i", data[position:position + 4])[0]
|
||||
if (len(data) - position - 4) < length:
|
||||
raise InvalidBSON("invalid string length")
|
||||
position += 4
|
||||
return _get_c_string(data, position, length)
|
||||
if data[position + length - 1] != ZERO:
|
||||
raise InvalidBSON("invalid end of string")
|
||||
return _get_c_string(data, position, length - 1)
|
||||
|
||||
|
||||
def _get_object(data, position, as_class, tz_aware, uuid_subtype):
|
||||
obj_size = struct.unpack("<i", data[position:position + 4])[0]
|
||||
if data[position + obj_size - 1:position + obj_size] != ZERO:
|
||||
raise InvalidBSON("bad eoo")
|
||||
encoded = data[position + 4:position + obj_size - 1]
|
||||
object = _elements_to_dict(encoded, as_class, tz_aware, uuid_subtype)
|
||||
position += obj_size
|
||||
|
||||
@ -97,6 +97,8 @@ static struct module_state _state;
|
||||
#define JAVA_LEGACY 5
|
||||
#define CSHARP_LEGACY 6
|
||||
#define BSON_MAX_SIZE 2147483647
|
||||
/* The smallest possible BSON document, i.e. "{}" */
|
||||
#define BSON_MIN_SIZE 5
|
||||
|
||||
/* Get an error class from the bson.errors module.
|
||||
*
|
||||
@ -1430,7 +1432,7 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
static PyObject* get_value(PyObject* self, const char* buffer, unsigned* position,
|
||||
int type, int max, PyObject* as_class,
|
||||
unsigned char tz_aware, unsigned char uuid_subtype) {
|
||||
struct module_state *state = GETSTATE(self);
|
||||
@ -1455,28 +1457,44 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
case 2:
|
||||
case 14:
|
||||
{
|
||||
int value_length = ((int*)(buffer + *position))[0] - 1;
|
||||
if (max < value_length) {
|
||||
unsigned value_length;
|
||||
if (max < 4) {
|
||||
goto invalid;
|
||||
}
|
||||
memcpy(&value_length, buffer + *position, 4);
|
||||
/* Encoded string length + string */
|
||||
if (max < 4 + value_length) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
value = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
|
||||
/* Strings must end in \0 */
|
||||
if (buffer[*position + value_length - 1]) {
|
||||
goto invalid;
|
||||
}
|
||||
value = PyUnicode_DecodeUTF8(buffer + *position, value_length - 1, "strict");
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
*position += value_length + 1;
|
||||
*position += value_length;
|
||||
break;
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
PyObject* collection;
|
||||
int size;
|
||||
unsigned size;
|
||||
if (max < 4) {
|
||||
goto invalid;
|
||||
}
|
||||
memcpy(&size, buffer + *position, 4);
|
||||
if (size < 0 || max < size) {
|
||||
if (size < BSON_MIN_SIZE || max < size) {
|
||||
goto invalid;
|
||||
}
|
||||
/* Check for bad eoo */
|
||||
if (buffer[*position + size - 1]) {
|
||||
goto invalid;
|
||||
}
|
||||
value = elements_to_dict(self, buffer + *position + 4,
|
||||
size - 5, as_class, tz_aware, uuid_subtype);
|
||||
(int)size - 5, as_class, tz_aware, uuid_subtype);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
@ -1530,14 +1548,20 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
int size,
|
||||
end;
|
||||
unsigned size, end;
|
||||
|
||||
if (max < 4) {
|
||||
goto invalid;
|
||||
}
|
||||
memcpy(&size, buffer + *position, 4);
|
||||
if (max < size) {
|
||||
goto invalid;
|
||||
}
|
||||
end = *position + size - 1;
|
||||
/* Check for bad eoo */
|
||||
if (buffer[end]) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
|
||||
value = PyList_New(0);
|
||||
@ -1549,14 +1573,19 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
|
||||
int bson_type = (int)buffer[(*position)++];
|
||||
size_t key_size = strlen(buffer + *position);
|
||||
if (key_size > BSON_MAX_SIZE) {
|
||||
if (max < (int)key_size) {
|
||||
Py_DECREF(value);
|
||||
goto invalid;
|
||||
}
|
||||
/* just skip the key, they're in order. */
|
||||
*position += (int)key_size + 1;
|
||||
*position += (unsigned)key_size + 1;
|
||||
if (Py_EnterRecursiveCall(" while decoding a list value")) {
|
||||
Py_DECREF(value);
|
||||
return NULL;
|
||||
}
|
||||
to_append = get_value(self, buffer, position, bson_type,
|
||||
max - (int)key_size, as_class, tz_aware, uuid_subtype);
|
||||
Py_LeaveRecursiveCall();
|
||||
if (!to_append) {
|
||||
Py_DECREF(value);
|
||||
return NULL;
|
||||
@ -1572,8 +1601,11 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
PyObject* data;
|
||||
PyObject* st;
|
||||
PyObject* type_to_create;
|
||||
int length, subtype;
|
||||
unsigned length, subtype;
|
||||
|
||||
if (max < 4) {
|
||||
goto invalid;
|
||||
}
|
||||
memcpy(&length, buffer + *position, 4);
|
||||
if (max < length) {
|
||||
goto invalid;
|
||||
@ -1779,7 +1811,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
if (!pattern) {
|
||||
return NULL;
|
||||
}
|
||||
*position += (int)pattern_length + 1;
|
||||
*position += (unsigned)pattern_length + 1;
|
||||
if ((flags_length = strlen(buffer + *position)) > BSON_MAX_SIZE) {
|
||||
Py_DECREF(pattern);
|
||||
goto invalid;
|
||||
@ -1804,7 +1836,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
flags |= 64;
|
||||
}
|
||||
}
|
||||
*position += (int)flags_length + 1;
|
||||
*position += (unsigned)flags_length + 1;
|
||||
if ((compile_func = _get_object(state->RECompile, "re", "compile"))) {
|
||||
value = PyObject_CallFunction(compile_func, "Oi", pattern, flags);
|
||||
Py_DECREF(compile_func);
|
||||
@ -1814,23 +1846,32 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
case 12:
|
||||
{
|
||||
size_t coll_length;
|
||||
unsigned coll_length;
|
||||
PyObject* collection;
|
||||
PyObject* id = NULL;
|
||||
PyObject* objectid_type;
|
||||
PyObject* dbref_type;
|
||||
|
||||
*position += 4;
|
||||
coll_length = strlen(buffer + *position);
|
||||
if (coll_length > BSON_MAX_SIZE || max < (int)coll_length + 12) {
|
||||
if (max < 4) {
|
||||
goto invalid;
|
||||
}
|
||||
memcpy(&coll_length, buffer + *position, 4);
|
||||
/* Encoded string length + string + 12 byte ObjectId */
|
||||
if (max < 4 + coll_length + 12) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
/* Strings must end in \0 */
|
||||
if (buffer[*position + coll_length - 1]) {
|
||||
goto invalid;
|
||||
}
|
||||
|
||||
collection = PyUnicode_DecodeUTF8(buffer + *position,
|
||||
coll_length, "strict");
|
||||
coll_length - 1, "strict");
|
||||
if (!collection) {
|
||||
return NULL;
|
||||
}
|
||||
*position += (int)coll_length + 1;
|
||||
*position += coll_length;
|
||||
|
||||
if ((objectid_type = _get_object(state->ObjectId, "bson.objectid", "ObjectId"))) {
|
||||
id = PyObject_CallFunction(objectid_type, "s#", buffer + *position, 12);
|
||||
@ -1853,16 +1894,25 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
{
|
||||
PyObject* code;
|
||||
PyObject* code_type;
|
||||
int value_length = ((int*)(buffer + *position))[0] - 1;
|
||||
if (max < value_length) {
|
||||
unsigned value_length;
|
||||
if (max < 4) {
|
||||
goto invalid;
|
||||
}
|
||||
memcpy(&value_length, buffer + *position, 4);
|
||||
/* Encoded string length + string */
|
||||
if (max < 4 + value_length) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
code = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
|
||||
/* Strings must end in \0 */
|
||||
if (buffer[*position + value_length - 1]) {
|
||||
goto invalid;
|
||||
}
|
||||
code = PyUnicode_DecodeUTF8(buffer + *position, value_length - 1, "strict");
|
||||
if (!code) {
|
||||
return NULL;
|
||||
}
|
||||
*position += value_length + 1;
|
||||
*position += value_length;
|
||||
if ((code_type = _get_object(state->Code, "bson.code", "Code"))) {
|
||||
value = PyObject_CallFunctionObjArgs(code_type, code, NULL, NULL);
|
||||
Py_DECREF(code_type);
|
||||
@ -1872,25 +1922,56 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
case 15:
|
||||
{
|
||||
size_t code_length;
|
||||
int scope_size;
|
||||
unsigned c_w_s_size;
|
||||
unsigned code_size;
|
||||
unsigned scope_size;
|
||||
PyObject* code;
|
||||
PyObject* scope;
|
||||
PyObject* code_type;
|
||||
|
||||
*position += 8;
|
||||
code_length = strlen(buffer + *position);
|
||||
if (code_length > BSON_MAX_SIZE || max < 8 + (int)code_length) {
|
||||
if (max < 8) {
|
||||
goto invalid;
|
||||
}
|
||||
code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict");
|
||||
|
||||
memcpy(&c_w_s_size, buffer + *position, 4);
|
||||
*position += 4;
|
||||
|
||||
if (max < c_w_s_size) {
|
||||
goto invalid;
|
||||
}
|
||||
|
||||
memcpy(&code_size, buffer + *position, 4);
|
||||
/* code_w_scope length + code length + code + scope length */
|
||||
if (max < 4 + 4 + code_size + 4) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
/* Strings must end in \0 */
|
||||
if (buffer[*position + code_size - 1]) {
|
||||
goto invalid;
|
||||
}
|
||||
code = PyUnicode_DecodeUTF8(buffer + *position, code_size - 1, "strict");
|
||||
if (!code) {
|
||||
return NULL;
|
||||
}
|
||||
*position += (int)code_length + 1;
|
||||
*position += code_size;
|
||||
|
||||
memcpy(&scope_size, buffer + *position, 4);
|
||||
scope = elements_to_dict(self, buffer + *position + 4, scope_size - 5,
|
||||
if (scope_size < BSON_MIN_SIZE) {
|
||||
Py_DECREF(code);
|
||||
goto invalid;
|
||||
}
|
||||
/* code length + code + scope length + scope */
|
||||
if ((4 + code_size + 4 + scope_size) != c_w_s_size) {
|
||||
Py_DECREF(code);
|
||||
goto invalid;
|
||||
}
|
||||
|
||||
/* Check for bad eoo */
|
||||
if (buffer[*position + scope_size - 1]) {
|
||||
goto invalid;
|
||||
}
|
||||
scope = elements_to_dict(self, buffer + *position + 4, (int)scope_size - 5,
|
||||
(PyObject*)&PyDict_Type, tz_aware, uuid_subtype);
|
||||
if (!scope) {
|
||||
Py_DECREF(code);
|
||||
@ -1989,16 +2070,17 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
|
||||
error = _error("InvalidBSON");
|
||||
if (error) {
|
||||
PyErr_SetNone(error);
|
||||
PyErr_SetString(error,
|
||||
"invalid length or type code");
|
||||
Py_DECREF(error);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
|
||||
static PyObject* _elements_to_dict(PyObject* self, const char* string, int max,
|
||||
PyObject* as_class, unsigned char tz_aware,
|
||||
unsigned char uuid_subtype) {
|
||||
int position = 0;
|
||||
unsigned position = 0;
|
||||
PyObject* dict = PyObject_CallObject(as_class, NULL);
|
||||
if (!dict) {
|
||||
return NULL;
|
||||
@ -2038,6 +2120,18 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
|
||||
return dict;
|
||||
}
|
||||
|
||||
static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
|
||||
PyObject* as_class, unsigned char tz_aware,
|
||||
unsigned char uuid_subtype) {
|
||||
PyObject* result;
|
||||
if (Py_EnterRecursiveCall(" while decoding a BSON document"))
|
||||
return NULL;
|
||||
result = _elements_to_dict(self, string, max,
|
||||
as_class, tz_aware, uuid_subtype);
|
||||
Py_LeaveRecursiveCall();
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
|
||||
int size;
|
||||
Py_ssize_t total_size;
|
||||
@ -2068,7 +2162,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
|
||||
#else
|
||||
total_size = PyString_Size(bson);
|
||||
#endif
|
||||
if (total_size < 5) {
|
||||
if (total_size < BSON_MIN_SIZE) {
|
||||
PyObject* InvalidBSON = _error("InvalidBSON");
|
||||
if (InvalidBSON) {
|
||||
PyErr_SetString(InvalidBSON,
|
||||
@ -2088,7 +2182,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
|
||||
}
|
||||
|
||||
memcpy(&size, string, 4);
|
||||
if (size < 0) {
|
||||
if (size < BSON_MIN_SIZE) {
|
||||
PyObject* InvalidBSON = _error("InvalidBSON");
|
||||
if (InvalidBSON) {
|
||||
PyErr_SetString(InvalidBSON, "invalid message size");
|
||||
@ -2097,7 +2191,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (total_size < size) {
|
||||
if (total_size < size || total_size > BSON_MAX_SIZE) {
|
||||
PyObject* InvalidBSON = _error("InvalidBSON");
|
||||
if (InvalidBSON) {
|
||||
PyErr_SetString(InvalidBSON, "objsize too large");
|
||||
@ -2173,7 +2267,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) {
|
||||
return NULL;
|
||||
|
||||
while (total_size > 0) {
|
||||
if (total_size < 5) {
|
||||
if (total_size < BSON_MIN_SIZE) {
|
||||
PyObject* InvalidBSON = _error("InvalidBSON");
|
||||
if (InvalidBSON) {
|
||||
PyErr_SetString(InvalidBSON,
|
||||
@ -2185,7 +2279,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) {
|
||||
}
|
||||
|
||||
memcpy(&size, string, 4);
|
||||
if (size < 0) {
|
||||
if (size < BSON_MIN_SIZE) {
|
||||
PyObject* InvalidBSON = _error("InvalidBSON");
|
||||
if (InvalidBSON) {
|
||||
PyErr_SetString(InvalidBSON, "invalid message size");
|
||||
|
||||
@ -63,6 +63,8 @@ class TestBSON(unittest.TestCase):
|
||||
# the simplest valid BSON document
|
||||
self.assertTrue(is_valid(b("\x05\x00\x00\x00\x00")))
|
||||
self.assertTrue(is_valid(BSON(b("\x05\x00\x00\x00\x00"))))
|
||||
|
||||
# failure cases
|
||||
self.assertFalse(is_valid(b("\x04\x00\x00\x00\x00")))
|
||||
self.assertFalse(is_valid(b("\x05\x00\x00\x00\x01")))
|
||||
self.assertFalse(is_valid(b("\x05\x00\x00\x00")))
|
||||
@ -70,6 +72,17 @@ class TestBSON(unittest.TestCase):
|
||||
self.assertFalse(is_valid(b("\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12")))
|
||||
self.assertFalse(is_valid(b("\x09\x00\x00\x00\x10a\x00\x05\x00")))
|
||||
self.assertFalse(is_valid(b("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")))
|
||||
self.assertFalse(is_valid(b("\x13\x00\x00\x00\x02foo\x00"
|
||||
"\x04\x00\x00\x00bar\x00\x00")))
|
||||
self.assertFalse(is_valid(b("\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00"
|
||||
"\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00")))
|
||||
self.assertFalse(is_valid(b("\x15\x00\x00\x00\x03foo\x00\x0c"
|
||||
"\x00\x00\x00\x08bar\x00\x01\x00\x00")))
|
||||
self.assertFalse(is_valid(b("\x1c\x00\x00\x00\x03foo\x00"
|
||||
"\x12\x00\x00\x00\x02bar\x00"
|
||||
"\x05\x00\x00\x00baz\x00\x00\x00")))
|
||||
self.assertFalse(is_valid(b("\x10\x00\x00\x00\x02a\x00"
|
||||
"\x04\x00\x00\x00abc\xff\x00")))
|
||||
|
||||
def test_random_data_is_not_bson(self):
|
||||
qcheck.check_unittest(self, qcheck.isnt(is_valid),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user