handle dbrefs in c bson module as well
This commit is contained in:
parent
8578b20e64
commit
157e032892
@ -363,6 +363,15 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
|
||||
*(buffer->buffer + type_byte) = 0x07;
|
||||
return 1;
|
||||
} else if (PyObject_IsInstance(value, DBRef)) {
|
||||
*(buffer->buffer + type_byte) = 0x03;
|
||||
int start_position = buffer->position;
|
||||
|
||||
// save space for length
|
||||
int length_location = buffer_save_bytes(buffer, 4);
|
||||
if (length_location == -1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
PyObject* collection_object = PyObject_GetAttrString(value, "collection");
|
||||
if (!collection_object) {
|
||||
return 0;
|
||||
@ -382,36 +391,47 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
|
||||
Py_DECREF(encoded_collection);
|
||||
return 0;
|
||||
}
|
||||
PyObject* id_str = PyObject_Str(id_object);
|
||||
Py_DECREF(id_object);
|
||||
if (!id_str) {
|
||||
|
||||
if (!buffer_write_bytes(buffer, "\x02$ref\x00", 6)) {
|
||||
Py_DECREF(encoded_collection);
|
||||
return 0;
|
||||
}
|
||||
const char* id = PyString_AsString(id_str);
|
||||
if (!id) {
|
||||
Py_DECREF(encoded_collection);
|
||||
Py_DECREF(id_str);
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
int collection_length = strlen(collection) + 1;
|
||||
if (!buffer_write_bytes(buffer, (const char*)&collection_length, 4)) {
|
||||
Py_DECREF(encoded_collection);
|
||||
Py_DECREF(id_str);
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
if (!buffer_write_bytes(buffer, collection, collection_length)) {
|
||||
Py_DECREF(encoded_collection);
|
||||
Py_DECREF(id_str);
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
Py_DECREF(encoded_collection);
|
||||
if (!write_shuffled_oid(buffer, id)) {
|
||||
Py_DECREF(id_str);
|
||||
|
||||
int type_pos = buffer_save_bytes(buffer, 1);
|
||||
if (type_pos == -1) {
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
Py_DECREF(id_str);
|
||||
*(buffer->buffer + type_byte) = 0x0C;
|
||||
if (!buffer_write_bytes(buffer, "$id\x00", 4)) {
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
if (!write_element_to_buffer(buffer, type_pos, id_object)) {
|
||||
Py_DECREF(id_object);
|
||||
return 0;
|
||||
}
|
||||
Py_DECREF(id_object);
|
||||
|
||||
// write null byte and fill in length
|
||||
char zero = 0;
|
||||
if (!buffer_write_bytes(buffer, &zero, 1)) {
|
||||
return 0;
|
||||
}
|
||||
int length = buffer->position - start_position;
|
||||
memcpy(buffer->buffer + length_location, &length, 4);
|
||||
return 1;
|
||||
}
|
||||
else if (PyObject_HasAttrString(value, "pattern") &&
|
||||
@ -626,9 +646,29 @@ static PyObject* get_value(const char* buffer, int* position, int type) {
|
||||
{
|
||||
int size;
|
||||
memcpy(&size, buffer + *position, 4);
|
||||
value = elements_to_dict(buffer + *position + 4, size - 5);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
if (strcmp(buffer + *position + 5, "$ref") == 0) { // DBRef
|
||||
int offset = *position + 14;
|
||||
int collection_length = strlen(buffer + offset);
|
||||
PyObject* collection = PyUnicode_DecodeUTF8(buffer + offset, collection_length, "strict");
|
||||
if (!collection) {
|
||||
return NULL;
|
||||
}
|
||||
offset += collection_length + 1;
|
||||
char id_type = buffer[offset];
|
||||
offset += 5;
|
||||
PyObject* id = get_value(buffer, &offset, (int)id_type);
|
||||
if (!id) {
|
||||
Py_DECREF(collection);
|
||||
return NULL;
|
||||
}
|
||||
value = PyObject_CallFunctionObjArgs(DBRef, collection, id, NULL);
|
||||
Py_DECREF(collection);
|
||||
Py_DECREF(id);
|
||||
} else {
|
||||
value = elements_to_dict(buffer + *position + 4, size - 5);
|
||||
if (!value) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
*position += size;
|
||||
break;
|
||||
|
||||
@ -26,7 +26,7 @@ from pymongo.dbref import DBRef
|
||||
from pymongo.son import SON
|
||||
|
||||
gen_target = 100
|
||||
reduction_attempts = 100
|
||||
reduction_attempts = 10
|
||||
examples = 5
|
||||
|
||||
def lift(value):
|
||||
|
||||
@ -122,7 +122,7 @@ class TestBSON(unittest.TestCase):
|
||||
def from_then_to_dict(dict):
|
||||
return dict == (BSON.from_dict(dict)).to_dict()
|
||||
|
||||
qcheck.check_unittest(self, from_then_to_dict, qcheck.gen_mongo_dict(3, False))
|
||||
qcheck.check_unittest(self, from_then_to_dict, qcheck.gen_mongo_dict(3))
|
||||
|
||||
def test_data_files(self):
|
||||
# TODO don't hardcode this, actually clone the repo
|
||||
|
||||
Loading…
Reference in New Issue
Block a user