Avoid unsigned overflow wrapping PYTHON-571
This commit is contained in:
parent
2b1ea1fdeb
commit
989a0f4f4d
@ -140,7 +140,7 @@ def _get_number(data, position, as_class, tz_aware, uuid_subtype):
|
||||
|
||||
def _get_string(data, position, as_class, tz_aware, uuid_subtype):
|
||||
length = struct.unpack("<i", data[position:position + 4])[0]
|
||||
if (len(data) - position - 4) < length:
|
||||
if length <= 0 or (len(data) - position - 4) < length:
|
||||
raise InvalidBSON("invalid string length")
|
||||
position += 4
|
||||
if data[position + length - 1:position + length] != ZERO:
|
||||
|
||||
@ -1464,7 +1464,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
}
|
||||
memcpy(&value_length, buffer + *position, 4);
|
||||
/* Encoded string length + string */
|
||||
if (max < 4 + value_length) {
|
||||
if (!value_length || max < value_length || max < 4 + value_length) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
@ -1555,7 +1555,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
goto invalid;
|
||||
}
|
||||
memcpy(&size, buffer + *position, 4);
|
||||
if (max < size) {
|
||||
if (size < BSON_MIN_SIZE || max < size) {
|
||||
goto invalid;
|
||||
}
|
||||
end = *position + size - 1;
|
||||
@ -1607,7 +1607,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
unsigned length;
|
||||
unsigned char subtype;
|
||||
|
||||
if (max < 4) {
|
||||
if (max < 5) {
|
||||
goto invalid;
|
||||
}
|
||||
memcpy(&length, buffer + *position, 4);
|
||||
@ -1616,32 +1616,38 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
}
|
||||
|
||||
subtype = (unsigned char)buffer[*position + 4];
|
||||
*position += 5;
|
||||
if (subtype == 2 && length < 4) {
|
||||
goto invalid;
|
||||
}
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
/* Python3 special case. Decode BSON binary subtype 0 to bytes. */
|
||||
if (subtype == 0) {
|
||||
value = PyBytes_FromStringAndSize(buffer + *position + 5, length);
|
||||
*position += length + 5;
|
||||
value = PyBytes_FromStringAndSize(buffer + *position, length);
|
||||
*position += length;
|
||||
break;
|
||||
}
|
||||
if (subtype == 2) {
|
||||
data = PyBytes_FromStringAndSize(buffer + *position + 9, length - 4);
|
||||
data = PyBytes_FromStringAndSize(buffer + *position + 4, length - 4);
|
||||
} else {
|
||||
data = PyBytes_FromStringAndSize(buffer + *position + 5, length);
|
||||
data = PyBytes_FromStringAndSize(buffer + *position, length);
|
||||
}
|
||||
#else
|
||||
if (subtype == 2) {
|
||||
data = PyString_FromStringAndSize(buffer + *position + 9, length - 4);
|
||||
data = PyString_FromStringAndSize(buffer + *position + 4, length - 4);
|
||||
} else {
|
||||
data = PyString_FromStringAndSize(buffer + *position + 5, length);
|
||||
data = PyString_FromStringAndSize(buffer + *position, length);
|
||||
}
|
||||
#endif
|
||||
if (!data) {
|
||||
return NULL;
|
||||
}
|
||||
if ((subtype == 3 || subtype == 4) && state->UUID) { // Encode as UUID, not Binary
|
||||
/* Encode as UUID, not Binary */
|
||||
if ((subtype == 3 || subtype == 4) && state->UUID) {
|
||||
PyObject* kwargs;
|
||||
PyObject* args = PyTuple_New(0);
|
||||
if (!args) {
|
||||
/* UUID should always be 16 bytes */
|
||||
if (!args || length != 16) {
|
||||
Py_DECREF(data);
|
||||
return NULL;
|
||||
}
|
||||
@ -1652,8 +1658,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
return NULL;
|
||||
}
|
||||
|
||||
assert(length == 16); // UUID should always be 16 bytes
|
||||
|
||||
if (uuid_subtype == CSHARP_LEGACY) {
|
||||
/* Legacy C# byte order */
|
||||
if ((PyDict_SetItemString(kwargs, "bytes_le", data)) == -1)
|
||||
@ -1663,7 +1667,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
if (uuid_subtype == JAVA_LEGACY) {
|
||||
/* Convert from legacy java byte order */
|
||||
char big_endian[16];
|
||||
_fix_java(buffer + *position + 5, big_endian);
|
||||
_fix_java(buffer + *position, big_endian);
|
||||
/* Free the previously created PyString object */
|
||||
Py_DECREF(data);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
@ -1690,7 +1694,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
return NULL;
|
||||
}
|
||||
|
||||
*position += length + 5;
|
||||
*position += length;
|
||||
break;
|
||||
|
||||
uuiderror:
|
||||
@ -1718,7 +1722,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
*position += length + 5;
|
||||
*position += length;
|
||||
break;
|
||||
}
|
||||
case 6:
|
||||
@ -1866,7 +1870,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
}
|
||||
memcpy(&coll_length, buffer + *position, 4);
|
||||
/* Encoded string length + string + 12 byte ObjectId */
|
||||
if (max < 4 + coll_length + 12) {
|
||||
if (!coll_length || max < coll_length || max < 4 + coll_length + 12) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
@ -1913,7 +1917,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
}
|
||||
memcpy(&value_length, buffer + *position, 4);
|
||||
/* Encoded string length + string */
|
||||
if (max < 4 + value_length) {
|
||||
if (!value_length || max < value_length || max < 4 + value_length) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
@ -1955,7 +1959,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
|
||||
|
||||
memcpy(&code_size, buffer + *position, 4);
|
||||
/* code_w_scope length + code length + code + scope length */
|
||||
if (max < 4 + 4 + code_size + 4) {
|
||||
if (!code_size || max < code_size || max < 4 + 4 + code_size + 4) {
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
|
||||
@ -40,7 +40,8 @@ from bson.dbref import DBRef
|
||||
from bson.py3compat import b
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
from bson.errors import (InvalidDocument,
|
||||
from bson.errors import (InvalidBSON,
|
||||
InvalidDocument,
|
||||
InvalidStringData)
|
||||
from bson.max_key import MaxKey
|
||||
from bson.min_key import MinKey
|
||||
@ -84,6 +85,61 @@ class TestBSON(unittest.TestCase):
|
||||
self.assertFalse(is_valid(b("\x10\x00\x00\x00\x02a\x00"
|
||||
"\x04\x00\x00\x00abc\xff\x00")))
|
||||
|
||||
def test_bad_string_lengths(self):
|
||||
def decode(bs):
|
||||
bson.BSON(bs).decode()
|
||||
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x0c\x00\x00\x00\x02\x00"
|
||||
"\x00\x00\x00\x00\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x12\x00\x00\x00\x02\x00"
|
||||
"\xff\xff\xff\xfffoobar\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x0c\x00\x00\x00\x0e\x00"
|
||||
"\x00\x00\x00\x00\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x12\x00\x00\x00\x0e\x00"
|
||||
"\xff\xff\xff\xfffoobar\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x18\x00\x00\x00\x0c\x00"
|
||||
"\x00\x00\x00\x00\x00RY\xb5j"
|
||||
"\xfa[\xd8A\xd6X]\x99\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x1e\x00\x00\x00\x0c\x00"
|
||||
"\xff\xff\xff\xfffoobar\x00"
|
||||
"RY\xb5j\xfa[\xd8A\xd6X]\x99\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x0c\x00\x00\x00\r\x00"
|
||||
"\x00\x00\x00\x00\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x0c\x00\x00\x00\r\x00"
|
||||
"\xff\xff\xff\xff\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x1c\x00\x00\x00\x0f\x00"
|
||||
"\x15\x00\x00\x00\x00\x00"
|
||||
"\x00\x00\x00\x0c\x00\x00"
|
||||
"\x00\x02\x00\x01\x00\x00"
|
||||
"\x00\x00\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x1c\x00\x00\x00\x0f\x00"
|
||||
"\x15\x00\x00\x00\xff\xff"
|
||||
"\xff\xff\x00\x0c\x00\x00"
|
||||
"\x00\x02\x00\x01\x00\x00"
|
||||
"\x00\x00\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x1c\x00\x00\x00\x0f\x00"
|
||||
"\x15\x00\x00\x00\x01\x00"
|
||||
"\x00\x00\x00\x0c\x00\x00"
|
||||
"\x00\x02\x00\x00\x00\x00"
|
||||
"\x00\x00\x00\x00"))
|
||||
self.assertRaises(InvalidBSON, decode,
|
||||
b("\x1c\x00\x00\x00\x0f\x00"
|
||||
"\x15\x00\x00\x00\x01\x00"
|
||||
"\x00\x00\x00\x0c\x00\x00"
|
||||
"\x00\x02\x00\xff\xff\xff"
|
||||
"\xff\x00\x00\x00"))
|
||||
|
||||
def test_random_data_is_not_bson(self):
|
||||
qcheck.check_unittest(self, qcheck.isnt(is_valid),
|
||||
qcheck.gen_string(qcheck.gen_range(0, 40)))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user