Capture all BSON decode errors and wrap with InvalidBSON. PYTHON-494
This commit is contained in:
parent
fecbbee710
commit
ba66a2dde7
@ -484,7 +484,6 @@ if _use_c:
|
||||
_dict_to_bson = _cbson._dict_to_bson
|
||||
|
||||
|
||||
|
||||
def decode_all(data, as_class=dict,
|
||||
tz_aware=True, uuid_subtype=OLD_UUID_SUBTYPE):
|
||||
"""Decode BSON data to multiple documents.
|
||||
@ -504,17 +503,24 @@ def decode_all(data, as_class=dict,
|
||||
docs = []
|
||||
position = 0
|
||||
end = len(data) - 1
|
||||
while position < end:
|
||||
obj_size = struct.unpack("<i", data[position:position + 4])[0]
|
||||
if len(data) - position < obj_size:
|
||||
raise InvalidBSON("objsize too large")
|
||||
if data[position + obj_size - 1:position + obj_size] != ZERO:
|
||||
raise InvalidBSON("bad eoo")
|
||||
elements = data[position + 4:position + obj_size - 1]
|
||||
position += obj_size
|
||||
docs.append(_elements_to_dict(elements, as_class,
|
||||
tz_aware, uuid_subtype))
|
||||
return docs
|
||||
try:
|
||||
while position < end:
|
||||
obj_size = struct.unpack("<i", data[position:position + 4])[0]
|
||||
if len(data) - position < obj_size:
|
||||
raise InvalidBSON("objsize too large")
|
||||
if data[position + obj_size - 1:position + obj_size] != ZERO:
|
||||
raise InvalidBSON("bad eoo")
|
||||
elements = data[position + 4:position + obj_size - 1]
|
||||
position += obj_size
|
||||
docs.append(_elements_to_dict(elements, as_class,
|
||||
tz_aware, uuid_subtype))
|
||||
return docs
|
||||
except InvalidBSON:
|
||||
raise
|
||||
except Exception:
|
||||
# Change exception type to InvalidBSON but preserve traceback.
|
||||
exc_type, exc_value, exc_tb = sys.exc_info()
|
||||
raise InvalidBSON, InvalidBSON(str(exc_value)), exc_tb
|
||||
if _use_c:
|
||||
decode_all = _cbson.decode_all
|
||||
|
||||
|
||||
@ -1347,8 +1347,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
unsigned char tz_aware, unsigned char uuid_subtype) {
|
||||
struct module_state *state = GETSTATE(self);
|
||||
|
||||
PyObject* value;
|
||||
PyObject* error;
|
||||
PyObject* value = NULL;
|
||||
switch (type) {
|
||||
case 1:
|
||||
{
|
||||
@ -1358,9 +1357,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
memcpy(&d, buffer + *position, 8);
|
||||
value = PyFloat_FromDouble(d);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
*position += 8;
|
||||
break;
|
||||
}
|
||||
@ -1373,9 +1369,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
*position += 4;
|
||||
value = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
*position += value_length + 1;
|
||||
break;
|
||||
}
|
||||
@ -1389,10 +1382,10 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
value = elements_to_dict(self, buffer + *position + 4,
|
||||
size - 5, as_class, tz_aware, uuid_subtype);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!value) {
|
||||
goto invalid;
|
||||
}
|
||||
/* Decoding for DBRefs */
|
||||
collection = PyDict_GetItemString(value, "$ref");
|
||||
if (collection) { /* DBRef */
|
||||
@ -1428,9 +1421,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
Py_DECREF(id);
|
||||
Py_DECREF(collection);
|
||||
Py_DECREF(database);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
*position += size;
|
||||
@ -1450,7 +1440,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
|
||||
value = PyList_New(0);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
while (*position < end) {
|
||||
PyObject* to_append;
|
||||
@ -1467,7 +1457,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
max - (int)key_size, as_class, tz_aware, uuid_subtype);
|
||||
if (!to_append) {
|
||||
Py_DECREF(value);
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
PyList_Append(value, to_append);
|
||||
Py_DECREF(to_append);
|
||||
@ -1506,20 +1496,20 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
#endif
|
||||
if (!data) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
if ((subtype == 3 || subtype == 4) && state->UUID) { // Encode as UUID, not Binary
|
||||
PyObject* kwargs;
|
||||
PyObject* args = PyTuple_New(0);
|
||||
if (!args) {
|
||||
Py_DECREF(data);
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
kwargs = PyDict_New();
|
||||
if (!kwargs) {
|
||||
Py_DECREF(data);
|
||||
Py_DECREF(args);
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
|
||||
assert(length == 16); // UUID should always be 16 bytes
|
||||
@ -1553,10 +1543,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
Py_DECREF(args);
|
||||
Py_DECREF(kwargs);
|
||||
Py_DECREF(data);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
*position += length + 5;
|
||||
break;
|
||||
|
||||
@ -1579,9 +1565,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
value = PyObject_CallFunctionObjArgs(state->Binary, data, st, NULL);
|
||||
Py_DECREF(st);
|
||||
Py_DECREF(data);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
*position += length + 5;
|
||||
break;
|
||||
}
|
||||
@ -1602,9 +1585,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
#else
|
||||
value = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12);
|
||||
#endif
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
*position += 12;
|
||||
break;
|
||||
}
|
||||
@ -1631,29 +1611,29 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
|
||||
if (!naive) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
replace = PyObject_GetAttrString(naive, "replace");
|
||||
Py_DECREF(naive);
|
||||
if (!replace) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
args = PyTuple_New(0);
|
||||
if (!args) {
|
||||
Py_DECREF(replace);
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
kwargs = PyDict_New();
|
||||
if (!kwargs) {
|
||||
Py_DECREF(replace);
|
||||
Py_DECREF(args);
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
if (PyDict_SetItemString(kwargs, "tzinfo", state->UTC) == -1) {
|
||||
Py_DECREF(replace);
|
||||
Py_DECREF(args);
|
||||
Py_DECREF(kwargs);
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
value = PyObject_Call(replace, args, kwargs);
|
||||
Py_DECREF(replace);
|
||||
@ -1672,7 +1652,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict");
|
||||
if (!pattern) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += (int)pattern_length + 1;
|
||||
if ((flags_length = strlen(buffer + *position)) > BSON_MAX_SIZE) {
|
||||
@ -1718,14 +1698,14 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
collection = PyUnicode_DecodeUTF8(buffer + *position,
|
||||
coll_length, "strict");
|
||||
if (!collection) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += (int)coll_length + 1;
|
||||
|
||||
id = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12);
|
||||
if (!id) {
|
||||
Py_DECREF(collection);
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += 12;
|
||||
value = PyObject_CallFunctionObjArgs(state->DBRef, collection, id, NULL);
|
||||
@ -1743,7 +1723,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
*position += 4;
|
||||
code = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
|
||||
if (!code) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += value_length + 1;
|
||||
value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL);
|
||||
@ -1764,7 +1744,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
}
|
||||
code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict");
|
||||
if (!code) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += (int)code_length + 1;
|
||||
|
||||
@ -1773,7 +1753,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
(PyObject*)&PyDict_Type, tz_aware, uuid_subtype);
|
||||
if (!scope) {
|
||||
Py_DECREF(code);
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += scope_size;
|
||||
|
||||
@ -1795,7 +1775,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
value = PyInt_FromLong(i);
|
||||
#endif
|
||||
if (!value) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += 4;
|
||||
break;
|
||||
@ -1810,7 +1790,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
memcpy(&time, buffer + *position + 4, 4);
|
||||
value = PyObject_CallFunction(state->Timestamp, "II", time, inc);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += 8;
|
||||
break;
|
||||
@ -1824,7 +1804,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
memcpy(&ll, buffer + *position, 8);
|
||||
value = PyLong_FromLongLong(ll);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
goto invalid;
|
||||
}
|
||||
*position += 8;
|
||||
break;
|
||||
@ -1850,14 +1830,45 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
return value;
|
||||
|
||||
if (value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
invalid:
|
||||
|
||||
error = _error("InvalidBSON");
|
||||
if (error) {
|
||||
PyErr_SetNone(error);
|
||||
Py_DECREF(error);
|
||||
/* Wrap any non-InvalidBSON errors in InvalidBSON. */
|
||||
if (PyErr_Occurred()) {
|
||||
/* Calling _error clears the error state, so fetch it first. */
|
||||
PyObject *etype, *evalue, *etrace, *InvalidBSON;
|
||||
PyErr_Fetch(&etype, &evalue, &etrace);
|
||||
InvalidBSON = _error("InvalidBSON");
|
||||
if (InvalidBSON) {
|
||||
if (!PyErr_GivenExceptionMatches(etype, InvalidBSON)) {
|
||||
/* Raise InvalidBSON(str(e)). */
|
||||
PyObject *msg = NULL;
|
||||
Py_DECREF(etype);
|
||||
etype = InvalidBSON;
|
||||
|
||||
if (evalue) {
|
||||
msg = PyObject_Str(evalue);
|
||||
Py_DECREF(evalue);
|
||||
evalue = msg;
|
||||
}
|
||||
PyErr_NormalizeException(&etype, &evalue, &etrace);
|
||||
Py_XDECREF(msg);
|
||||
}
|
||||
}
|
||||
/* Steals references to args. */
|
||||
PyErr_Restore(etype, evalue, etrace);
|
||||
Py_XDECREF(InvalidBSON);
|
||||
return NULL;
|
||||
} else {
|
||||
PyObject *InvalidBSON = _error("InvalidBSON");
|
||||
if (InvalidBSON) {
|
||||
PyErr_SetNone(InvalidBSON);
|
||||
Py_DECREF(InvalidBSON);
|
||||
}
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@ -16,10 +16,11 @@
|
||||
|
||||
"""Test the bson module."""
|
||||
|
||||
import unittest
|
||||
import datetime
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import unittest
|
||||
try:
|
||||
import uuid
|
||||
should_test_uuid = True
|
||||
@ -41,7 +42,8 @@ from bson.py3compat import b
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
from bson.errors import (InvalidDocument,
|
||||
InvalidStringData)
|
||||
InvalidStringData,
|
||||
InvalidBSON)
|
||||
from bson.max_key import MaxKey
|
||||
from bson.min_key import MinKey
|
||||
from bson.tz_util import (FixedOffset,
|
||||
@ -448,5 +450,32 @@ class TestBSON(unittest.TestCase):
|
||||
d = OrderedDict([("one", 1), ("two", 2), ("three", 3), ("four", 4)])
|
||||
self.assertEqual(d, BSON.encode(d).decode(as_class=OrderedDict))
|
||||
|
||||
def test_exception_wrapping(self):
|
||||
# No matter what exception is raised while trying to decode BSON,
|
||||
# the final exception always matches InvalidBSON and the original
|
||||
# is traceback preserved.
|
||||
|
||||
# Invalid Python regex, though valid PCRE: {'r': /[\w-\.]/}
|
||||
# Will cause an error in re.compile().
|
||||
bad_doc = b('"\x00\x00\x00\x07_id\x00R\x013\xd4S1\xe3\xd3\xd6Sgs'
|
||||
'\x0br\x00[\\w-\\.]\x00\x00\x00')
|
||||
|
||||
try:
|
||||
decode_all(bad_doc)
|
||||
except InvalidBSON:
|
||||
exc_type, exc_value, exc_tb = sys.exc_info()
|
||||
# Original re error was captured and wrapped in InvalidBSON.
|
||||
self.assertEqual(exc_value.args[0], 'bad character range')
|
||||
|
||||
# Traceback includes bson module's call into re module.
|
||||
for filename, lineno, fname, text in traceback.extract_tb(exc_tb):
|
||||
if filename.endswith('re.py') and fname == 'compile':
|
||||
# Traceback was correctly preserved.
|
||||
break
|
||||
else:
|
||||
self.fail('Traceback not captured')
|
||||
else:
|
||||
self.fail('InvalidBSON not raised')
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user