From 0fb2fcfac0b49411cb7749bbf74afccf71fc53d8 Mon Sep 17 00:00:00 2001 From: Bernie Hackett Date: Wed, 24 Sep 2014 09:50:31 -0700 Subject: [PATCH] PYTHON-700 - Support subclassing of son manipulators --- pymongo/database.py | 6 ++++-- test/test_database.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/pymongo/database.py b/pymongo/database.py index d0316bbf1..ea74d1396 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -29,6 +29,7 @@ from pymongo.errors import (CollectionInvalid, from pymongo.read_preferences import (modes, secondary_ok_commands, ReadPreference) +from pymongo.son_manipulator import SONManipulator def _check_name(name): @@ -94,9 +95,10 @@ class Database(common.BaseObject): :Parameters: - `manipulator`: the manipulator to add """ + base = SONManipulator() def method_overwritten(instance, method): - return getattr(instance, method) != \ - getattr(super(instance.__class__, instance), method) + return (getattr( + instance, method).im_func != getattr(base, method).im_func) if manipulator.will_copy(): if method_overwritten(manipulator, "transform_incoming"): diff --git a/test/test_database.py b/test/test_database.py index aa3e00dce..6362a5b32 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -707,6 +707,30 @@ class TestDatabase(unittest.TestCase): out = db.test.find_one() self.assertEqual('value', out.get('value')) + def test_son_manipulator_inheritance(self): + class Thing(object): + def __init__(self, value): + self.value = value + + class ThingTransformer(SONManipulator): + def transform_incoming(self, thing, collection): + return {'value': thing.value} + + def transform_outgoing(self, son, collection): + return Thing(son['value']) + + class Child(ThingTransformer): + pass + + db = self.client.foo + db.add_son_manipulator(Child()) + t = Thing('value') + + db.test.remove() + db.test.insert([t]) + out = db.test.find_one() + self.assertTrue(isinstance(out, Thing)) + self.assertEqual('value', out.value) if __name__ == "__main__": unittest.main()