diff --git a/mongo.py b/mongo.py index 8e07eb0e9..1909ae843 100644 --- a/mongo.py +++ b/mongo.py @@ -197,6 +197,31 @@ class Mongo(object): raise TypeError("cannot dereference a %s" % type(dbref)) return self[dbref.collection()].find_one(dbref.id()) + def _fix(self, son): + """Fixes an object coming out of the database. + + Used to do things like auto dereferencing, if the option is enabled. + + Arguments: + - `son`: a SON object coming out of the database + """ + if not self.__auto_dereference: + return son + + def fix_value(value): + if isinstance(value, DBRef): + return self._fix(self.dereference(value)) + elif isinstance(value, (SON, types.DictType)): + return self._fix(value) + elif isinstance(value, types.ListType): + return [fix_value(v) for v in value] + return value + + for (key, value) in son.items(): + son[key] = fix_value(value) + + return son + class Collection(object): """A Mongo collection. """ @@ -644,9 +669,9 @@ class Cursor(object): def next(self): if len(self.__data): - return self.__data.pop(0) + return self.__collection.database()._fix(self.__data.pop(0)) if self._refresh(): - return self.__data.pop(0) + return self.__collection.database()._fix(self.__data.pop(0)) raise StopIteration class TestMongo(unittest.TestCase): @@ -1007,5 +1032,28 @@ class TestMongo(unittest.TestCase): key = db.test.save(obj) self.assertEqual(obj, db.dereference(DBRef("test", key))) + def test_auto_deref(self): + db = Mongo("test", self.host, self.port) + db.test.remove({}) + db.mike.remove({}) + + a = {"hello": u"world"} + key = db.mike.save(a) + dbref = DBRef("mike", key) + + self.assertEqual(db.dereference(dbref), a) + + b = {"mike_obj": dbref} + db.test.save(b) + self.assertEqual(dbref, db.test.find_one()["mike_obj"]) + self.assertEqual(a, db.dereference(db.test.find_one()["mike_obj"])) + + db = Mongo("test", self.host, self.port, {"auto_dereference": False}) + self.assertEqual(dbref, db.test.find_one()["mike_obj"]) + + db = Mongo("test", self.host, self.port, {"auto_dereference": True}) + self.assertNotEqual(dbref, db.test.find_one()["mike_obj"]) + self.assertEqual(a, db.test.find_one()["mike_obj"]) + if __name__ == "__main__": unittest.main()