From 72a90b75c63c68f9ebe435a142ea9fb3f37e40db Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Mon, 21 Oct 2013 16:19:30 -0400 Subject: [PATCH] Wrap all BSON-decoding errors in InvalidBSON exception. PYTHON-494 --- bson/__init__.py | 29 ++++--- bson/_cbsonmodule.c | 119 +++++++++++++++++---------- test/test_bson.py | 191 +++++++++++++++++++++++++++----------------- 3 files changed, 214 insertions(+), 125 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index a6f993db5..e8d66d4a2 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -518,17 +518,24 @@ def decode_all(data, as_class=dict, docs = [] position = 0 end = len(data) - 1 - while position < end: - obj_size = struct.unpack("UUID) { @@ -1675,15 +1668,19 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio /* UUID should always be 16 bytes */ if (!args || length != 16) { Py_DECREF(data); - return NULL; + goto invalid; } kwargs = PyDict_New(); if (!kwargs) { Py_DECREF(data); Py_DECREF(args); - return NULL; + goto invalid; } + /* + * From this point, we hold refs to args, kwargs, and data. + * If anything fails, goto uuiderror to clean them up. + */ if (uuid_subtype == CSHARP_LEGACY) { /* Legacy C# byte order */ if ((PyDict_SetItemString(kwargs, "bytes_le", data)) == -1) @@ -1717,7 +1714,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio Py_DECREF(kwargs); Py_DECREF(data); if (!value) { - return NULL; + goto invalid; } *position += length; @@ -1727,7 +1724,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio Py_DECREF(args); Py_DECREF(kwargs); Py_XDECREF(data); - return NULL; + goto invalid; } #if PY_MAJOR_VERSION >= 3 @@ -1737,7 +1734,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio #endif if (!st) { Py_DECREF(data); - return NULL; + goto invalid; } if ((type_to_create = _get_object(state->Binary, "bson.binary", "Binary"))) { value = PyObject_CallFunctionObjArgs(type_to_create, data, st, NULL); @@ -1746,7 +1743,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio Py_DECREF(st); Py_DECREF(data); if (!value) { - return NULL; + goto invalid; } *position += length; break; @@ -1801,23 +1798,23 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio } if (!naive) { - return NULL; + goto invalid; } replace = PyObject_GetAttrString(naive, "replace"); Py_DECREF(naive); if (!replace) { - return NULL; + goto invalid; } args = PyTuple_New(0); if (!args) { Py_DECREF(replace); - return NULL; + goto invalid; } kwargs = PyDict_New(); if (!kwargs) { Py_DECREF(replace); Py_DECREF(args); - return NULL; + goto invalid; } utc_type = _get_object(state->UTC, "bson.tz_util", "UTC"); if (!utc_type || PyDict_SetItemString(kwargs, "tzinfo", utc_type) == -1) { @@ -1825,7 +1822,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio Py_DECREF(args); Py_DECREF(kwargs); Py_XDECREF(utc_type); - return NULL; + goto invalid; } Py_XDECREF(utc_type); value = PyObject_Call(replace, args, kwargs); @@ -1846,7 +1843,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio } pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict"); if (!pattern) { - return NULL; + goto invalid; } *position += (unsigned)pattern_length + 1; flags_length = strlen(buffer + *position); @@ -1919,7 +1916,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio collection = PyUnicode_DecodeUTF8(buffer + *position, coll_length - 1, "strict"); if (!collection) { - return NULL; + goto invalid; } *position += coll_length; @@ -1933,7 +1930,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio } if (!id) { Py_DECREF(collection); - return NULL; + goto invalid; } *position += 12; if ((dbref_type = _get_object(state->DBRef, "bson.dbref", "DBRef"))) { @@ -1964,7 +1961,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio } code = PyUnicode_DecodeUTF8(buffer + *position, value_length - 1, "strict"); if (!code) { - return NULL; + goto invalid; } *position += value_length; if ((code_type = _get_object(state->Code, "bson.code", "Code"))) { @@ -2006,7 +2003,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio } code = PyUnicode_DecodeUTF8(buffer + *position, code_size - 1, "strict"); if (!code) { - return NULL; + goto invalid; } *position += code_size; @@ -2030,7 +2027,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio tz_aware, uuid_subtype, compile_re); if (!scope) { Py_DECREF(code); - return NULL; + goto invalid; } *position += scope_size; @@ -2055,7 +2052,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio value = PyInt_FromLong(i); #endif if (!value) { - return NULL; + goto invalid; } *position += 4; break; @@ -2085,7 +2082,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio memcpy(&ll, buffer + *position, 8); value = PyLong_FromLongLong(ll); if (!value) { - return NULL; + goto invalid; } *position += 8; break; @@ -2094,7 +2091,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio { PyObject* minkey_type = _get_object(state->MinKey, "bson.min_key", "MinKey"); if (!minkey_type) - return NULL; + goto invalid; value = PyObject_CallFunctionObjArgs(minkey_type, NULL); Py_DECREF(minkey_type); break; @@ -2103,7 +2100,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio { PyObject* maxkey_type = _get_object(state->MaxKey, "bson.max_key", "MaxKey"); if (!maxkey_type) - return NULL; + goto invalid; value = PyObject_CallFunctionObjArgs(maxkey_type, NULL); Py_DECREF(maxkey_type); break; @@ -2116,18 +2113,58 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio "no c decoder for this type yet"); Py_DECREF(InvalidDocument); } - return NULL; + goto invalid; } } - return value; + + if (value) { + return value; + } invalid: - error = _error("InvalidBSON"); - if (error) { - PyErr_SetString(error, - "invalid length or type code"); - Py_DECREF(error); + /* + * Wrap any non-InvalidBSON errors in InvalidBSON. + */ + if (PyErr_Occurred()) { + PyObject *etype, *evalue, *etrace; + PyObject *InvalidBSON; + + /* + * Calling _error clears the error state, so fetch it first. + */ + PyErr_Fetch(&etype, &evalue, &etrace); + InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + if (!PyErr_GivenExceptionMatches(etype, InvalidBSON)) { + /* + * Raise InvalidBSON(str(e)). + */ + Py_DECREF(etype); + etype = InvalidBSON; + + if (evalue) { + PyObject *msg = PyObject_Str(evalue); + Py_DECREF(evalue); + evalue = msg; + } + PyErr_NormalizeException(&etype, &evalue, &etrace); + } else { + /* + * The current exception matches InvalidBSON, so we don't need + * this reference after all. + */ + Py_DECREF(InvalidBSON); + } + } + /* Steals references to args. */ + PyErr_Restore(etype, evalue, etrace); + } else { + PyObject *InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "invalid length or type code"); + Py_DECREF(InvalidBSON); + } } return NULL; } diff --git a/test/test_bson.py b/test/test_bson.py index 0b8ac1b40..bfea866c1 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -16,10 +16,11 @@ """Test the bson module.""" -import unittest import datetime import re import sys +import traceback +import unittest try: import uuid should_test_uuid = True @@ -55,91 +56,109 @@ PY3 = sys.version_info[0] == 3 class TestBSON(unittest.TestCase): + def assertInvalid(self, data, msg=None): + try: + bson.BSON(data).decode() + except InvalidBSON, e: + # Check a message is set. + self.assertTrue(len(e.args) > 0) + if msg: + self.assertEqual(msg, e.args[0]) + def test_basic_validation(self): self.assertRaises(TypeError, is_valid, 100) self.assertRaises(TypeError, is_valid, u"test") self.assertRaises(TypeError, is_valid, 10.4) - self.assertFalse(is_valid(b("test"))) + self.assertInvalid(b("test")) # the simplest valid BSON document self.assertTrue(is_valid(b("\x05\x00\x00\x00\x00"))) self.assertTrue(is_valid(BSON(b("\x05\x00\x00\x00\x00")))) # failure cases - self.assertFalse(is_valid(b("\x04\x00\x00\x00\x00"))) - self.assertFalse(is_valid(b("\x05\x00\x00\x00\x01"))) - self.assertFalse(is_valid(b("\x05\x00\x00\x00"))) - self.assertFalse(is_valid(b("\x05\x00\x00\x00\x00\x00"))) - self.assertFalse(is_valid(b("\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12"))) - self.assertFalse(is_valid(b("\x09\x00\x00\x00\x10a\x00\x05\x00"))) - self.assertFalse(is_valid(b("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"))) - self.assertFalse(is_valid(b("\x13\x00\x00\x00\x02foo\x00" - "\x04\x00\x00\x00bar\x00\x00"))) - self.assertFalse(is_valid(b("\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00" - "\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00"))) - self.assertFalse(is_valid(b("\x15\x00\x00\x00\x03foo\x00\x0c" - "\x00\x00\x00\x08bar\x00\x01\x00\x00"))) - self.assertFalse(is_valid(b("\x1c\x00\x00\x00\x03foo\x00" - "\x12\x00\x00\x00\x02bar\x00" - "\x05\x00\x00\x00baz\x00\x00\x00"))) - self.assertFalse(is_valid(b("\x10\x00\x00\x00\x02a\x00" - "\x04\x00\x00\x00abc\xff\x00"))) + self.assertInvalid(b("\x04\x00\x00\x00\x00"), + 'invalid message size') + self.assertInvalid(b("\x05\x00\x00\x00\x01"), + 'bad eoo') + self.assertInvalid(b("\x05\x00\x00\x00"), + 'not enough data for a BSON document') + self.assertInvalid(b("\x05\x00\x00\x00\x00\x00"), + 'bad eoo') + self.assertInvalid(b("\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12"), + 'bad eoo') + self.assertInvalid(b("\x09\x00\x00\x00\x10a\x00\x05\x00"), + 'invalid length or type code') + self.assertInvalid(b("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), + 'invalid message size') + self.assertInvalid(b("\x13\x00\x00\x00\x02foo\x00" + "\x04\x00\x00\x00bar\x00\x00"), + 'objsize too large') + self.assertInvalid(b("\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00" + "\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00"), + 'invalid length or type code') + self.assertInvalid(b("\x15\x00\x00\x00\x03foo\x00\x0c" + "\x00\x00\x00\x08bar\x00\x01\x00\x00"), + 'invalid length or type code') + self.assertInvalid(b("\x1c\x00\x00\x00\x03foo\x00" + "\x12\x00\x00\x00\x02bar\x00" + "\x05\x00\x00\x00baz\x00\x00\x00"), + 'invalid length or type code') + self.assertInvalid(b("\x10\x00\x00\x00\x02a\x00" + "\x04\x00\x00\x00abc\xff\x00"), + 'invalid length or type code') def test_bad_string_lengths(self): - def decode(bs): - bson.BSON(bs).decode() - - self.assertRaises(InvalidBSON, decode, - b("\x0c\x00\x00\x00\x02\x00" - "\x00\x00\x00\x00\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x12\x00\x00\x00\x02\x00" - "\xff\xff\xff\xfffoobar\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x0c\x00\x00\x00\x0e\x00" - "\x00\x00\x00\x00\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x12\x00\x00\x00\x0e\x00" - "\xff\xff\xff\xfffoobar\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x18\x00\x00\x00\x0c\x00" - "\x00\x00\x00\x00\x00RY\xb5j" - "\xfa[\xd8A\xd6X]\x99\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x1e\x00\x00\x00\x0c\x00" - "\xff\xff\xff\xfffoobar\x00" - "RY\xb5j\xfa[\xd8A\xd6X]\x99\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x0c\x00\x00\x00\r\x00" - "\x00\x00\x00\x00\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x0c\x00\x00\x00\r\x00" - "\xff\xff\xff\xff\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x1c\x00\x00\x00\x0f\x00" - "\x15\x00\x00\x00\x00\x00" - "\x00\x00\x00\x0c\x00\x00" - "\x00\x02\x00\x01\x00\x00" - "\x00\x00\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x1c\x00\x00\x00\x0f\x00" - "\x15\x00\x00\x00\xff\xff" - "\xff\xff\x00\x0c\x00\x00" - "\x00\x02\x00\x01\x00\x00" - "\x00\x00\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x1c\x00\x00\x00\x0f\x00" - "\x15\x00\x00\x00\x01\x00" - "\x00\x00\x00\x0c\x00\x00" - "\x00\x02\x00\x00\x00\x00" - "\x00\x00\x00\x00")) - self.assertRaises(InvalidBSON, decode, - b("\x1c\x00\x00\x00\x0f\x00" - "\x15\x00\x00\x00\x01\x00" - "\x00\x00\x00\x0c\x00\x00" - "\x00\x02\x00\xff\xff\xff" - "\xff\x00\x00\x00")) + self.assertInvalid( + b("\x0c\x00\x00\x00\x02\x00" + "\x00\x00\x00\x00\x00\x00")) + self.assertInvalid( + b("\x12\x00\x00\x00\x02\x00" + "\xff\xff\xff\xfffoobar\x00\x00")) + self.assertInvalid( + b("\x0c\x00\x00\x00\x0e\x00" + "\x00\x00\x00\x00\x00\x00")) + self.assertInvalid( + b("\x12\x00\x00\x00\x0e\x00" + "\xff\xff\xff\xfffoobar\x00\x00")) + self.assertInvalid( + b("\x18\x00\x00\x00\x0c\x00" + "\x00\x00\x00\x00\x00RY\xb5j" + "\xfa[\xd8A\xd6X]\x99\x00")) + self.assertInvalid( + b("\x1e\x00\x00\x00\x0c\x00" + "\xff\xff\xff\xfffoobar\x00" + "RY\xb5j\xfa[\xd8A\xd6X]\x99\x00")) + self.assertInvalid( + b("\x0c\x00\x00\x00\r\x00" + "\x00\x00\x00\x00\x00\x00")) + self.assertInvalid( + b("\x0c\x00\x00\x00\r\x00" + "\xff\xff\xff\xff\x00\x00")) + self.assertInvalid( + b("\x1c\x00\x00\x00\x0f\x00" + "\x15\x00\x00\x00\x00\x00" + "\x00\x00\x00\x0c\x00\x00" + "\x00\x02\x00\x01\x00\x00" + "\x00\x00\x00\x00")) + self.assertInvalid( + b("\x1c\x00\x00\x00\x0f\x00" + "\x15\x00\x00\x00\xff\xff" + "\xff\xff\x00\x0c\x00\x00" + "\x00\x02\x00\x01\x00\x00" + "\x00\x00\x00\x00")) + self.assertInvalid( + b("\x1c\x00\x00\x00\x0f\x00" + "\x15\x00\x00\x00\x01\x00" + "\x00\x00\x00\x0c\x00\x00" + "\x00\x02\x00\x00\x00\x00" + "\x00\x00\x00\x00")) + self.assertInvalid( + b("\x1c\x00\x00\x00\x0f\x00" + "\x15\x00\x00\x00\x01\x00" + "\x00\x00\x00\x0c\x00\x00" + "\x00\x02\x00\xff\xff\xff" + "\xff\x00\x00\x00")) def test_random_data_is_not_bson(self): qcheck.check_unittest(self, qcheck.isnt(is_valid), @@ -571,6 +590,32 @@ class TestBSON(unittest.TestCase): self.assertEqual( doc2_with_bson_re, BSON(doc2_bson).decode(compile_re=False)) + def test_exception_wrapping(self): + # No matter what exception is raised while trying to decode BSON, + # the final exception always matches InvalidBSON and the original + # traceback is preserved. + + # Invalid Python regex, though valid PCRE. + # Causes an error in re.compile(). + bad_doc = BSON.encode({'r': Regex(r'[\w-\.]')}) + + try: + decode_all(bad_doc) + except InvalidBSON: + exc_type, exc_value, exc_tb = sys.exc_info() + # Original re error was captured and wrapped in InvalidBSON. + self.assertEqual(exc_value.args[0], 'bad character range') + + # Traceback includes bson module's call into re module. + for filename, lineno, fname, text in traceback.extract_tb(exc_tb): + if filename.endswith('re.py') and fname == 'compile': + # Traceback was correctly preserved. + break + else: + self.fail('Traceback not captured') + else: + self.fail('InvalidBSON not raised') + if __name__ == "__main__": unittest.main()