Avoid unsigned overflow wrapping PYTHON-571

This commit is contained in:
Bernie Hackett 2013-10-15 17:25:16 -07:00
parent 2b1ea1fdeb
commit 989a0f4f4d
3 changed files with 81 additions and 21 deletions

View File

@ -140,7 +140,7 @@ def _get_number(data, position, as_class, tz_aware, uuid_subtype):
def _get_string(data, position, as_class, tz_aware, uuid_subtype):
length = struct.unpack("<i", data[position:position + 4])[0]
if (len(data) - position - 4) < length:
if length <= 0 or (len(data) - position - 4) < length:
raise InvalidBSON("invalid string length")
position += 4
if data[position + length - 1:position + length] != ZERO:

View File

@ -1464,7 +1464,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
}
memcpy(&value_length, buffer + *position, 4);
/* Encoded string length + string */
if (max < 4 + value_length) {
if (!value_length || max < value_length || max < 4 + value_length) {
goto invalid;
}
*position += 4;
@ -1555,7 +1555,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
goto invalid;
}
memcpy(&size, buffer + *position, 4);
if (max < size) {
if (size < BSON_MIN_SIZE || max < size) {
goto invalid;
}
end = *position + size - 1;
@ -1607,7 +1607,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
unsigned length;
unsigned char subtype;
if (max < 4) {
if (max < 5) {
goto invalid;
}
memcpy(&length, buffer + *position, 4);
@ -1616,32 +1616,38 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
}
subtype = (unsigned char)buffer[*position + 4];
*position += 5;
if (subtype == 2 && length < 4) {
goto invalid;
}
#if PY_MAJOR_VERSION >= 3
/* Python3 special case. Decode BSON binary subtype 0 to bytes. */
if (subtype == 0) {
value = PyBytes_FromStringAndSize(buffer + *position + 5, length);
*position += length + 5;
value = PyBytes_FromStringAndSize(buffer + *position, length);
*position += length;
break;
}
if (subtype == 2) {
data = PyBytes_FromStringAndSize(buffer + *position + 9, length - 4);
data = PyBytes_FromStringAndSize(buffer + *position + 4, length - 4);
} else {
data = PyBytes_FromStringAndSize(buffer + *position + 5, length);
data = PyBytes_FromStringAndSize(buffer + *position, length);
}
#else
if (subtype == 2) {
data = PyString_FromStringAndSize(buffer + *position + 9, length - 4);
data = PyString_FromStringAndSize(buffer + *position + 4, length - 4);
} else {
data = PyString_FromStringAndSize(buffer + *position + 5, length);
data = PyString_FromStringAndSize(buffer + *position, length);
}
#endif
if (!data) {
return NULL;
}
if ((subtype == 3 || subtype == 4) && state->UUID) { // Encode as UUID, not Binary
/* Encode as UUID, not Binary */
if ((subtype == 3 || subtype == 4) && state->UUID) {
PyObject* kwargs;
PyObject* args = PyTuple_New(0);
if (!args) {
/* UUID should always be 16 bytes */
if (!args || length != 16) {
Py_DECREF(data);
return NULL;
}
@ -1652,8 +1658,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
return NULL;
}
assert(length == 16); // UUID should always be 16 bytes
if (uuid_subtype == CSHARP_LEGACY) {
/* Legacy C# byte order */
if ((PyDict_SetItemString(kwargs, "bytes_le", data)) == -1)
@ -1663,7 +1667,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
if (uuid_subtype == JAVA_LEGACY) {
/* Convert from legacy java byte order */
char big_endian[16];
_fix_java(buffer + *position + 5, big_endian);
_fix_java(buffer + *position, big_endian);
/* Free the previously created PyString object */
Py_DECREF(data);
#if PY_MAJOR_VERSION >= 3
@ -1690,7 +1694,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
return NULL;
}
*position += length + 5;
*position += length;
break;
uuiderror:
@ -1718,7 +1722,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
if (!value) {
return NULL;
}
*position += length + 5;
*position += length;
break;
}
case 6:
@ -1866,7 +1870,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
}
memcpy(&coll_length, buffer + *position, 4);
/* Encoded string length + string + 12 byte ObjectId */
if (max < 4 + coll_length + 12) {
if (!coll_length || max < coll_length || max < 4 + coll_length + 12) {
goto invalid;
}
*position += 4;
@ -1913,7 +1917,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
}
memcpy(&value_length, buffer + *position, 4);
/* Encoded string length + string */
if (max < 4 + value_length) {
if (!value_length || max < value_length || max < 4 + value_length) {
goto invalid;
}
*position += 4;
@ -1955,7 +1959,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
memcpy(&code_size, buffer + *position, 4);
/* code_w_scope length + code length + code + scope length */
if (max < 4 + 4 + code_size + 4) {
if (!code_size || max < code_size || max < 4 + 4 + code_size + 4) {
goto invalid;
}
*position += 4;

View File

@ -40,7 +40,8 @@ from bson.dbref import DBRef
from bson.py3compat import b
from bson.son import SON
from bson.timestamp import Timestamp
from bson.errors import (InvalidDocument,
from bson.errors import (InvalidBSON,
InvalidDocument,
InvalidStringData)
from bson.max_key import MaxKey
from bson.min_key import MinKey
@ -84,6 +85,61 @@ class TestBSON(unittest.TestCase):
self.assertFalse(is_valid(b("\x10\x00\x00\x00\x02a\x00"
"\x04\x00\x00\x00abc\xff\x00")))
def test_bad_string_lengths(self):
def decode(bs):
bson.BSON(bs).decode()
self.assertRaises(InvalidBSON, decode,
b("\x0c\x00\x00\x00\x02\x00"
"\x00\x00\x00\x00\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x12\x00\x00\x00\x02\x00"
"\xff\xff\xff\xfffoobar\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x0c\x00\x00\x00\x0e\x00"
"\x00\x00\x00\x00\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x12\x00\x00\x00\x0e\x00"
"\xff\xff\xff\xfffoobar\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x18\x00\x00\x00\x0c\x00"
"\x00\x00\x00\x00\x00RY\xb5j"
"\xfa[\xd8A\xd6X]\x99\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x1e\x00\x00\x00\x0c\x00"
"\xff\xff\xff\xfffoobar\x00"
"RY\xb5j\xfa[\xd8A\xd6X]\x99\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x0c\x00\x00\x00\r\x00"
"\x00\x00\x00\x00\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x0c\x00\x00\x00\r\x00"
"\xff\xff\xff\xff\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x1c\x00\x00\x00\x0f\x00"
"\x15\x00\x00\x00\x00\x00"
"\x00\x00\x00\x0c\x00\x00"
"\x00\x02\x00\x01\x00\x00"
"\x00\x00\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x1c\x00\x00\x00\x0f\x00"
"\x15\x00\x00\x00\xff\xff"
"\xff\xff\x00\x0c\x00\x00"
"\x00\x02\x00\x01\x00\x00"
"\x00\x00\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x1c\x00\x00\x00\x0f\x00"
"\x15\x00\x00\x00\x01\x00"
"\x00\x00\x00\x0c\x00\x00"
"\x00\x02\x00\x00\x00\x00"
"\x00\x00\x00\x00"))
self.assertRaises(InvalidBSON, decode,
b("\x1c\x00\x00\x00\x0f\x00"
"\x15\x00\x00\x00\x01\x00"
"\x00\x00\x00\x0c\x00\x00"
"\x00\x02\x00\xff\xff\xff"
"\xff\x00\x00\x00"))
def test_random_data_is_not_bson(self):
qcheck.check_unittest(self, qcheck.isnt(is_valid),
qcheck.gen_string(qcheck.gen_range(0, 40)))