From b4cb9be0d32e464db4843c0b54e081a3f2e77bdd Mon Sep 17 00:00:00 2001 From: Mike Dirolf Date: Tue, 10 Nov 2009 14:54:39 -0500 Subject: [PATCH] Add database support to DBRefs (as ) --- pymongo/_cbsonmodule.c | 129 +++++++---------------------------- pymongo/bson.py | 8 +-- pymongo/database.py | 10 ++- pymongo/dbref.py | 56 +++++++++++---- pymongo/son.py | 23 ++++--- pymongo/son_manipulator.py | 4 ++ test/test_bson.py | 2 + test/test_database.py | 2 + test/test_dbref.py | 16 ++++- test/test_son_manipulator.py | 1 - 10 files changed, 110 insertions(+), 141 deletions(-) diff --git a/pymongo/_cbsonmodule.c b/pymongo/_cbsonmodule.c index d170e44e3..fa8cd28fb 100644 --- a/pymongo/_cbsonmodule.c +++ b/pymongo/_cbsonmodule.c @@ -602,86 +602,16 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject* } return 1; } else if (PyObject_IsInstance(value, DBRef)) { - int start_position, - length_location, - collection_length, - type_pos, - length; - PyObject* collection_object; - PyObject* encoded_collection; - PyObject* id_object; - char zero = 0; - + PyObject* as_doc = PyObject_CallMethod(value, "as_doc", NULL); + if (!as_doc) { + return 0; + } + if (!write_dict(buffer, as_doc, 0)) { + Py_DECREF(as_doc); + return 0; + } + Py_DECREF(as_doc); *(buffer->buffer + type_byte) = 0x03; - start_position = buffer->position; - - /* save space for length */ - length_location = buffer_save_bytes(buffer, 4); - if (length_location == -1) { - return 0; - } - - collection_object = PyObject_GetAttrString(value, "collection"); - if (!collection_object) { - return 0; - } - encoded_collection = PyUnicode_AsUTF8String(collection_object); - Py_DECREF(collection_object); - if (!encoded_collection) { - return 0; - } - { - const char* collection = PyString_AsString(encoded_collection); - if (!collection) { - Py_DECREF(encoded_collection); - return 0; - } - id_object = PyObject_GetAttrString(value, "id"); - if (!id_object) { - Py_DECREF(encoded_collection); - return 0; - } - - if (!buffer_write_bytes(buffer, "\x02$ref\x00", 6)) { - Py_DECREF(encoded_collection); - Py_DECREF(id_object); - return 0; - } - collection_length = strlen(collection) + 1; - if (!buffer_write_bytes(buffer, (const char*)&collection_length, 4)) { - Py_DECREF(encoded_collection); - Py_DECREF(id_object); - return 0; - } - if (!buffer_write_bytes(buffer, collection, collection_length)) { - Py_DECREF(encoded_collection); - Py_DECREF(id_object); - return 0; - } - } - Py_DECREF(encoded_collection); - - type_pos = buffer_save_bytes(buffer, 1); - if (type_pos == -1) { - Py_DECREF(id_object); - return 0; - } - if (!buffer_write_bytes(buffer, "$id\x00", 4)) { - Py_DECREF(id_object); - return 0; - } - if (!write_element_to_buffer(buffer, type_pos, id_object, check_keys)) { - Py_DECREF(id_object); - return 0; - } - Py_DECREF(id_object); - - /* write null byte and fill in length */ - if (!buffer_write_bytes(buffer, &zero, 1)) { - return 0; - } - length = buffer->position - start_position; - memcpy(buffer->buffer + length_location, &length, 4); return 1; } else if (PyObject_HasAttrString(value, "pattern") && @@ -1043,33 +973,22 @@ static PyObject* get_value(const char* buffer, int* position, int type) { { int size; memcpy(&size, buffer + *position, 4); - if (strcmp(buffer + *position + 5, "$ref") == 0) { /* DBRef */ - char id_type; - PyObject* id; - - int offset = *position + 14; - int collection_length = strlen(buffer + offset); - PyObject* collection = PyUnicode_DecodeUTF8(buffer + offset, collection_length, "strict"); - if (!collection) { - return NULL; - } - offset += collection_length + 1; - id_type = buffer[offset]; - offset += 5; - id = get_value(buffer, &offset, (int)id_type); - if (!id) { - Py_DECREF(collection); - return NULL; - } - value = PyObject_CallFunctionObjArgs(DBRef, collection, id, NULL); - Py_DECREF(collection); - Py_DECREF(id); - } else { - value = elements_to_dict(buffer + *position + 4, size - 5); - if (!value) { - return NULL; - } + value = elements_to_dict(buffer + *position + 4, size - 5); + if (!value) { + return NULL; } + + /* Decoding for DBRefs */ + if (strcmp(buffer + *position + 5, "$ref") == 0) { /* DBRef */ + PyObject* id = PyDict_GetItemString(value, "$id"); + PyObject* collection = PyDict_GetItemString(value, "$ref"); + PyObject* database = PyDict_GetItemString(value, "$db"); + + /* This works even if there is no $db since database will be NULL and + the call will be as if there were only two arguments specified. */ + value = PyObject_CallFunctionObjArgs(DBRef, collection, id, database, NULL); + } + *position += size; break; } diff --git a/pymongo/bson.py b/pymongo/bson.py index 898d12143..07c8b1d1a 100644 --- a/pymongo/bson.py +++ b/pymongo/bson.py @@ -236,7 +236,7 @@ def _get_string(data): def _get_object(data): (object, data) = _bson_to_dict(data) if "$ref" in object: - return (DBRef(object["$ref"], object["$id"]), data) + return (DBRef(object["$ref"], object["$id"], object.get("$db", None)), data) return (object, data) @@ -456,10 +456,8 @@ def _element_to_bson(key, value, check_keys): flags += "x" return "\x0B" + name + _make_c_string(pattern) + _make_c_string(flags) if isinstance(value, DBRef): - return _element_to_bson(key, - SON([("$ref", value.collection), - ("$id", value.id)]), - False) + return _element_to_bson(key, value.as_doc(), False) + raise InvalidDocument("cannot convert value of type %s to bson" % type(value)) diff --git a/pymongo/database.py b/pymongo/database.py index 9ddb984bf..6baf8ec6d 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -377,14 +377,20 @@ class Database(object): def dereference(self, dbref): """Dereference a DBRef, getting the SON object it points to. - Raises TypeError if dbref is not an instance of DBRef. Returns a SON - object or None if the reference does not point to a valid object. + Raises TypeError if `dbref` is not an instance of DBRef. Returns a SON + object or None if the reference does not point to a valid object. Raises + ValueError if `dbref` has a database specified that is different from + the current database. :Parameters: - `dbref`: the reference """ if not isinstance(dbref, DBRef): raise TypeError("cannot dereference a %s" % type(dbref)) + if dbref.database is not None and dbref.database != self.__name: + raise ValueError("trying to dereference a DBRef that points to " + "another database (%r not %r)" % (dbref.database, + self.__name)) return self[dbref.collection].find_one({"_id": dbref.id}) def eval(self, code, *args): diff --git a/pymongo/dbref.py b/pymongo/dbref.py index 55175a989..3ea48503e 100644 --- a/pymongo/dbref.py +++ b/pymongo/dbref.py @@ -12,35 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tools for manipulating DBRefs (references to Mongo objects).""" +"""Tools for manipulating DBRefs (references to MongoDB documents).""" import types +from son import SON + class DBRef(object): - """A reference to an object stored in a Mongo database. + """A reference to a document stored in a Mongo database. """ - def __init__(self, collection, id): + def __init__(self, collection, id, database=None): """Initialize a new DBRef. - Raises TypeError if collection is not an instance of (str, unicode). + Raises TypeError if collection or database is not an instance of + (str, unicode). `database` is optional and allows references to + documents to work across databases. :Parameters: - - `collection`: the collection the object is stored in - - `id`: the value of the object's _id field + - `collection`: name of the collection the document is stored in + - `id`: the value of the document's _id field + - `database` (optional): name of the database to reference """ if not isinstance(collection, types.StringTypes): raise TypeError("collection must be an instance of (str, unicode)") - - if isinstance(collection, types.StringType): - collection = unicode(collection, "utf-8") + if not isinstance(database, (types.StringTypes, types.NoneType)): + raise TypeError("database must be an instance of (str, unicode)") self.__collection = collection self.__id = id + self.__database = database def collection(self): - """Get this DBRef's collection as unicode. + """Get the name of this DBRef's collection as unicode. """ return self.__collection collection = property(collection) @@ -51,14 +56,35 @@ class DBRef(object): return self.__id id = property(id) + def database(self): + """Get the name of this DBRef's database. + + Returns None if this DBRef doesn't specify a database. + """ + return self.__database + database = property(database) + + def as_doc(self): + """Get the SON document representation of this DBRef. + + Generally not needed by application developers + """ + doc = SON([("$ref", self.collection), + ("$id", self.id)]) + if self.database is not None: + doc["$db"] = self.database + return doc + def __repr__(self): - return "DBRef(" + repr(self.collection) + ", " + repr(self.id) + ")" + if self.database is None: + return "DBRef(%r, %r)" % (self.collection, self.id) + return "DBRef(%r, %r, %r)" % (self.collection, self.id, self.database) def __cmp__(self, other): if isinstance(other, DBRef): - return cmp([self.__collection, self.__id], - [other.__collection, other.__id]) + return cmp([self.__database, self.__collection, self.__id], + [other.__database, other.__collection, other.__id]) return NotImplemented - + def __hash__(self): - return hash((self.__collection, self.__id)) + return hash((self.__collection, self.__id, self.__database)) diff --git a/pymongo/son.py b/pymongo/son.py index 41ba2e84d..3a6dd5d75 100644 --- a/pymongo/son.py +++ b/pymongo/son.py @@ -24,17 +24,6 @@ import binascii import base64 import types -try: - import xml.etree.ElementTree as ET -except ImportError: - import elementtree.ElementTree as ET - -from code import Code -from binary import Binary -from objectid import ObjectId -from dbref import DBRef -from errors import UnsupportedTag - class SON(dict): """SON data. @@ -223,7 +212,19 @@ class SON(dict): def from_xml(cls, xml): """Create an instance of SON from an xml document. + + This is really only used for testing, and is probably unnecessary. """ + try: + import xml.etree.ElementTree as ET + except ImportError: + import elementtree.ElementTree as ET + + from code import Code + from binary import Binary + from objectid import ObjectId + from dbref import DBRef + from errors import UnsupportedTag def pad(list, index): while index >= len(list): diff --git a/pymongo/son_manipulator.py b/pymongo/son_manipulator.py index 07f6a13c5..0b8c31bd4 100644 --- a/pymongo/son_manipulator.py +++ b/pymongo/son_manipulator.py @@ -117,6 +117,10 @@ class AutoReference(SONManipulator): only be auto-referenced if they have an *_ns* field. NOTE: this will behave poorly if you have a circular reference. + + TODO: this only works for documents that are in the same database. To fix + this we'll need to add a DatabaseInjector that adds *_db* and then make + use of the optional *database* support for DBRefs. """ def __init__(self, db): diff --git a/test/test_bson.py b/test/test_bson.py index 2adf55ec1..dc759eaed 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -166,6 +166,8 @@ class TestBSON(unittest.TestCase): helper({"another binary": Binary("test")}) helper(SON([(u'test dst', datetime.datetime(1993, 4, 4, 2))])) helper({"big float": float(10000000000)}) + helper({"ref": DBRef("coll", 5)}) + helper({"ref": DBRef("coll", 5, "foo")}) def from_then_to_dict(dict): return dict == (BSON.from_dict(dict)).to_dict() diff --git a/test/test_database.py b/test/test_database.py index 0aa087a06..be76591a1 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -271,6 +271,8 @@ class TestDatabase(unittest.TestCase): obj = {"x": True} key = db.test.save(obj) self.assertEqual(obj, db.dereference(DBRef("test", key))) + self.assertEqual(obj, db.dereference(DBRef("test", key, "pymongo_test"))) + self.assertRaises(ValueError, db.dereference, DBRef("test", key, "foo")) self.assertEqual(None, db.dereference(DBRef("test", 4))) obj = {"_id": 4} diff --git a/test/test_dbref.py b/test/test_dbref.py index f85c49892..12f478099 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -36,9 +36,11 @@ class TestDBRef(unittest.TestCase): self.assertRaises(TypeError, DBRef, 1.5, a) self.assertRaises(TypeError, DBRef, a, a) self.assertRaises(TypeError, DBRef, None, a) + self.assertRaises(TypeError, DBRef, "coll", a, 5) self.assert_(DBRef("coll", a)) self.assert_(DBRef(u"coll", a)) self.assert_(DBRef(u"coll", 5)) + self.assert_(DBRef(u"coll", 5, "database")) def test_read_only(self): a = DBRef("coll", ObjectId()) @@ -49,25 +51,35 @@ class TestDBRef(unittest.TestCase): def bar(): a.id = "aoeu" - a.collection + self.assertEqual("coll", a.collection) a.id + self.assertEqual(None, a.database) self.assertRaises(AttributeError, foo) self.assertRaises(AttributeError, bar) def test_repr(self): self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"))), - "DBRef(u'coll', ObjectId('1234567890abcdef12345678'))") + "DBRef('coll', ObjectId('1234567890abcdef12345678'))") self.assertEqual(repr(DBRef(u"coll", ObjectId("1234567890abcdef12345678"))), "DBRef(u'coll', ObjectId('1234567890abcdef12345678'))") + self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo")), + "DBRef('coll', ObjectId('1234567890abcdef12345678'), 'foo')") def test_cmp(self): self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), DBRef(u"coll", ObjectId("1234567890abcdef12345678"))) + self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), + DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "foo")) self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), DBRef("col", ObjectId("1234567890abcdef12345678"))) self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), DBRef("coll", ObjectId("123456789011"))) self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), 4) + self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"), + DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "foo")) + self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"), + DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "bar")) + if __name__ == "__main__": unittest.main() diff --git a/test/test_son_manipulator.py b/test/test_son_manipulator.py index 5d36f0b8d..25e4addd8 100644 --- a/test/test_son_manipulator.py +++ b/test/test_son_manipulator.py @@ -21,7 +21,6 @@ sys.path[0:0] = [""] import qcheck from pymongo.objectid import ObjectId -from pymongo.dbref import DBRef from pymongo.son import SON from pymongo.son_manipulator import SONManipulator, ObjectIdInjector from pymongo.son_manipulator import NamespaceInjector, ObjectIdShuffler