From 1d1d8729fd014689bbde5e2f0500dd419a971537 Mon Sep 17 00:00:00 2001 From: behackett Date: Tue, 5 Apr 2011 09:46:21 -0700 Subject: [PATCH] Fix pickling/unpickling of DBRef PYTHON-231 --- bson/dbref.py | 11 ++++++++++- test/test_dbref.py | 10 +++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/bson/dbref.py b/bson/dbref.py index 3b7bd0faf..3fb73ac9d 100644 --- a/bson/dbref.py +++ b/bson/dbref.py @@ -79,7 +79,16 @@ class DBRef(object): return self.__database def __getattr__(self, key): - return self.__kwargs[key] + try: + return self.__kwargs[key] + except KeyError: + raise AttributeError(key) + + # Have to provide __setstate__ to avoid + # infinite recursion since we override + # __getattr__. + def __setstate__(self, state): + self.__dict__.update(state) def as_doc(self): """Get the SON document representation of this DBRef. diff --git a/test/test_dbref.py b/test/test_dbref.py index 0c9ff6279..a6aad89a9 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -14,6 +14,7 @@ """Tests for the dbref module.""" +import pickle import unittest import sys sys.path[0:0] = [""] @@ -90,7 +91,7 @@ class TestDBRef(unittest.TestCase): self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5)) self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5, foo="baz")) self.assertEqual("bar", DBRef("coll", 5, foo="bar").foo) - self.assertRaises(KeyError, getattr, DBRef("coll", 5, foo="bar"), "bar") + self.assertRaises(AttributeError, getattr, DBRef("coll", 5, foo="bar"), "bar") def test_deepcopy(self): a = DBRef('coll', 'asdf', 'db', x=[1]) @@ -105,8 +106,11 @@ class TestDBRef(unittest.TestCase): self.assertEqual(a.x, [1]) self.assertEqual(b.x, [2]) - - + def test_pickling(self): + dbr = DBRef('coll', 5, foo='bar') + pkl = pickle.dumps(dbr) + dbr2 = pickle.loads(pkl) + self.assertEqual(dbr, dbr2) if __name__ == "__main__": unittest.main()