From 828aad3aad35f60aea5f5a34050a808aadacc2be Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 8 Jan 2013 09:53:39 +0000 Subject: [PATCH] Ensure deepcopy clones using the correct type - PYTHON-459 --- bson/__init__.py | 5 +---- bson/son.py | 12 +++++++++++- pymongo/cursor.py | 2 +- test/test_cursor.py | 7 +++++++ test/test_son.py | 41 +++++++++++++++++++++++++++++++++++++++-- 5 files changed, 59 insertions(+), 8 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index 02c8452bb..fe6bc4f74 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -32,7 +32,7 @@ from bson.max_key import MaxKey from bson.min_key import MinKey from bson.objectid import ObjectId from bson.py3compat import b, binary_type -from bson.son import SON +from bson.son import SON, RE_TYPE from bson.timestamp import Timestamp from bson.tz_util import utc @@ -52,9 +52,6 @@ except ImportError: PY3 = sys.version_info[0] == 3 -# This sort of sucks, but seems to be as good as it gets... -RE_TYPE = type(re.compile("")) - MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 MAX_INT64 = 9223372036854775807 diff --git a/bson/son.py b/bson/son.py index cdf2b74c5..539d3752d 100644 --- a/bson/son.py +++ b/bson/son.py @@ -19,6 +19,10 @@ of keys is important. A SON object can be used just like a normal Python dictionary.""" import copy +import re + +# This sort of sucks, but seems to be as good as it gets... +RE_TYPE = type(re.compile("")) class SON(dict): @@ -227,6 +231,12 @@ class SON(dict): def __deepcopy__(self, memo): out = SON() + val_id = id(self) + if val_id in memo: + return memo.get(val_id) + memo[val_id] = out for k, v in self.iteritems(): - out[k] = copy.deepcopy(v, memo) + if not isinstance(v, RE_TYPE): + v = copy.deepcopy(v, memo) + out[k] = v return out diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 421144efb..8616cea1c 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -854,7 +854,7 @@ class Cursor(object): return memo.get(val_id) memo[val_id] = y for key, value in x.iteritems(): - if isinstance(value, dict): + if isinstance(value, dict) and not isinstance(value, SON): value = self.__deepcopy(value, memo) elif not isinstance(value, RE_TYPE): value = copy.deepcopy(value, memo) diff --git a/test/test_cursor.py b/test/test_cursor.py index 64811657a..f4ab9fe7d 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -24,6 +24,7 @@ sys.path[0:0] = [""] from nose.plugins.skip import SkipTest from bson.code import Code +from bson.son import SON from pymongo import (ASCENDING, DESCENDING) from pymongo.database import Database @@ -522,6 +523,12 @@ class TestCursor(unittest.TestCase): id(cursor2._Cursor__spec)) self.assertEqual(len(cursor2._Cursor__spec), 2) + # Ensure hints are cloned as the correct type + cursor = self.db.test.find().hint([('z', 1), ("a", 1)]) + cursor2 = copy.deepcopy(cursor) + self.assertTrue(isinstance(cursor2._Cursor__hint, SON)) + self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint) + def test_add_remove_option(self): cursor = self.db.test.find() self.assertEqual(0, cursor._Cursor__query_options()) diff --git a/test/test_son.py b/test/test_son.py index b846a790f..cb2596da9 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -14,9 +14,11 @@ """Tests for the son module.""" -import unittest -import sys +import copy import pickle +import re +import sys +import unittest sys.path[0:0] = [""] from nose.plugins.skip import SkipTest @@ -111,5 +113,40 @@ class TestSON(unittest.TestCase): son_2_1_1 = pickle.loads(pickled_with_2_1_1) self.assertEqual(son_2_1_1, SON([])) + def test_copying(self): + simple_son = SON([]) + complex_son = SON([('son', simple_son), + ('list', [simple_son, simple_son])]) + regex_son = SON([("x", re.compile("^hello.*"))]) + reflexive_son = SON([('son', simple_son)]) + reflexive_son["reflexive"] = reflexive_son + + simple_son1 = copy.copy(simple_son) + self.assertEqual(simple_son, simple_son1) + + complex_son1 = copy.copy(complex_son) + self.assertEqual(complex_son, complex_son1) + + regex_son1 = copy.copy(regex_son) + self.assertEqual(regex_son, regex_son1) + + reflexive_son1 = copy.copy(reflexive_son) + self.assertEqual(reflexive_son, reflexive_son1) + + # Test deepcopying + simple_son1 = copy.deepcopy(simple_son) + self.assertEqual(simple_son, simple_son1) + + regex_son1 = copy.deepcopy(regex_son) + self.assertEqual(regex_son, regex_son1) + + complex_son1 = copy.deepcopy(complex_son) + self.assertEqual(complex_son, complex_son1) + + reflexive_son1 = copy.deepcopy(reflexive_son) + self.assertEqual(reflexive_son.keys(), reflexive_son1.keys()) + self.assertEqual(id(reflexive_son1), id(reflexive_son1["reflexive"])) + + if __name__ == "__main__": unittest.main()