Add database support to DBRefs (as )

This commit is contained in:
Mike Dirolf 2009-11-10 14:54:39 -05:00
parent 9bbeea4782
commit b4cb9be0d3
10 changed files with 110 additions and 141 deletions

View File

@ -602,86 +602,16 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
}
return 1;
} else if (PyObject_IsInstance(value, DBRef)) {
int start_position,
length_location,
collection_length,
type_pos,
length;
PyObject* collection_object;
PyObject* encoded_collection;
PyObject* id_object;
char zero = 0;
PyObject* as_doc = PyObject_CallMethod(value, "as_doc", NULL);
if (!as_doc) {
return 0;
}
if (!write_dict(buffer, as_doc, 0)) {
Py_DECREF(as_doc);
return 0;
}
Py_DECREF(as_doc);
*(buffer->buffer + type_byte) = 0x03;
start_position = buffer->position;
/* save space for length */
length_location = buffer_save_bytes(buffer, 4);
if (length_location == -1) {
return 0;
}
collection_object = PyObject_GetAttrString(value, "collection");
if (!collection_object) {
return 0;
}
encoded_collection = PyUnicode_AsUTF8String(collection_object);
Py_DECREF(collection_object);
if (!encoded_collection) {
return 0;
}
{
const char* collection = PyString_AsString(encoded_collection);
if (!collection) {
Py_DECREF(encoded_collection);
return 0;
}
id_object = PyObject_GetAttrString(value, "id");
if (!id_object) {
Py_DECREF(encoded_collection);
return 0;
}
if (!buffer_write_bytes(buffer, "\x02$ref\x00", 6)) {
Py_DECREF(encoded_collection);
Py_DECREF(id_object);
return 0;
}
collection_length = strlen(collection) + 1;
if (!buffer_write_bytes(buffer, (const char*)&collection_length, 4)) {
Py_DECREF(encoded_collection);
Py_DECREF(id_object);
return 0;
}
if (!buffer_write_bytes(buffer, collection, collection_length)) {
Py_DECREF(encoded_collection);
Py_DECREF(id_object);
return 0;
}
}
Py_DECREF(encoded_collection);
type_pos = buffer_save_bytes(buffer, 1);
if (type_pos == -1) {
Py_DECREF(id_object);
return 0;
}
if (!buffer_write_bytes(buffer, "$id\x00", 4)) {
Py_DECREF(id_object);
return 0;
}
if (!write_element_to_buffer(buffer, type_pos, id_object, check_keys)) {
Py_DECREF(id_object);
return 0;
}
Py_DECREF(id_object);
/* write null byte and fill in length */
if (!buffer_write_bytes(buffer, &zero, 1)) {
return 0;
}
length = buffer->position - start_position;
memcpy(buffer->buffer + length_location, &length, 4);
return 1;
}
else if (PyObject_HasAttrString(value, "pattern") &&
@ -1043,33 +973,22 @@ static PyObject* get_value(const char* buffer, int* position, int type) {
{
int size;
memcpy(&size, buffer + *position, 4);
if (strcmp(buffer + *position + 5, "$ref") == 0) { /* DBRef */
char id_type;
PyObject* id;
int offset = *position + 14;
int collection_length = strlen(buffer + offset);
PyObject* collection = PyUnicode_DecodeUTF8(buffer + offset, collection_length, "strict");
if (!collection) {
return NULL;
}
offset += collection_length + 1;
id_type = buffer[offset];
offset += 5;
id = get_value(buffer, &offset, (int)id_type);
if (!id) {
Py_DECREF(collection);
return NULL;
}
value = PyObject_CallFunctionObjArgs(DBRef, collection, id, NULL);
Py_DECREF(collection);
Py_DECREF(id);
} else {
value = elements_to_dict(buffer + *position + 4, size - 5);
if (!value) {
return NULL;
}
value = elements_to_dict(buffer + *position + 4, size - 5);
if (!value) {
return NULL;
}
/* Decoding for DBRefs */
if (strcmp(buffer + *position + 5, "$ref") == 0) { /* DBRef */
PyObject* id = PyDict_GetItemString(value, "$id");
PyObject* collection = PyDict_GetItemString(value, "$ref");
PyObject* database = PyDict_GetItemString(value, "$db");
/* This works even if there is no $db since database will be NULL and
the call will be as if there were only two arguments specified. */
value = PyObject_CallFunctionObjArgs(DBRef, collection, id, database, NULL);
}
*position += size;
break;
}

View File

@ -236,7 +236,7 @@ def _get_string(data):
def _get_object(data):
(object, data) = _bson_to_dict(data)
if "$ref" in object:
return (DBRef(object["$ref"], object["$id"]), data)
return (DBRef(object["$ref"], object["$id"], object.get("$db", None)), data)
return (object, data)
@ -456,10 +456,8 @@ def _element_to_bson(key, value, check_keys):
flags += "x"
return "\x0B" + name + _make_c_string(pattern) + _make_c_string(flags)
if isinstance(value, DBRef):
return _element_to_bson(key,
SON([("$ref", value.collection),
("$id", value.id)]),
False)
return _element_to_bson(key, value.as_doc(), False)
raise InvalidDocument("cannot convert value of type %s to bson" %
type(value))

View File

@ -377,14 +377,20 @@ class Database(object):
def dereference(self, dbref):
"""Dereference a DBRef, getting the SON object it points to.
Raises TypeError if dbref is not an instance of DBRef. Returns a SON
object or None if the reference does not point to a valid object.
Raises TypeError if `dbref` is not an instance of DBRef. Returns a SON
object or None if the reference does not point to a valid object. Raises
ValueError if `dbref` has a database specified that is different from
the current database.
:Parameters:
- `dbref`: the reference
"""
if not isinstance(dbref, DBRef):
raise TypeError("cannot dereference a %s" % type(dbref))
if dbref.database is not None and dbref.database != self.__name:
raise ValueError("trying to dereference a DBRef that points to "
"another database (%r not %r)" % (dbref.database,
self.__name))
return self[dbref.collection].find_one({"_id": dbref.id})
def eval(self, code, *args):

View File

@ -12,35 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for manipulating DBRefs (references to Mongo objects)."""
"""Tools for manipulating DBRefs (references to MongoDB documents)."""
import types
from son import SON
class DBRef(object):
"""A reference to an object stored in a Mongo database.
"""A reference to a document stored in a Mongo database.
"""
def __init__(self, collection, id):
def __init__(self, collection, id, database=None):
"""Initialize a new DBRef.
Raises TypeError if collection is not an instance of (str, unicode).
Raises TypeError if collection or database is not an instance of
(str, unicode). `database` is optional and allows references to
documents to work across databases.
:Parameters:
- `collection`: the collection the object is stored in
- `id`: the value of the object's _id field
- `collection`: name of the collection the document is stored in
- `id`: the value of the document's _id field
- `database` (optional): name of the database to reference
"""
if not isinstance(collection, types.StringTypes):
raise TypeError("collection must be an instance of (str, unicode)")
if isinstance(collection, types.StringType):
collection = unicode(collection, "utf-8")
if not isinstance(database, (types.StringTypes, types.NoneType)):
raise TypeError("database must be an instance of (str, unicode)")
self.__collection = collection
self.__id = id
self.__database = database
def collection(self):
"""Get this DBRef's collection as unicode.
"""Get the name of this DBRef's collection as unicode.
"""
return self.__collection
collection = property(collection)
@ -51,14 +56,35 @@ class DBRef(object):
return self.__id
id = property(id)
def database(self):
"""Get the name of this DBRef's database.
Returns None if this DBRef doesn't specify a database.
"""
return self.__database
database = property(database)
def as_doc(self):
"""Get the SON document representation of this DBRef.
Generally not needed by application developers
"""
doc = SON([("$ref", self.collection),
("$id", self.id)])
if self.database is not None:
doc["$db"] = self.database
return doc
def __repr__(self):
return "DBRef(" + repr(self.collection) + ", " + repr(self.id) + ")"
if self.database is None:
return "DBRef(%r, %r)" % (self.collection, self.id)
return "DBRef(%r, %r, %r)" % (self.collection, self.id, self.database)
def __cmp__(self, other):
if isinstance(other, DBRef):
return cmp([self.__collection, self.__id],
[other.__collection, other.__id])
return cmp([self.__database, self.__collection, self.__id],
[other.__database, other.__collection, other.__id])
return NotImplemented
def __hash__(self):
return hash((self.__collection, self.__id))
return hash((self.__collection, self.__id, self.__database))

View File

@ -24,17 +24,6 @@ import binascii
import base64
import types
try:
import xml.etree.ElementTree as ET
except ImportError:
import elementtree.ElementTree as ET
from code import Code
from binary import Binary
from objectid import ObjectId
from dbref import DBRef
from errors import UnsupportedTag
class SON(dict):
"""SON data.
@ -223,7 +212,19 @@ class SON(dict):
def from_xml(cls, xml):
"""Create an instance of SON from an xml document.
This is really only used for testing, and is probably unnecessary.
"""
try:
import xml.etree.ElementTree as ET
except ImportError:
import elementtree.ElementTree as ET
from code import Code
from binary import Binary
from objectid import ObjectId
from dbref import DBRef
from errors import UnsupportedTag
def pad(list, index):
while index >= len(list):

View File

@ -117,6 +117,10 @@ class AutoReference(SONManipulator):
only be auto-referenced if they have an *_ns* field.
NOTE: this will behave poorly if you have a circular reference.
TODO: this only works for documents that are in the same database. To fix
this we'll need to add a DatabaseInjector that adds *_db* and then make
use of the optional *database* support for DBRefs.
"""
def __init__(self, db):

View File

@ -166,6 +166,8 @@ class TestBSON(unittest.TestCase):
helper({"another binary": Binary("test")})
helper(SON([(u'test dst', datetime.datetime(1993, 4, 4, 2))]))
helper({"big float": float(10000000000)})
helper({"ref": DBRef("coll", 5)})
helper({"ref": DBRef("coll", 5, "foo")})
def from_then_to_dict(dict):
return dict == (BSON.from_dict(dict)).to_dict()

View File

@ -271,6 +271,8 @@ class TestDatabase(unittest.TestCase):
obj = {"x": True}
key = db.test.save(obj)
self.assertEqual(obj, db.dereference(DBRef("test", key)))
self.assertEqual(obj, db.dereference(DBRef("test", key, "pymongo_test")))
self.assertRaises(ValueError, db.dereference, DBRef("test", key, "foo"))
self.assertEqual(None, db.dereference(DBRef("test", 4)))
obj = {"_id": 4}

View File

@ -36,9 +36,11 @@ class TestDBRef(unittest.TestCase):
self.assertRaises(TypeError, DBRef, 1.5, a)
self.assertRaises(TypeError, DBRef, a, a)
self.assertRaises(TypeError, DBRef, None, a)
self.assertRaises(TypeError, DBRef, "coll", a, 5)
self.assert_(DBRef("coll", a))
self.assert_(DBRef(u"coll", a))
self.assert_(DBRef(u"coll", 5))
self.assert_(DBRef(u"coll", 5, "database"))
def test_read_only(self):
a = DBRef("coll", ObjectId())
@ -49,25 +51,35 @@ class TestDBRef(unittest.TestCase):
def bar():
a.id = "aoeu"
a.collection
self.assertEqual("coll", a.collection)
a.id
self.assertEqual(None, a.database)
self.assertRaises(AttributeError, foo)
self.assertRaises(AttributeError, bar)
def test_repr(self):
self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"))),
"DBRef(u'coll', ObjectId('1234567890abcdef12345678'))")
"DBRef('coll', ObjectId('1234567890abcdef12345678'))")
self.assertEqual(repr(DBRef(u"coll", ObjectId("1234567890abcdef12345678"))),
"DBRef(u'coll', ObjectId('1234567890abcdef12345678'))")
self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo")),
"DBRef('coll', ObjectId('1234567890abcdef12345678'), 'foo')")
def test_cmp(self):
self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
DBRef(u"coll", ObjectId("1234567890abcdef12345678")))
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "foo"))
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
DBRef("col", ObjectId("1234567890abcdef12345678")))
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
DBRef("coll", ObjectId("123456789011")))
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), 4)
self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"),
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "foo"))
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"),
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "bar"))
if __name__ == "__main__":
unittest.main()

View File

@ -21,7 +21,6 @@ sys.path[0:0] = [""]
import qcheck
from pymongo.objectid import ObjectId
from pymongo.dbref import DBRef
from pymongo.son import SON
from pymongo.son_manipulator import SONManipulator, ObjectIdInjector
from pymongo.son_manipulator import NamespaceInjector, ObjectIdShuffler