Add database support to DBRefs (as )
This commit is contained in:
parent
9bbeea4782
commit
b4cb9be0d3
@ -602,86 +602,16 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
|
||||
}
|
||||
return 1;
|
||||
} else if (PyObject_IsInstance(value, DBRef)) {
|
||||
int start_position,
|
||||
length_location,
|
||||
collection_length,
|
||||
type_pos,
|
||||
length;
|
||||
PyObject* collection_object;
|
||||
PyObject* encoded_collection;
|
||||
PyObject* id_object;
|
||||
char zero = 0;
|
||||
|
||||
PyObject* as_doc = PyObject_CallMethod(value, "as_doc", NULL);
|
||||
if (!as_doc) {
|
||||
return 0;
|
||||
}
|
||||
if (!write_dict(buffer, as_doc, 0)) {
|
||||
Py_DECREF(as_doc);
|
||||
return 0;
|
||||
}
|
||||
Py_DECREF(as_doc);
|
||||
*(buffer->buffer + type_byte) = 0x03;
|
||||
start_position = buffer->position;
|
||||
|
||||
/* save space for length */
|
||||
length_location = buffer_save_bytes(buffer, 4);
|
||||
if (length_location == -1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
collection_object = PyObject_GetAttrString(value, "collection");
|
||||
if (!collection_object) {
|
||||
return 0;
|
||||
}
|
||||
encoded_collection = PyUnicode_AsUTF8String(collection_object);
|
||||
Py_DECREF(collection_object);
|
||||
if (!encoded_collection) {
|
||||
return 0;
|
||||
}
|
||||
{
|
||||
const char* collection = PyString_AsString(encoded_collection);
|
||||
if (!collection) {
|
||||
Py_DECREF(encoded_collection);
|
||||
return 0;
|
||||
}
|
||||
id_object = PyObject_GetAttrString(value, "id");
|
||||
if (!id_object) {
|
||||
Py_DECREF(encoded_collection);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!buffer_write_bytes(buffer, "\x02$ref\x00", 6)) {
|
||||
Py_DECREF(encoded_collection);
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
collection_length = strlen(collection) + 1;
|
||||
if (!buffer_write_bytes(buffer, (const char*)&collection_length, 4)) {
|
||||
Py_DECREF(encoded_collection);
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
if (!buffer_write_bytes(buffer, collection, collection_length)) {
|
||||
Py_DECREF(encoded_collection);
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
Py_DECREF(encoded_collection);
|
||||
|
||||
type_pos = buffer_save_bytes(buffer, 1);
|
||||
if (type_pos == -1) {
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
if (!buffer_write_bytes(buffer, "$id\x00", 4)) {
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
if (!write_element_to_buffer(buffer, type_pos, id_object, check_keys)) {
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
Py_DECREF(id_object);
|
||||
|
||||
/* write null byte and fill in length */
|
||||
if (!buffer_write_bytes(buffer, &zero, 1)) {
|
||||
return 0;
|
||||
}
|
||||
length = buffer->position - start_position;
|
||||
memcpy(buffer->buffer + length_location, &length, 4);
|
||||
return 1;
|
||||
}
|
||||
else if (PyObject_HasAttrString(value, "pattern") &&
|
||||
@ -1043,33 +973,22 @@ static PyObject* get_value(const char* buffer, int* position, int type) {
|
||||
{
|
||||
int size;
|
||||
memcpy(&size, buffer + *position, 4);
|
||||
if (strcmp(buffer + *position + 5, "$ref") == 0) { /* DBRef */
|
||||
char id_type;
|
||||
PyObject* id;
|
||||
|
||||
int offset = *position + 14;
|
||||
int collection_length = strlen(buffer + offset);
|
||||
PyObject* collection = PyUnicode_DecodeUTF8(buffer + offset, collection_length, "strict");
|
||||
if (!collection) {
|
||||
return NULL;
|
||||
}
|
||||
offset += collection_length + 1;
|
||||
id_type = buffer[offset];
|
||||
offset += 5;
|
||||
id = get_value(buffer, &offset, (int)id_type);
|
||||
if (!id) {
|
||||
Py_DECREF(collection);
|
||||
return NULL;
|
||||
}
|
||||
value = PyObject_CallFunctionObjArgs(DBRef, collection, id, NULL);
|
||||
Py_DECREF(collection);
|
||||
Py_DECREF(id);
|
||||
} else {
|
||||
value = elements_to_dict(buffer + *position + 4, size - 5);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
value = elements_to_dict(buffer + *position + 4, size - 5);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* Decoding for DBRefs */
|
||||
if (strcmp(buffer + *position + 5, "$ref") == 0) { /* DBRef */
|
||||
PyObject* id = PyDict_GetItemString(value, "$id");
|
||||
PyObject* collection = PyDict_GetItemString(value, "$ref");
|
||||
PyObject* database = PyDict_GetItemString(value, "$db");
|
||||
|
||||
/* This works even if there is no $db since database will be NULL and
|
||||
the call will be as if there were only two arguments specified. */
|
||||
value = PyObject_CallFunctionObjArgs(DBRef, collection, id, database, NULL);
|
||||
}
|
||||
|
||||
*position += size;
|
||||
break;
|
||||
}
|
||||
|
||||
@ -236,7 +236,7 @@ def _get_string(data):
|
||||
def _get_object(data):
|
||||
(object, data) = _bson_to_dict(data)
|
||||
if "$ref" in object:
|
||||
return (DBRef(object["$ref"], object["$id"]), data)
|
||||
return (DBRef(object["$ref"], object["$id"], object.get("$db", None)), data)
|
||||
return (object, data)
|
||||
|
||||
|
||||
@ -456,10 +456,8 @@ def _element_to_bson(key, value, check_keys):
|
||||
flags += "x"
|
||||
return "\x0B" + name + _make_c_string(pattern) + _make_c_string(flags)
|
||||
if isinstance(value, DBRef):
|
||||
return _element_to_bson(key,
|
||||
SON([("$ref", value.collection),
|
||||
("$id", value.id)]),
|
||||
False)
|
||||
return _element_to_bson(key, value.as_doc(), False)
|
||||
|
||||
raise InvalidDocument("cannot convert value of type %s to bson" %
|
||||
type(value))
|
||||
|
||||
|
||||
@ -377,14 +377,20 @@ class Database(object):
|
||||
def dereference(self, dbref):
|
||||
"""Dereference a DBRef, getting the SON object it points to.
|
||||
|
||||
Raises TypeError if dbref is not an instance of DBRef. Returns a SON
|
||||
object or None if the reference does not point to a valid object.
|
||||
Raises TypeError if `dbref` is not an instance of DBRef. Returns a SON
|
||||
object or None if the reference does not point to a valid object. Raises
|
||||
ValueError if `dbref` has a database specified that is different from
|
||||
the current database.
|
||||
|
||||
:Parameters:
|
||||
- `dbref`: the reference
|
||||
"""
|
||||
if not isinstance(dbref, DBRef):
|
||||
raise TypeError("cannot dereference a %s" % type(dbref))
|
||||
if dbref.database is not None and dbref.database != self.__name:
|
||||
raise ValueError("trying to dereference a DBRef that points to "
|
||||
"another database (%r not %r)" % (dbref.database,
|
||||
self.__name))
|
||||
return self[dbref.collection].find_one({"_id": dbref.id})
|
||||
|
||||
def eval(self, code, *args):
|
||||
|
||||
@ -12,35 +12,40 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools for manipulating DBRefs (references to Mongo objects)."""
|
||||
"""Tools for manipulating DBRefs (references to MongoDB documents)."""
|
||||
|
||||
import types
|
||||
|
||||
from son import SON
|
||||
|
||||
|
||||
class DBRef(object):
|
||||
"""A reference to an object stored in a Mongo database.
|
||||
"""A reference to a document stored in a Mongo database.
|
||||
"""
|
||||
|
||||
def __init__(self, collection, id):
|
||||
def __init__(self, collection, id, database=None):
|
||||
"""Initialize a new DBRef.
|
||||
|
||||
Raises TypeError if collection is not an instance of (str, unicode).
|
||||
Raises TypeError if collection or database is not an instance of
|
||||
(str, unicode). `database` is optional and allows references to
|
||||
documents to work across databases.
|
||||
|
||||
:Parameters:
|
||||
- `collection`: the collection the object is stored in
|
||||
- `id`: the value of the object's _id field
|
||||
- `collection`: name of the collection the document is stored in
|
||||
- `id`: the value of the document's _id field
|
||||
- `database` (optional): name of the database to reference
|
||||
"""
|
||||
if not isinstance(collection, types.StringTypes):
|
||||
raise TypeError("collection must be an instance of (str, unicode)")
|
||||
|
||||
if isinstance(collection, types.StringType):
|
||||
collection = unicode(collection, "utf-8")
|
||||
if not isinstance(database, (types.StringTypes, types.NoneType)):
|
||||
raise TypeError("database must be an instance of (str, unicode)")
|
||||
|
||||
self.__collection = collection
|
||||
self.__id = id
|
||||
self.__database = database
|
||||
|
||||
def collection(self):
|
||||
"""Get this DBRef's collection as unicode.
|
||||
"""Get the name of this DBRef's collection as unicode.
|
||||
"""
|
||||
return self.__collection
|
||||
collection = property(collection)
|
||||
@ -51,14 +56,35 @@ class DBRef(object):
|
||||
return self.__id
|
||||
id = property(id)
|
||||
|
||||
def database(self):
|
||||
"""Get the name of this DBRef's database.
|
||||
|
||||
Returns None if this DBRef doesn't specify a database.
|
||||
"""
|
||||
return self.__database
|
||||
database = property(database)
|
||||
|
||||
def as_doc(self):
|
||||
"""Get the SON document representation of this DBRef.
|
||||
|
||||
Generally not needed by application developers
|
||||
"""
|
||||
doc = SON([("$ref", self.collection),
|
||||
("$id", self.id)])
|
||||
if self.database is not None:
|
||||
doc["$db"] = self.database
|
||||
return doc
|
||||
|
||||
def __repr__(self):
|
||||
return "DBRef(" + repr(self.collection) + ", " + repr(self.id) + ")"
|
||||
if self.database is None:
|
||||
return "DBRef(%r, %r)" % (self.collection, self.id)
|
||||
return "DBRef(%r, %r, %r)" % (self.collection, self.id, self.database)
|
||||
|
||||
def __cmp__(self, other):
|
||||
if isinstance(other, DBRef):
|
||||
return cmp([self.__collection, self.__id],
|
||||
[other.__collection, other.__id])
|
||||
return cmp([self.__database, self.__collection, self.__id],
|
||||
[other.__database, other.__collection, other.__id])
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.__collection, self.__id))
|
||||
return hash((self.__collection, self.__id, self.__database))
|
||||
|
||||
@ -24,17 +24,6 @@ import binascii
|
||||
import base64
|
||||
import types
|
||||
|
||||
try:
|
||||
import xml.etree.ElementTree as ET
|
||||
except ImportError:
|
||||
import elementtree.ElementTree as ET
|
||||
|
||||
from code import Code
|
||||
from binary import Binary
|
||||
from objectid import ObjectId
|
||||
from dbref import DBRef
|
||||
from errors import UnsupportedTag
|
||||
|
||||
|
||||
class SON(dict):
|
||||
"""SON data.
|
||||
@ -223,7 +212,19 @@ class SON(dict):
|
||||
|
||||
def from_xml(cls, xml):
|
||||
"""Create an instance of SON from an xml document.
|
||||
|
||||
This is really only used for testing, and is probably unnecessary.
|
||||
"""
|
||||
try:
|
||||
import xml.etree.ElementTree as ET
|
||||
except ImportError:
|
||||
import elementtree.ElementTree as ET
|
||||
|
||||
from code import Code
|
||||
from binary import Binary
|
||||
from objectid import ObjectId
|
||||
from dbref import DBRef
|
||||
from errors import UnsupportedTag
|
||||
|
||||
def pad(list, index):
|
||||
while index >= len(list):
|
||||
|
||||
@ -117,6 +117,10 @@ class AutoReference(SONManipulator):
|
||||
only be auto-referenced if they have an *_ns* field.
|
||||
|
||||
NOTE: this will behave poorly if you have a circular reference.
|
||||
|
||||
TODO: this only works for documents that are in the same database. To fix
|
||||
this we'll need to add a DatabaseInjector that adds *_db* and then make
|
||||
use of the optional *database* support for DBRefs.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
|
||||
@ -166,6 +166,8 @@ class TestBSON(unittest.TestCase):
|
||||
helper({"another binary": Binary("test")})
|
||||
helper(SON([(u'test dst', datetime.datetime(1993, 4, 4, 2))]))
|
||||
helper({"big float": float(10000000000)})
|
||||
helper({"ref": DBRef("coll", 5)})
|
||||
helper({"ref": DBRef("coll", 5, "foo")})
|
||||
|
||||
def from_then_to_dict(dict):
|
||||
return dict == (BSON.from_dict(dict)).to_dict()
|
||||
|
||||
@ -271,6 +271,8 @@ class TestDatabase(unittest.TestCase):
|
||||
obj = {"x": True}
|
||||
key = db.test.save(obj)
|
||||
self.assertEqual(obj, db.dereference(DBRef("test", key)))
|
||||
self.assertEqual(obj, db.dereference(DBRef("test", key, "pymongo_test")))
|
||||
self.assertRaises(ValueError, db.dereference, DBRef("test", key, "foo"))
|
||||
|
||||
self.assertEqual(None, db.dereference(DBRef("test", 4)))
|
||||
obj = {"_id": 4}
|
||||
|
||||
@ -36,9 +36,11 @@ class TestDBRef(unittest.TestCase):
|
||||
self.assertRaises(TypeError, DBRef, 1.5, a)
|
||||
self.assertRaises(TypeError, DBRef, a, a)
|
||||
self.assertRaises(TypeError, DBRef, None, a)
|
||||
self.assertRaises(TypeError, DBRef, "coll", a, 5)
|
||||
self.assert_(DBRef("coll", a))
|
||||
self.assert_(DBRef(u"coll", a))
|
||||
self.assert_(DBRef(u"coll", 5))
|
||||
self.assert_(DBRef(u"coll", 5, "database"))
|
||||
|
||||
def test_read_only(self):
|
||||
a = DBRef("coll", ObjectId())
|
||||
@ -49,25 +51,35 @@ class TestDBRef(unittest.TestCase):
|
||||
def bar():
|
||||
a.id = "aoeu"
|
||||
|
||||
a.collection
|
||||
self.assertEqual("coll", a.collection)
|
||||
a.id
|
||||
self.assertEqual(None, a.database)
|
||||
self.assertRaises(AttributeError, foo)
|
||||
self.assertRaises(AttributeError, bar)
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"))),
|
||||
"DBRef(u'coll', ObjectId('1234567890abcdef12345678'))")
|
||||
"DBRef('coll', ObjectId('1234567890abcdef12345678'))")
|
||||
self.assertEqual(repr(DBRef(u"coll", ObjectId("1234567890abcdef12345678"))),
|
||||
"DBRef(u'coll', ObjectId('1234567890abcdef12345678'))")
|
||||
self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo")),
|
||||
"DBRef('coll', ObjectId('1234567890abcdef12345678'), 'foo')")
|
||||
|
||||
def test_cmp(self):
|
||||
self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
|
||||
DBRef(u"coll", ObjectId("1234567890abcdef12345678")))
|
||||
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
|
||||
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "foo"))
|
||||
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
|
||||
DBRef("col", ObjectId("1234567890abcdef12345678")))
|
||||
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
|
||||
DBRef("coll", ObjectId("123456789011")))
|
||||
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")), 4)
|
||||
self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"),
|
||||
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "foo"))
|
||||
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"),
|
||||
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "bar"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -21,7 +21,6 @@ sys.path[0:0] = [""]
|
||||
|
||||
import qcheck
|
||||
from pymongo.objectid import ObjectId
|
||||
from pymongo.dbref import DBRef
|
||||
from pymongo.son import SON
|
||||
from pymongo.son_manipulator import SONManipulator, ObjectIdInjector
|
||||
from pymongo.son_manipulator import NamespaceInjector, ObjectIdShuffler
|
||||
|
||||
Loading…
Reference in New Issue
Block a user