diff --git a/son_manipulator.py b/son_manipulator.py new file mode 100644 index 000000000..f5f3aa6c1 --- /dev/null +++ b/son_manipulator.py @@ -0,0 +1,58 @@ +"""Manipulators that can edit SON objects as the enter and exit a database. + +New manipulators should be defined as subclasses of SONManipulator and can be +installed on a database by calling `Database.add_son_manipulator`.""" + +from objectid import ObjectId + +class SONManipulator(object): + """A base son manipulator. + + This manipulator just saves and restores objects without changing them. + """ + def __init__(self, database): + """Instantiate the manager. + + Arguments: + - `database`: a Mongo Database + """ + self.__database = database + + def transform_incoming(self, son, collection): + """Manipulate an incoming son object. + + Arguments: + - `son`: the son object to be inserted into the database + - `collection`: the collection the object is being inserted into + """ + return son + + def transform_outgoing(self, son, collection): + """Manipulate an outgoing son object. + + Arguments: + - `son`: the son object being retrieved from the database + - `collection`: the collection this object was stored in + """ + return son + +class ObjectIdInjector(SONManipulator): + """A son manipulator that adds the _id field if it is missing. + """ + def transform_incoming(self, son, collection): + """Add an _id field if it is missing. + """ + if "_id" in son: + assert isinstance(son["_id"], ObjectId), "'_id' must be an ObjectId" + else: + son["_id"] = ObjectId() + return son + +class NamespaceInjector(SONManipulator): + """A son manipulator that adds the _ns field. + """ + def transform_incoming(self, son, collection): + """Add the _ns field to the incoming object + """ + son["_ns"] = collection.name() + return son diff --git a/test/test_son_manipulator.py b/test/test_son_manipulator.py new file mode 100644 index 000000000..524ff2557 --- /dev/null +++ b/test/test_son_manipulator.py @@ -0,0 +1,64 @@ +"""Tests for SONManipulators. +""" + +import unittest + +import qcheck +from objectid import ObjectId +from son_manipulator import SONManipulator, ObjectIdInjector, NamespaceInjector +from database import Database +from test_connection import get_connection + +class TestSONManipulator(unittest.TestCase): + def setUp(self): + self.db = Database(get_connection(), "test") + + def test_basic(self): + manip = SONManipulator(self.db) + collection = self.db.test + + def incoming_is_identity(son): + return son == manip.transform_incoming(son, collection) + qcheck.check_unittest(self, incoming_is_identity, + qcheck.gen_mongo_dict(3)) + + def outgoing_is_identity(son): + return son == manip.transform_outgoing(son, collection) + qcheck.check_unittest(self, outgoing_is_identity, + qcheck.gen_mongo_dict(3)) + + def test_id_injection(self): + manip = ObjectIdInjector(self.db) + collection = self.db.test + + def incoming_adds_id(son): + son = manip.transform_incoming(son, collection) + assert "_id" in son + assert isinstance(son["_id"], ObjectId) + return True + qcheck.check_unittest(self, incoming_adds_id, + qcheck.gen_mongo_dict(3)) + + def outgoing_is_identity(son): + return son == manip.transform_outgoing(son, collection) + qcheck.check_unittest(self, outgoing_is_identity, + qcheck.gen_mongo_dict(3)) + + def test_ns_injection(self): + manip = NamespaceInjector(self.db) + collection = self.db.test + + def incoming_adds_ns(son): + son = manip.transform_incoming(son, collection) + assert "_ns" in son + return son["_ns"] == collection.name() + qcheck.check_unittest(self, incoming_adds_ns, + qcheck.gen_mongo_dict(3)) + + def outgoing_is_identity(son): + return son == manip.transform_outgoing(son, collection) + qcheck.check_unittest(self, outgoing_is_identity, + qcheck.gen_mongo_dict(3)) + +if __name__ == "__main__": + unittest.main()