Fix pickling/unpickling of DBRef PYTHON-231

This commit is contained in:
behackett 2011-04-05 09:46:21 -07:00
parent 9f9b4a6666
commit 1d1d8729fd
2 changed files with 17 additions and 4 deletions

View File

@ -79,7 +79,16 @@ class DBRef(object):
return self.__database
def __getattr__(self, key):
return self.__kwargs[key]
try:
return self.__kwargs[key]
except KeyError:
raise AttributeError(key)
# Have to provide __setstate__ to avoid
# infinite recursion since we override
# __getattr__.
def __setstate__(self, state):
self.__dict__.update(state)
def as_doc(self):
"""Get the SON document representation of this DBRef.

View File

@ -14,6 +14,7 @@
"""Tests for the dbref module."""
import pickle
import unittest
import sys
sys.path[0:0] = [""]
@ -90,7 +91,7 @@ class TestDBRef(unittest.TestCase):
self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5))
self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5, foo="baz"))
self.assertEqual("bar", DBRef("coll", 5, foo="bar").foo)
self.assertRaises(KeyError, getattr, DBRef("coll", 5, foo="bar"), "bar")
self.assertRaises(AttributeError, getattr, DBRef("coll", 5, foo="bar"), "bar")
def test_deepcopy(self):
a = DBRef('coll', 'asdf', 'db', x=[1])
@ -105,8 +106,11 @@ class TestDBRef(unittest.TestCase):
self.assertEqual(a.x, [1])
self.assertEqual(b.x, [2])
def test_pickling(self):
dbr = DBRef('coll', 5, foo='bar')
pkl = pickle.dumps(dbr)
dbr2 = pickle.loads(pkl)
self.assertEqual(dbr, dbr2)
if __name__ == "__main__":
unittest.main()