Better handling of corrupt/invalid BSON PYTHON-571

This commit is contained in:
behackett 2013-10-08 13:41:32 -07:00
parent 12283dcffa
commit ab4d63b658
3 changed files with 156 additions and 43 deletions

View File

@ -139,13 +139,19 @@ def _get_number(data, position, as_class, tz_aware, uuid_subtype):
def _get_string(data, position, as_class, tz_aware, uuid_subtype):
length = struct.unpack("<i", data[position:position + 4])[0] - 1
length = struct.unpack("<i", data[position:position + 4])[0]
if (len(data) - position - 4) < length:
raise InvalidBSON("invalid string length")
position += 4
return _get_c_string(data, position, length)
if data[position + length - 1] != ZERO:
raise InvalidBSON("invalid end of string")
return _get_c_string(data, position, length - 1)
def _get_object(data, position, as_class, tz_aware, uuid_subtype):
obj_size = struct.unpack("<i", data[position:position + 4])[0]
if data[position + obj_size - 1:position + obj_size] != ZERO:
raise InvalidBSON("bad eoo")
encoded = data[position + 4:position + obj_size - 1]
object = _elements_to_dict(encoded, as_class, tz_aware, uuid_subtype)
position += obj_size

View File

@ -97,6 +97,8 @@ static struct module_state _state;
#define JAVA_LEGACY 5
#define CSHARP_LEGACY 6
#define BSON_MAX_SIZE 2147483647
/* The smallest possible BSON document, i.e. "{}" */
#define BSON_MIN_SIZE 5
/* Get an error class from the bson.errors module.
*
@ -1430,7 +1432,7 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
return result;
}
static PyObject* get_value(PyObject* self, const char* buffer, int* position,
static PyObject* get_value(PyObject* self, const char* buffer, unsigned* position,
int type, int max, PyObject* as_class,
unsigned char tz_aware, unsigned char uuid_subtype) {
struct module_state *state = GETSTATE(self);
@ -1455,28 +1457,44 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
case 2:
case 14:
{
int value_length = ((int*)(buffer + *position))[0] - 1;
if (max < value_length) {
unsigned value_length;
if (max < 4) {
goto invalid;
}
memcpy(&value_length, buffer + *position, 4);
/* Encoded string length + string */
if (max < 4 + value_length) {
goto invalid;
}
*position += 4;
value = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
/* Strings must end in \0 */
if (buffer[*position + value_length - 1]) {
goto invalid;
}
value = PyUnicode_DecodeUTF8(buffer + *position, value_length - 1, "strict");
if (!value) {
return NULL;
}
*position += value_length + 1;
*position += value_length;
break;
}
case 3:
{
PyObject* collection;
int size;
unsigned size;
if (max < 4) {
goto invalid;
}
memcpy(&size, buffer + *position, 4);
if (size < 0 || max < size) {
if (size < BSON_MIN_SIZE || max < size) {
goto invalid;
}
/* Check for bad eoo */
if (buffer[*position + size - 1]) {
goto invalid;
}
value = elements_to_dict(self, buffer + *position + 4,
size - 5, as_class, tz_aware, uuid_subtype);
(int)size - 5, as_class, tz_aware, uuid_subtype);
if (!value) {
return NULL;
}
@ -1530,14 +1548,20 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
case 4:
{
int size,
end;
unsigned size, end;
if (max < 4) {
goto invalid;
}
memcpy(&size, buffer + *position, 4);
if (max < size) {
goto invalid;
}
end = *position + size - 1;
/* Check for bad eoo */
if (buffer[end]) {
goto invalid;
}
*position += 4;
value = PyList_New(0);
@ -1549,14 +1573,19 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
int bson_type = (int)buffer[(*position)++];
size_t key_size = strlen(buffer + *position);
if (key_size > BSON_MAX_SIZE) {
if (max < (int)key_size) {
Py_DECREF(value);
goto invalid;
}
/* just skip the key, they're in order. */
*position += (int)key_size + 1;
*position += (unsigned)key_size + 1;
if (Py_EnterRecursiveCall(" while decoding a list value")) {
Py_DECREF(value);
return NULL;
}
to_append = get_value(self, buffer, position, bson_type,
max - (int)key_size, as_class, tz_aware, uuid_subtype);
Py_LeaveRecursiveCall();
if (!to_append) {
Py_DECREF(value);
return NULL;
@ -1572,8 +1601,11 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
PyObject* data;
PyObject* st;
PyObject* type_to_create;
int length, subtype;
unsigned length, subtype;
if (max < 4) {
goto invalid;
}
memcpy(&length, buffer + *position, 4);
if (max < length) {
goto invalid;
@ -1779,7 +1811,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
if (!pattern) {
return NULL;
}
*position += (int)pattern_length + 1;
*position += (unsigned)pattern_length + 1;
if ((flags_length = strlen(buffer + *position)) > BSON_MAX_SIZE) {
Py_DECREF(pattern);
goto invalid;
@ -1804,7 +1836,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
flags |= 64;
}
}
*position += (int)flags_length + 1;
*position += (unsigned)flags_length + 1;
if ((compile_func = _get_object(state->RECompile, "re", "compile"))) {
value = PyObject_CallFunction(compile_func, "Oi", pattern, flags);
Py_DECREF(compile_func);
@ -1814,23 +1846,32 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
case 12:
{
size_t coll_length;
unsigned coll_length;
PyObject* collection;
PyObject* id = NULL;
PyObject* objectid_type;
PyObject* dbref_type;
*position += 4;
coll_length = strlen(buffer + *position);
if (coll_length > BSON_MAX_SIZE || max < (int)coll_length + 12) {
if (max < 4) {
goto invalid;
}
memcpy(&coll_length, buffer + *position, 4);
/* Encoded string length + string + 12 byte ObjectId */
if (max < 4 + coll_length + 12) {
goto invalid;
}
*position += 4;
/* Strings must end in \0 */
if (buffer[*position + coll_length - 1]) {
goto invalid;
}
collection = PyUnicode_DecodeUTF8(buffer + *position,
coll_length, "strict");
coll_length - 1, "strict");
if (!collection) {
return NULL;
}
*position += (int)coll_length + 1;
*position += coll_length;
if ((objectid_type = _get_object(state->ObjectId, "bson.objectid", "ObjectId"))) {
id = PyObject_CallFunction(objectid_type, "s#", buffer + *position, 12);
@ -1853,16 +1894,25 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
{
PyObject* code;
PyObject* code_type;
int value_length = ((int*)(buffer + *position))[0] - 1;
if (max < value_length) {
unsigned value_length;
if (max < 4) {
goto invalid;
}
memcpy(&value_length, buffer + *position, 4);
/* Encoded string length + string */
if (max < 4 + value_length) {
goto invalid;
}
*position += 4;
code = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
/* Strings must end in \0 */
if (buffer[*position + value_length - 1]) {
goto invalid;
}
code = PyUnicode_DecodeUTF8(buffer + *position, value_length - 1, "strict");
if (!code) {
return NULL;
}
*position += value_length + 1;
*position += value_length;
if ((code_type = _get_object(state->Code, "bson.code", "Code"))) {
value = PyObject_CallFunctionObjArgs(code_type, code, NULL, NULL);
Py_DECREF(code_type);
@ -1872,25 +1922,56 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
case 15:
{
size_t code_length;
int scope_size;
unsigned c_w_s_size;
unsigned code_size;
unsigned scope_size;
PyObject* code;
PyObject* scope;
PyObject* code_type;
*position += 8;
code_length = strlen(buffer + *position);
if (code_length > BSON_MAX_SIZE || max < 8 + (int)code_length) {
if (max < 8) {
goto invalid;
}
code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict");
memcpy(&c_w_s_size, buffer + *position, 4);
*position += 4;
if (max < c_w_s_size) {
goto invalid;
}
memcpy(&code_size, buffer + *position, 4);
/* code_w_scope length + code length + code + scope length */
if (max < 4 + 4 + code_size + 4) {
goto invalid;
}
*position += 4;
/* Strings must end in \0 */
if (buffer[*position + code_size - 1]) {
goto invalid;
}
code = PyUnicode_DecodeUTF8(buffer + *position, code_size - 1, "strict");
if (!code) {
return NULL;
}
*position += (int)code_length + 1;
*position += code_size;
memcpy(&scope_size, buffer + *position, 4);
scope = elements_to_dict(self, buffer + *position + 4, scope_size - 5,
if (scope_size < BSON_MIN_SIZE) {
Py_DECREF(code);
goto invalid;
}
/* code length + code + scope length + scope */
if ((4 + code_size + 4 + scope_size) != c_w_s_size) {
Py_DECREF(code);
goto invalid;
}
/* Check for bad eoo */
if (buffer[*position + scope_size - 1]) {
goto invalid;
}
scope = elements_to_dict(self, buffer + *position + 4, (int)scope_size - 5,
(PyObject*)&PyDict_Type, tz_aware, uuid_subtype);
if (!scope) {
Py_DECREF(code);
@ -1989,16 +2070,17 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
error = _error("InvalidBSON");
if (error) {
PyErr_SetNone(error);
PyErr_SetString(error,
"invalid length or type code");
Py_DECREF(error);
}
return NULL;
}
static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
static PyObject* _elements_to_dict(PyObject* self, const char* string, int max,
PyObject* as_class, unsigned char tz_aware,
unsigned char uuid_subtype) {
int position = 0;
unsigned position = 0;
PyObject* dict = PyObject_CallObject(as_class, NULL);
if (!dict) {
return NULL;
@ -2038,6 +2120,18 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
return dict;
}
static PyObject* elements_to_dict(PyObject* self, const char* string, int max,
PyObject* as_class, unsigned char tz_aware,
unsigned char uuid_subtype) {
PyObject* result;
if (Py_EnterRecursiveCall(" while decoding a BSON document"))
return NULL;
result = _elements_to_dict(self, string, max,
as_class, tz_aware, uuid_subtype);
Py_LeaveRecursiveCall();
return result;
}
static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
int size;
Py_ssize_t total_size;
@ -2068,7 +2162,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
#else
total_size = PyString_Size(bson);
#endif
if (total_size < 5) {
if (total_size < BSON_MIN_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
PyErr_SetString(InvalidBSON,
@ -2088,7 +2182,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
}
memcpy(&size, string, 4);
if (size < 0) {
if (size < BSON_MIN_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
PyErr_SetString(InvalidBSON, "invalid message size");
@ -2097,7 +2191,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) {
return NULL;
}
if (total_size < size) {
if (total_size < size || total_size > BSON_MAX_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
PyErr_SetString(InvalidBSON, "objsize too large");
@ -2173,7 +2267,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) {
return NULL;
while (total_size > 0) {
if (total_size < 5) {
if (total_size < BSON_MIN_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
PyErr_SetString(InvalidBSON,
@ -2185,7 +2279,7 @@ static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) {
}
memcpy(&size, string, 4);
if (size < 0) {
if (size < BSON_MIN_SIZE) {
PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
PyErr_SetString(InvalidBSON, "invalid message size");

View File

@ -63,6 +63,8 @@ class TestBSON(unittest.TestCase):
# the simplest valid BSON document
self.assertTrue(is_valid(b("\x05\x00\x00\x00\x00")))
self.assertTrue(is_valid(BSON(b("\x05\x00\x00\x00\x00"))))
# failure cases
self.assertFalse(is_valid(b("\x04\x00\x00\x00\x00")))
self.assertFalse(is_valid(b("\x05\x00\x00\x00\x01")))
self.assertFalse(is_valid(b("\x05\x00\x00\x00")))
@ -70,6 +72,17 @@ class TestBSON(unittest.TestCase):
self.assertFalse(is_valid(b("\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12")))
self.assertFalse(is_valid(b("\x09\x00\x00\x00\x10a\x00\x05\x00")))
self.assertFalse(is_valid(b("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")))
self.assertFalse(is_valid(b("\x13\x00\x00\x00\x02foo\x00"
"\x04\x00\x00\x00bar\x00\x00")))
self.assertFalse(is_valid(b("\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00"
"\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00")))
self.assertFalse(is_valid(b("\x15\x00\x00\x00\x03foo\x00\x0c"
"\x00\x00\x00\x08bar\x00\x01\x00\x00")))
self.assertFalse(is_valid(b("\x1c\x00\x00\x00\x03foo\x00"
"\x12\x00\x00\x00\x02bar\x00"
"\x05\x00\x00\x00baz\x00\x00\x00")))
self.assertFalse(is_valid(b("\x10\x00\x00\x00\x02a\x00"
"\x04\x00\x00\x00abc\xff\x00")))
def test_random_data_is_not_bson(self):
qcheck.check_unittest(self, qcheck.isnt(is_valid),