PYTHON-1044 - Fix up unknown BSON type handing

This commit is contained in:
Bernie Hackett 2016-01-29 13:40:03 -08:00
parent 0a5ef8de6e
commit 8b7a13629b
3 changed files with 146 additions and 50 deletions

View File

@ -93,7 +93,15 @@ BSONMAX = b("\x7F") # Max key
_CODEC_OPTIONS_TYPE_ERROR = TypeError(
"codec_options must be an instance of bson.codec_options.CodecOptions")
def _get_int(data, position, as_class=None,
def _raise_unknown_type(element_type, element_name):
"""Unknown type helper."""
raise InvalidBSON("Detected unknown BSON type %r for fieldname %r. Are "
"you using the latest driver version?" % (
element_type, element_name))
def _get_int(data, position, name, as_class=None,
tz_aware=False, uuid_subtype=OLD_UUID_SUBTYPE,
compile_re=True, unsigned=False):
format = unsigned and "I" or "i"
@ -137,13 +145,15 @@ def _make_c_string(string, check_null=False):
"UTF-8: %r" % string)
def _get_number(data, position, as_class, tz_aware, uuid_subtype, compile_re):
def _get_number(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
num = struct.unpack("<d", data[position:position + 8])[0]
position += 8
return num, position
def _get_string(data, position, as_class, tz_aware, uuid_subtype, compile_re):
def _get_string(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
length = struct.unpack("<i", data[position:position + 4])[0]
if length <= 0 or (len(data) - position - 4) < length:
raise InvalidBSON("invalid string length")
@ -153,7 +163,8 @@ def _get_string(data, position, as_class, tz_aware, uuid_subtype, compile_re):
return _get_c_string(data, position, length - 1)
def _get_object(data, position, as_class, tz_aware, uuid_subtype, compile_re):
def _get_object(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
obj_size = struct.unpack("<i", data[position:position + 4])[0]
if data[position + obj_size - 1:position + obj_size] != ZERO:
raise InvalidBSON("bad eoo")
@ -168,26 +179,43 @@ def _get_object(data, position, as_class, tz_aware, uuid_subtype, compile_re):
return object, position
def _get_array(data, position, as_class, tz_aware, uuid_subtype, compile_re):
obj, position = _get_object(data, position,
as_class, tz_aware, uuid_subtype, compile_re)
def _get_array(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
size = struct.unpack("<i", data[position:position + 4])[0]
end = position + size - 1
if data[end:end + 1] != ZERO:
raise InvalidBSON("bad eoo")
position += 4
end -= 1
result = []
i = 0
while True:
# Avoid doing global and attibute lookups in the loop.
append = result.append
index = data.index
getter = _element_getter
while position < end:
element_type = data[position:position + 1]
# Just skip the keys.
position = index(ZERO, position) + 1
try:
result.append(obj[str(i)])
i += 1
value, position = getter[element_type](
data, position, name,
as_class, tz_aware, uuid_subtype, compile_re)
except KeyError:
break
return result, position
_raise_unknown_type(element_type, name)
append(value)
return result, position + 1
def _get_binary(data, position, as_class, tz_aware, uuid_subtype, compile_re):
length, position = _get_int(data, position)
def _get_binary(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
length, position = _get_int(data, position, name)
subtype = ord(data[position:position + 1])
position += 1
if subtype == 2:
length2, position = _get_int(data, position)
length2, position = _get_int(data, position, name)
if length2 != length - 4:
raise InvalidBSON("invalid binary (st 2) - lengths don't match!")
length = length2
@ -213,20 +241,22 @@ def _get_binary(data, position, as_class, tz_aware, uuid_subtype, compile_re):
return value, position
def _get_oid(data, position, as_class=None,
def _get_oid(data, position, name, as_class=None,
tz_aware=False, uuid_subtype=OLD_UUID_SUBTYPE, compile_re=True):
value = ObjectId(data[position:position + 12])
position += 12
return value, position
def _get_boolean(data, position, as_class, tz_aware, uuid_subtype, compile_re):
def _get_boolean(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
value = data[position:position + 1] == ONE
position += 1
return value, position
def _get_date(data, position, as_class, tz_aware, uuid_subtype, compile_re):
def _get_date(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
millis = struct.unpack("<q", data[position:position + 8])[0]
diff = millis % 1000
seconds = (millis - diff) / 1000
@ -238,27 +268,30 @@ def _get_date(data, position, as_class, tz_aware, uuid_subtype, compile_re):
return dt.replace(microsecond=diff * 1000), position
def _get_code(data, position, as_class, tz_aware, uuid_subtype, compile_re):
code, position = _get_string(data, position,
def _get_code(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
code, position = _get_string(data, position, name,
as_class, tz_aware, uuid_subtype, compile_re)
return Code(code), position
def _get_code_w_scope(
data, position, as_class, tz_aware, uuid_subtype, compile_re):
_, position = _get_int(data, position)
code, position = _get_string(data, position,
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
_, position = _get_int(data, position, name)
code, position = _get_string(data, position, name,
as_class, tz_aware, uuid_subtype, compile_re)
scope, position = _get_object(data, position,
scope, position = _get_object(data, position, name,
as_class, tz_aware, uuid_subtype, compile_re)
return Code(code, scope), position
def _get_null(data, position, as_class, tz_aware, uuid_subtype, compile_re):
def _get_null(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
return None, position
def _get_regex(data, position, as_class, tz_aware, uuid_subtype, compile_re):
def _get_regex(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
pattern, position = _get_c_string(data, position)
bson_flags, position = _get_c_string(data, position)
bson_re = Regex(pattern, bson_flags)
@ -268,21 +301,23 @@ def _get_regex(data, position, as_class, tz_aware, uuid_subtype, compile_re):
return bson_re, position
def _get_ref(data, position, as_class, tz_aware, uuid_subtype, compile_re):
collection, position = _get_string(data, position, as_class, tz_aware,
uuid_subtype, compile_re)
oid, position = _get_oid(data, position)
def _get_ref(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
collection, position = _get_string(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re)
oid, position = _get_oid(data, position, name)
return DBRef(collection, oid), position
def _get_timestamp(
data, position, as_class, tz_aware, uuid_subtype, compile_re):
inc, position = _get_int(data, position, unsigned=True)
timestamp, position = _get_int(data, position, unsigned=True)
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
inc, position = _get_int(data, position, name, unsigned=True)
timestamp, position = _get_int(data, position, name, unsigned=True)
return Timestamp(timestamp, inc), position
def _get_long(data, position, as_class, tz_aware, uuid_subtype, compile_re):
def _get_long(
data, position, name, as_class, tz_aware, uuid_subtype, compile_re):
# Have to cast to long; on 32-bit unpack may return an int.
# 2to3 will change long to int. That's fine since long doesn't
# exist in python3.
@ -310,8 +345,8 @@ _element_getter = {
BSONINT: _get_int, # number_int
BSONTIM: _get_timestamp,
BSONLON: _get_long, # Same as _get_int after 2to3 runs.
BSONMIN: lambda u, v, w, x, y, z: (MinKey(), v),
BSONMAX: lambda u, v, w, x, y, z: (MaxKey(), v)}
BSONMIN: lambda t, u, v, w, x, y, z: (MinKey(), u),
BSONMAX: lambda t, u, v, w, x, y, z: (MaxKey(), u)}
def _element_to_dict(
@ -319,8 +354,12 @@ def _element_to_dict(
element_type = data[position:position + 1]
position += 1
element_name, position = _get_c_string(data, position)
value, position = _element_getter[element_type](
data, position, as_class, tz_aware, uuid_subtype, compile_re)
try:
func = _element_getter[element_type]
except KeyError:
_raise_unknown_type(element_type, element_name)
value, position = func(data, position, element_name,
as_class, tz_aware, uuid_subtype, compile_re)
return element_name, value, position

View File

@ -1427,10 +1427,10 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
return result;
}
static PyObject* get_value(PyObject* self, const char* buffer, unsigned* position,
unsigned char type, unsigned max, PyObject* as_class,
unsigned char tz_aware, unsigned char uuid_subtype,
unsigned char compile_re) {
static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
unsigned* position, unsigned char type, unsigned max,
PyObject* as_class, unsigned char tz_aware,
unsigned char uuid_subtype, unsigned char compile_re) {
struct module_state *state = GETSTATE(self);
PyObject* value = NULL;
@ -1574,7 +1574,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
Py_DECREF(value);
goto invalid;
}
to_append = get_value(self, buffer, position, bson_type,
to_append = get_value(self, name, buffer, position, bson_type,
max - (unsigned)key_size,
as_class, tz_aware, uuid_subtype,
compile_re);
@ -2078,11 +2078,50 @@ static PyObject* get_value(PyObject* self, const char* buffer, unsigned* positio
}
default:
{
PyObject* InvalidDocument = _error("InvalidDocument");
if (InvalidDocument) {
PyErr_SetString(InvalidDocument,
"no c decoder for this type yet");
Py_DECREF(InvalidDocument);
PyObject* InvalidBSON = _error("InvalidBSON");
if (InvalidBSON) {
#if PY_MAJOR_VERSION >= 3
PyObject* type_obj = PyBytes_FromFormat("%c", type);
#else
PyObject* type_obj = PyString_FromFormat("%c", type);
#endif
if (type_obj) {
PyObject* type_repr = PyObject_Repr(type_obj);
Py_DECREF(type_obj);
if (type_repr) {
PyObject* errmsg = NULL;
#if PY_MAJOR_VERSION >= 3
PyObject* left = PyUnicode_FromString(
"Detected unknown BSON type ");
if (left) {
PyObject* lmsg = PyUnicode_Concat(left, type_repr);
Py_DECREF(left);
if (lmsg) {
errmsg = PyUnicode_FromFormat(
"%U for fieldname '%U'. Are you using the "
"latest driver version?", lmsg, name);
Py_DECREF(lmsg);
}
}
#else
PyObject* name_repr = PyObject_Repr(name);
if (name_repr) {
errmsg = PyString_FromFormat(
"Detected unknown BSON type %s for fieldname %s."
" Are you using the latest driver version?",
PyString_AS_STRING(type_repr),
PyString_AS_STRING(name_repr));
Py_DECREF(name_repr);
}
#endif
Py_DECREF(type_repr);
if (errmsg) {
PyErr_SetObject(InvalidBSON, errmsg);
Py_DECREF(errmsg);
}
}
}
Py_DECREF(InvalidBSON);
}
goto invalid;
}
@ -2173,7 +2212,7 @@ static PyObject* _elements_to_dict(PyObject* self, const char* string,
return NULL;
}
position += (unsigned)name_length + 1;
value = get_value(self, string, &position, type,
value = get_value(self, name, string, &position, type,
max - position, as_class, tz_aware, uuid_subtype,
compile_re);
if (!value) {

View File

@ -358,6 +358,24 @@ class TestBSON(unittest.TestCase):
qcheck.check_unittest(self, encode_then_decode_backport_precedence,
qcheck.gen_mongo_dict(3))
def test_unknown_type(self):
# Repr value differs with major python version
part = "type %r for fieldname %r" % (b('\x13'), u"foo")
docs = [
b('\x0e\x00\x00\x00\x13foo\x00\x01\x00\x00\x00\x00'),
b('\x16\x00\x00\x00\x04foo\x00\x0c\x00\x00\x00\x130'
'\x00\x01\x00\x00\x00\x00\x00'),
b(' \x00\x00\x00\x04bar\x00\x16\x00\x00\x00\x030\x00\x0e\x00\x00'
'\x00\x13foo\x00\x01\x00\x00\x00\x00\x00\x00')]
for bs in docs:
try:
bson.BSON(bs).decode()
except Exception, exc:
self.assertTrue(isinstance(exc, InvalidBSON))
self.assertTrue(part in str(exc))
else:
self.fail("Failed to raise an exception.")
def test_dbpointer(self):
# *Note* - DBPointer and DBRef are *not* the same thing. DBPointer
# is a deprecated BSON type. DBRef is a convention that does not