Revert "Capture all BSON decode errors and wrap with InvalidBSON. PYTHON-494"

This reverts commit ba66a2dde7.
This commit is contained in:
A. Jesse Jiryu Davis 2013-08-17 22:35:17 -04:00
parent 9d9ac1c1da
commit d133396162
3 changed files with 61 additions and 107 deletions

View File

@ -484,6 +484,7 @@ if _use_c:
_dict_to_bson = _cbson._dict_to_bson
def decode_all(data, as_class=dict,
tz_aware=True, uuid_subtype=OLD_UUID_SUBTYPE):
"""Decode BSON data to multiple documents.
@ -503,25 +504,17 @@ def decode_all(data, as_class=dict,
docs = []
position = 0
end = len(data) - 1
try:
while position < end:
obj_size = struct.unpack("<i", data[position:position + 4])[0]
if len(data) - position < obj_size:
raise InvalidBSON("objsize too large")
if data[position + obj_size - 1:position + obj_size] != ZERO:
raise InvalidBSON("bad eoo")
elements = data[position + 4:position + obj_size - 1]
position += obj_size
docs.append(_elements_to_dict(elements, as_class,
tz_aware, uuid_subtype))
return docs
except InvalidBSON:
raise
except Exception:
# Change exception type to InvalidBSON but preserve traceback.
exc_type, exc_value, exc_tb = sys.exc_info()
raise InvalidBSON, str(exc_value), exc_tb
while position < end:
obj_size = struct.unpack("<i", data[position:position + 4])[0]
if len(data) - position < obj_size:
raise InvalidBSON("objsize too large")
if data[position + obj_size - 1:position + obj_size] != ZERO:
raise InvalidBSON("bad eoo")
elements = data[position + 4:position + obj_size - 1]
position += obj_size
docs.append(_elements_to_dict(elements, as_class,
tz_aware, uuid_subtype))
return docs
if _use_c:
decode_all = _cbson.decode_all

View File

@ -1336,7 +1336,8 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
unsigned char tz_aware, unsigned char uuid_subtype) {
struct module_state *state = GETSTATE(self);
PyObject* value = NULL;
PyObject* value;
PyObject* error;
switch (type) {
case 1:
{
@ -1346,6 +1347,9 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
memcpy(&d, buffer + *position, 8);
value = PyFloat_FromDouble(d);
if (!value) {
return NULL;
}
*position += 8;
break;
}
@ -1358,6 +1362,9 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
*position += 4;
value = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
if (!value) {
return NULL;
}
*position += value_length + 1;
break;
}
@ -1371,10 +1378,10 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
value = elements_to_dict(self, buffer + *position + 4,
size - 5, as_class, tz_aware, uuid_subtype);
if (!value) {
goto invalid;
return NULL;
}
/* Decoding for DBRefs */
collection = PyDict_GetItemString(value, "$ref");
if (collection) { /* DBRef */
@ -1410,6 +1417,9 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
Py_DECREF(id);
Py_DECREF(collection);
Py_DECREF(database);
if (!value) {
return NULL;
}
}
*position += size;
@ -1429,7 +1439,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
value = PyList_New(0);
if (!value) {
goto invalid;
return NULL;
}
while (*position < end) {
PyObject* to_append;
@ -1446,7 +1456,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
max - (int)key_size, as_class, tz_aware, uuid_subtype);
if (!to_append) {
Py_DECREF(value);
goto invalid;
return NULL;
}
PyList_Append(value, to_append);
Py_DECREF(to_append);
@ -1485,20 +1495,20 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
#endif
if (!data) {
goto invalid;
return NULL;
}
if ((subtype == 3 || subtype == 4) && state->UUID) { // Encode as UUID, not Binary
PyObject* kwargs;
PyObject* args = PyTuple_New(0);
if (!args) {
Py_DECREF(data);
goto invalid;
return NULL;
}
kwargs = PyDict_New();
if (!kwargs) {
Py_DECREF(data);
Py_DECREF(args);
goto invalid;
return NULL;
}
assert(length == 16); // UUID should always be 16 bytes
@ -1532,6 +1542,10 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
Py_DECREF(args);
Py_DECREF(kwargs);
Py_DECREF(data);
if (!value) {
return NULL;
}
*position += length + 5;
break;
@ -1554,6 +1568,9 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
value = PyObject_CallFunctionObjArgs(state->Binary, data, st, NULL);
Py_DECREF(st);
Py_DECREF(data);
if (!value) {
return NULL;
}
*position += length + 5;
break;
}
@ -1574,6 +1591,9 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
#else
value = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12);
#endif
if (!value) {
return NULL;
}
*position += 12;
break;
}
@ -1600,29 +1620,29 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
if (!naive) {
goto invalid;
return NULL;
}
replace = PyObject_GetAttrString(naive, "replace");
Py_DECREF(naive);
if (!replace) {
goto invalid;
return NULL;
}
args = PyTuple_New(0);
if (!args) {
Py_DECREF(replace);
goto invalid;
return NULL;
}
kwargs = PyDict_New();
if (!kwargs) {
Py_DECREF(replace);
Py_DECREF(args);
goto invalid;
return NULL;
}
if (PyDict_SetItemString(kwargs, "tzinfo", state->UTC) == -1) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
return NULL;
}
value = PyObject_Call(replace, args, kwargs);
Py_DECREF(replace);
@ -1641,7 +1661,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict");
if (!pattern) {
goto invalid;
return NULL;
}
*position += (int)pattern_length + 1;
if ((flags_length = strlen(buffer + *position)) > BSON_MAX_SIZE) {
@ -1687,14 +1707,14 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
collection = PyUnicode_DecodeUTF8(buffer + *position,
coll_length, "strict");
if (!collection) {
goto invalid;
return NULL;
}
*position += (int)coll_length + 1;
id = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12);
if (!id) {
Py_DECREF(collection);
goto invalid;
return NULL;
}
*position += 12;
value = PyObject_CallFunctionObjArgs(state->DBRef, collection, id, NULL);
@ -1712,7 +1732,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
*position += 4;
code = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
if (!code) {
goto invalid;
return NULL;
}
*position += value_length + 1;
value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL);
@ -1733,7 +1753,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
}
code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict");
if (!code) {
goto invalid;
return NULL;
}
*position += (int)code_length + 1;
@ -1742,7 +1762,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
(PyObject*)&PyDict_Type, tz_aware, uuid_subtype);
if (!scope) {
Py_DECREF(code);
goto invalid;
return NULL;
}
*position += scope_size;
@ -1764,7 +1784,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
value = PyInt_FromLong(i);
#endif
if (!value) {
goto invalid;
return NULL;
}
*position += 4;
break;
@ -1779,7 +1799,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
memcpy(&time, buffer + *position + 4, 4);
value = PyObject_CallFunction(state->Timestamp, "II", time, inc);
if (!value) {
goto invalid;
return NULL;
}
*position += 8;
break;
@ -1793,7 +1813,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
memcpy(&ll, buffer + *position, 8);
value = PyLong_FromLongLong(ll);
if (!value) {
goto invalid;
return NULL;
}
*position += 8;
break;
@ -1819,44 +1839,14 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
return NULL;
}
}
if (value) {
return value;
}
return value;
invalid:
/* Wrap any non-InvalidBSON errors in InvalidBSON. */
if (PyErr_Occurred()) {
/* Calling _error clears the error state, so fetch it first. */
PyObject *etype, *evalue, *etrace, *InvalidBSON;
PyErr_Fetch(&etype, &evalue, &etrace);
InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
if (!PyErr_GivenExceptionMatches(etype, InvalidBSON)) {
/* Raise InvalidBSON(str(e)). */
PyObject *msg = NULL;
Py_DECREF(etype);
etype = InvalidBSON;
if (evalue) {
msg = PyObject_Str(evalue);
Py_DECREF(evalue);
evalue = msg;
}
PyErr_NormalizeException(&etype, &evalue, &etrace);
}
}
/* Steals references to args. */
PyErr_Restore(etype, evalue, etrace);
Py_XDECREF(InvalidBSON);
return NULL;
} else {
PyObject *InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
PyErr_SetNone(InvalidBSON);
Py_DECREF(InvalidBSON);
}
error = _error("InvalidBSON");
if (error) {
PyErr_SetNone(error);
Py_DECREF(error);
}
return NULL;
}

View File

@ -16,11 +16,10 @@
"""Test the bson module."""
import unittest
import datetime
import re
import sys
import traceback
import unittest
try:
import uuid
should_test_uuid = True
@ -42,8 +41,7 @@ from bson.py3compat import b
from bson.son import SON
from bson.timestamp import Timestamp
from bson.errors import (InvalidDocument,
InvalidStringData,
InvalidBSON)
InvalidStringData)
from bson.max_key import MaxKey
from bson.min_key import MinKey
from bson.tz_util import (FixedOffset,
@ -450,32 +448,5 @@ class TestBSON(unittest.TestCase):
d = OrderedDict([("one", 1), ("two", 2), ("three", 3), ("four", 4)])
self.assertEqual(d, BSON.encode(d).decode(as_class=OrderedDict))
def test_exception_wrapping(self):
# No matter what exception is raised while trying to decode BSON,
# the final exception always matches InvalidBSON and the original
# is traceback preserved.
# Invalid Python regex, though valid PCRE: {'r': /[\w-\.]/}
# Will cause an error in re.compile().
bad_doc = b('"\x00\x00\x00\x07_id\x00R\x013\xd4S1\xe3\xd3\xd6Sgs'
'\x0br\x00[\\w-\\.]\x00\x00\x00')
try:
decode_all(bad_doc)
except InvalidBSON:
exc_type, exc_value, exc_tb = sys.exc_info()
# Original re error was captured and wrapped in InvalidBSON.
self.assertEqual(exc_value.args[0], 'bad character range')
# Traceback includes bson module's call into re module.
for filename, lineno, fname, text in traceback.extract_tb(exc_tb):
if filename.endswith('re.py') and fname == 'compile':
# Traceback was correctly preserved.
break
else:
self.fail('Traceback not captured')
else:
self.fail('InvalidBSON not raised')
if __name__ == "__main__":
unittest.main()