From 989a0f4f4d5a609b10cf6a9df9ac7bc834070988 Mon Sep 17 00:00:00 2001 From: Bernie Hackett Date: Tue, 15 Oct 2013 17:25:16 -0700 Subject: [PATCH] Avoid unsigned overflow wrapping PYTHON-571 --- bson/__init__.py | 2 +- bson/_cbsonmodule.c | 42 +++++++++++++++++--------------- test/test_bson.py | 58 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 21 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index 650220f9c..3ad2cd3f9 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -140,7 +140,7 @@ def _get_number(data, position, as_class, tz_aware, uuid_subtype): def _get_string(data, position, as_class, tz_aware, uuid_subtype): length = struct.unpack("= 3 /* Python3 special case. Decode BSON binary subtype 0 to bytes. */ if (subtype == 0) { - value = PyBytes_FromStringAndSize(buffer + *position + 5, length); - *position += length + 5; + value = PyBytes_FromStringAndSize(buffer + *position, length); + *position += length; break; } if (subtype == 2) { - data = PyBytes_FromStringAndSize(buffer + *position + 9, length - 4); + data = PyBytes_FromStringAndSize(buffer + *position + 4, length - 4); } else { - data = PyBytes_FromStringAndSize(buffer + *position + 5, length); + data = PyBytes_FromStringAndSize(buffer + *position, length); } #else if (subtype == 2) { - data = PyString_FromStringAndSize(buffer + *position + 9, length - 4); + data = PyString_FromStringAndSize(buffer + *position + 4, length - 4); } else { - data = PyString_FromStringAndSize(buffer + *position + 5, length); + data = PyString_FromStringAndSize(buffer + *position, length); } #endif if (!data) { return NULL; } - if ((subtype == 3 || subtype == 4) && state->UUID) { // Encode as UUID, not Binary + /* Encode as UUID, not Binary */ + if ((subtype == 3 || subtype == 4) && state->UUID) { PyObject* kwargs; PyObject* args = PyTuple_New(0); - if (!args) { + /* UUID should always be 16 bytes */ + if (!args || length != 16) { Py_DECREF(data); return NULL; } @@ -1652,8 +1658,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio return NULL; } - assert(length == 16); // UUID should always be 16 bytes - if (uuid_subtype == CSHARP_LEGACY) { /* Legacy C# byte order */ if ((PyDict_SetItemString(kwargs, "bytes_le", data)) == -1) @@ -1663,7 +1667,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio if (uuid_subtype == JAVA_LEGACY) { /* Convert from legacy java byte order */ char big_endian[16]; - _fix_java(buffer + *position + 5, big_endian); + _fix_java(buffer + *position, big_endian); /* Free the previously created PyString object */ Py_DECREF(data); #if PY_MAJOR_VERSION >= 3 @@ -1690,7 +1694,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio return NULL; } - *position += length + 5; + *position += length; break; uuiderror: @@ -1718,7 +1722,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio if (!value) { return NULL; } - *position += length + 5; + *position += length; break; } case 6: @@ -1866,7 +1870,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio } memcpy(&coll_length, buffer + *position, 4); /* Encoded string length + string + 12 byte ObjectId */ - if (max < 4 + coll_length + 12) { + if (!coll_length || max < coll_length || max < 4 + coll_length + 12) { goto invalid; } *position += 4; @@ -1913,7 +1917,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio } memcpy(&value_length, buffer + *position, 4); /* Encoded string length + string */ - if (max < 4 + value_length) { + if (!value_length || max < value_length || max < 4 + value_length) { goto invalid; } *position += 4; @@ -1955,7 +1959,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio memcpy(&code_size, buffer + *position, 4); /* code_w_scope length + code length + code + scope length */ - if (max < 4 + 4 + code_size + 4) { + if (!code_size || max < code_size || max < 4 + 4 + code_size + 4) { goto invalid; } *position += 4; diff --git a/test/test_bson.py b/test/test_bson.py index 99ef945d7..392f6f541 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -40,7 +40,8 @@ from bson.dbref import DBRef from bson.py3compat import b from bson.son import SON from bson.timestamp import Timestamp -from bson.errors import (InvalidDocument, +from bson.errors import (InvalidBSON, + InvalidDocument, InvalidStringData) from bson.max_key import MaxKey from bson.min_key import MinKey @@ -84,6 +85,61 @@ class TestBSON(unittest.TestCase): self.assertFalse(is_valid(b("\x10\x00\x00\x00\x02a\x00" "\x04\x00\x00\x00abc\xff\x00"))) + def test_bad_string_lengths(self): + def decode(bs): + bson.BSON(bs).decode() + + self.assertRaises(InvalidBSON, decode, + b("\x0c\x00\x00\x00\x02\x00" + "\x00\x00\x00\x00\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x12\x00\x00\x00\x02\x00" + "\xff\xff\xff\xfffoobar\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x0c\x00\x00\x00\x0e\x00" + "\x00\x00\x00\x00\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x12\x00\x00\x00\x0e\x00" + "\xff\xff\xff\xfffoobar\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x18\x00\x00\x00\x0c\x00" + "\x00\x00\x00\x00\x00RY\xb5j" + "\xfa[\xd8A\xd6X]\x99\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x1e\x00\x00\x00\x0c\x00" + "\xff\xff\xff\xfffoobar\x00" + "RY\xb5j\xfa[\xd8A\xd6X]\x99\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x0c\x00\x00\x00\r\x00" + "\x00\x00\x00\x00\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x0c\x00\x00\x00\r\x00" + "\xff\xff\xff\xff\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x1c\x00\x00\x00\x0f\x00" + "\x15\x00\x00\x00\x00\x00" + "\x00\x00\x00\x0c\x00\x00" + "\x00\x02\x00\x01\x00\x00" + "\x00\x00\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x1c\x00\x00\x00\x0f\x00" + "\x15\x00\x00\x00\xff\xff" + "\xff\xff\x00\x0c\x00\x00" + "\x00\x02\x00\x01\x00\x00" + "\x00\x00\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x1c\x00\x00\x00\x0f\x00" + "\x15\x00\x00\x00\x01\x00" + "\x00\x00\x00\x0c\x00\x00" + "\x00\x02\x00\x00\x00\x00" + "\x00\x00\x00\x00")) + self.assertRaises(InvalidBSON, decode, + b("\x1c\x00\x00\x00\x0f\x00" + "\x15\x00\x00\x00\x01\x00" + "\x00\x00\x00\x0c\x00\x00" + "\x00\x02\x00\xff\xff\xff" + "\xff\x00\x00\x00")) + def test_random_data_is_not_bson(self): qcheck.check_unittest(self, qcheck.isnt(is_valid), qcheck.gen_string(qcheck.gen_range(0, 40)))