diff --git a/bson/son.py b/bson/son.py index efbac6b01..55bc8a59c 100644 --- a/bson/son.py +++ b/bson/son.py @@ -72,6 +72,11 @@ class SON(dict): self.update(data) self.update(kwargs) + def __new__(cls, *args, **kwargs): + instance = super(SON, cls).__new__(cls, *args, **kwargs) + instance.__keys = [] + return instance + def __repr__(self): result = [] for key in self.__keys: @@ -214,3 +219,4 @@ class SON(dict): for k, v in self.iteritems(): out[k] = copy.deepcopy(v, memo) return out + diff --git a/test/test_dbref.py b/test/test_dbref.py index 249a41290..203c059d2 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -125,9 +125,10 @@ class TestDBRef(unittest.TestCase): def test_pickling(self): dbr = DBRef('coll', 5, foo='bar') - pkl = pickle.dumps(dbr) - dbr2 = pickle.loads(pkl) - self.assertEqual(dbr, dbr2) + for protocol in [0, 1, 2, -1]: + pkl = pickle.dumps(dbr, protocol=protocol) + dbr2 = pickle.loads(pkl) + self.assertEqual(dbr, dbr2) def test_dbref_hash(self): dbref_1a = DBRef('collection', 'id', 'database') diff --git a/test/test_objectid.py b/test/test_objectid.py index 8c48c36e2..7cccc8000 100644 --- a/test/test_objectid.py +++ b/test/test_objectid.py @@ -133,7 +133,9 @@ class TestObjectId(unittest.TestCase): def test_pickling(self): orig = ObjectId() - self.assertEqual(orig, pickle.loads(pickle.dumps(orig))) + for protocol in [0, 1, 2, -1]: + pkl = pickle.dumps(orig, protocol=protocol) + self.assertEqual(orig, pickle.loads(pkl)) def test_pickle_backwards_compatability(self): diff --git a/test/test_son.py b/test/test_son.py index 5a7924c7c..4d3ffaa5e 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -16,6 +16,7 @@ import unittest import sys +import pickle sys.path[0:0] = [""] from bson.son import SON @@ -53,6 +54,27 @@ class TestSON(unittest.TestCase): self.assertEqual(dict, c.to_dict()["blah"][0].__class__) self.assertEqual(dict, d.to_dict()["blah"]["foo"].__class__) + def test_pickle(self): + + simple_son = SON([]) + complex_son = SON([('son', simple_son), ('list', [simple_son, simple_son])]) + + for protocol in [0, 1, 2, -1]: + pickled = pickle.loads(pickle.dumps(complex_son, protocol=protocol)) + self.assertEquals(pickled['son'], pickled['list'][0]) + self.assertEquals(pickled['son'], pickled['list'][1]) + + def test_pickle_backwards_compatability(self): + + # This string was generated by pickling a SON object in pymongo + # version 2.1.1 + pickled_with_2_1_1 = ( + "ccopy_reg\n_reconstructor\np0\n(cbson.son\nSON\np1\n" + "c__builtin__\ndict\np2\n(dp3\ntp4\nRp5\n(dp6\n" + "S'_SON__keys'\np7\n(lp8\nsb." + ) + son_2_1_1 = pickle.loads(pickled_with_2_1_1) + self.assertEqual(son_2_1_1, SON([])) if __name__ == "__main__": unittest.main() diff --git a/test/test_timestamp.py b/test/test_timestamp.py index b7002d9a8..81e8a7140 100644 --- a/test/test_timestamp.py +++ b/test/test_timestamp.py @@ -49,8 +49,10 @@ class TestTimestamp(unittest.TestCase): dc = copy.deepcopy(d) self.assertEqual(dc, t.as_datetime()) - dp = pickle.loads(pickle.dumps(d)) - self.assertEqual(dp, t.as_datetime()) + for protocol in [0, 1, 2, -1]: + pkl = pickle.dumps(d, protocol=protocol) + dp = pickle.loads(pkl) + self.assertEqual(dp, t.as_datetime()) def test_exceptions(self): self.assertRaises(TypeError, Timestamp)