More invalid BSON handling PYTHON-252

This commit is contained in:
behackett 2011-06-14 11:53:06 -07:00
parent 7444dc765b
commit 1e100405be
2 changed files with 64 additions and 8 deletions

View File

@ -857,12 +857,16 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
}
static PyObject* get_value(const char* buffer, int* position, int type,
PyObject* as_class, unsigned char tz_aware) {
int max, PyObject* as_class, unsigned char tz_aware) {
PyObject* value;
PyObject* error;
switch (type) {
case 1:
{
double d;
if (max < 8) {
goto invalid;
}
memcpy(&d, buffer + *position, 8);
value = PyFloat_FromDouble(d);
if (!value) {
@ -876,6 +880,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
case 14:
{
int value_length = ((int*)(buffer + *position))[0] - 1;
if (max < value_length) {
goto invalid;
}
*position += 4;
value = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
if (!value) {
@ -888,6 +895,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
{
int size;
memcpy(&size, buffer + *position, 4);
if (max < size) {
goto invalid;
}
value = elements_to_dict(buffer + *position + 4, size - 5, as_class, tz_aware);
if (!value) {
return NULL;
@ -934,6 +944,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
end;
memcpy(&size, buffer + *position, 4);
if (max < size) {
goto invalid;
}
end = *position + size - 1;
*position += 4;
@ -947,7 +960,7 @@ static PyObject* get_value(const char* buffer, int* position, int type,
int type = (int)buffer[(*position)++];
int key_size = strlen(buffer + *position);
*position += key_size + 1; /* just skip the key, they're in order. */
to_append = get_value(buffer, position, type, as_class, tz_aware);
to_append = get_value(buffer, position, type, max - key_size, as_class, tz_aware);
if (!to_append) {
return NULL;
}
@ -965,6 +978,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
subtype;
memcpy(&length, buffer + *position, 4);
if (max < length) {
goto invalid;
}
subtype = (unsigned char)buffer[*position + 4];
if (subtype == 2) {
@ -1027,6 +1043,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
}
case 7:
{
if (max < 12) {
goto invalid;
}
value = PyObject_CallFunction(ObjectId, "s#", buffer + *position, 12);
if (!value) {
return NULL;
@ -1042,10 +1061,14 @@ static PyObject* get_value(const char* buffer, int* position, int type,
}
case 9:
{
PyObject* naive;
PyObject* replace;
PyObject* args;
PyObject* kwargs;
PyObject* naive = datetime_from_millis(*(long long*)(buffer + *position));
if (max < 8) {
goto invalid;
}
naive = datetime_from_millis(*(long long*)(buffer + *position));
*position += 8;
if (!tz_aware) { /* In the naive case, we're done here. */
value = naive;
@ -1085,17 +1108,23 @@ static PyObject* get_value(const char* buffer, int* position, int type,
}
case 11:
{
PyObject* pattern;
int flags_length,
flags,
i;
int pattern_length = strlen(buffer + *position);
PyObject* pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict");
if (max < pattern_length) {
goto invalid;
}
pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict");
if (!pattern) {
return NULL;
}
*position += pattern_length + 1;
flags_length = strlen(buffer + *position);
if (max < pattern_length + flags_length) {
goto invalid;
}
flags = 0;
for (i = 0; i < flags_length; i++) {
if (buffer[*position + i] == 'i') {
@ -1125,11 +1154,17 @@ static PyObject* get_value(const char* buffer, int* position, int type,
*position += 4;
collection_length = strlen(buffer + *position);
if (max < collection_length) {
goto invalid;
}
collection = PyUnicode_DecodeUTF8(buffer + *position, collection_length, "strict");
if (!collection) {
return NULL;
}
*position += collection_length + 1;
if (max < collection_length + 12) {
goto invalid;
}
id = PyObject_CallFunction(ObjectId, "s#", buffer + *position, 12);
if (!id) {
Py_DECREF(collection);
@ -1150,6 +1185,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
*position += 8;
code_length = strlen(buffer + *position);
if (max < 8 + code_length) {
goto invalid;
}
code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict");
if (!code) {
return NULL;
@ -1173,6 +1211,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
case 16:
{
int i;
if (max < 4) {
goto invalid;
}
memcpy(&i, buffer + *position, 4);
value = PyInt_FromLong(i);
if (!value) {
@ -1184,6 +1225,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
case 17:
{
unsigned int time, inc;
if (max < 8) {
goto invalid;
}
memcpy(&inc, buffer + *position, 4);
memcpy(&time, buffer + *position + 4, 4);
value = PyObject_CallFunction(Timestamp, "II", time, inc);
@ -1196,6 +1240,9 @@ static PyObject* get_value(const char* buffer, int* position, int type,
case 18:
{
long long ll;
if (max < 8) {
goto invalid;
}
memcpy(&ll, buffer + *position, 8);
value = PyLong_FromLongLong(ll);
if (!value) {
@ -1223,6 +1270,13 @@ static PyObject* get_value(const char* buffer, int* position, int type,
}
}
return value;
invalid:
error = _error("InvalidBSON");
PyErr_SetNone(error);
Py_DECREF(error);
return NULL;
}
static PyObject* elements_to_dict(const char* string, int max,
@ -1233,6 +1287,8 @@ static PyObject* elements_to_dict(const char* string, int max,
return NULL;
}
while (position < max) {
PyObject* name;
PyObject* value;
int type = (int)string[position++];
int name_length = strlen(string + position);
if (position + name_length >= max) {
@ -1241,13 +1297,12 @@ static PyObject* elements_to_dict(const char* string, int max,
Py_DECREF(InvalidBSON);
return NULL;
}
PyObject* name = PyUnicode_DecodeUTF8(string + position, name_length, "strict");
PyObject* value;
name = PyUnicode_DecodeUTF8(string + position, name_length, "strict");
if (!name) {
return NULL;
}
position += name_length + 1;
value = get_value(string, &position, type, as_class, tz_aware);
value = get_value(string, &position, type, max - position, as_class, tz_aware);
if (!value) {
return NULL;
}

View File

@ -69,6 +69,7 @@ class TestBSON(unittest.TestCase):
self.assertFalse(is_valid("\x05\x00\x00\x00"))
self.assertFalse(is_valid("\x05\x00\x00\x00\x00\x00"))
self.assertFalse(is_valid("\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12"))
self.assertFalse(is_valid("\x09\x00\x00\x00\x10a\x00\x05\x00"))
def test_random_data_is_not_bson(self):
qcheck.check_unittest(self, qcheck.isnt(is_valid),