diff --git a/bson/__init__.py b/bson/__init__.py index fe6bc4f74..919772761 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -150,7 +150,7 @@ def _get_object(data, position, as_class, tz_aware, uuid_subtype): object = _elements_to_dict(encoded, as_class, tz_aware, uuid_subtype) position += obj_size if "$ref" in object: - return (DBRef(object.pop("$ref"), object.pop("$id"), + return (DBRef(object.pop("$ref"), object.pop("$id", None), object.pop("$db", None), object), position) return object, position diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 0e50f1d33..19a56b093 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -1202,8 +1202,14 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position, Py_INCREF(collection); PyDict_DelItemString(value, "$ref"); - Py_INCREF(id); - PyDict_DelItemString(value, "$id"); + + if (id == NULL) { + id = Py_None; + Py_INCREF(id); + } else { + Py_INCREF(id); + PyDict_DelItemString(value, "$id"); + } if (database == NULL) { database = Py_None; diff --git a/test/test_collection.py b/test/test_collection.py index b54e57453..0a4da3446 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -30,6 +30,7 @@ sys.path[0:0] = [""] from bson.binary import Binary, UUIDLegacy, OLD_UUID_SUBTYPE, UUID_SUBTYPE from bson.code import Code +from bson.dbref import DBRef from bson.objectid import ObjectId from bson.py3compat import b from bson.son import SON @@ -1675,6 +1676,31 @@ class TestCollection(unittest.TestCase): self.assertRaises(InvalidDocument, c.save, {"x": c}) warnings.simplefilter("default") + def test_bad_dbref(self): + c = self.db.test + c.drop() + + # Incomplete DBRefs. + self.assertRaises( + InvalidDocument, + c.insert, {'ref': {'$ref': 'collection'}}) + + self.assertRaises( + InvalidDocument, + c.insert, {'ref': {'$id': ObjectId()}}) + + ref_only = {'ref': {'$ref': 'collection'}} + id_only = {'ref': {'$id': ObjectId()}} + + # Force insert of ref without $id. + c.insert(ref_only, check_keys=False) + self.assertEqual(DBRef('collection', id=None), c.find_one()['ref']) + c.drop() + + # DBRef without $ref is decoded as normal subdocument. + c.insert(id_only, check_keys=False) + self.assertEqual(id_only, c.find_one()) + def test_as_class(self): c = self.db.test c.drop()