From 157e032892dcd0834aed7847aebffc35da46eff3 Mon Sep 17 00:00:00 2001 From: Mike Dirolf Date: Wed, 18 Mar 2009 11:53:03 -0400 Subject: [PATCH] handle dbrefs in c bson module as well --- pymongo/_cbsonmodule.c | 76 ++++++++++++++++++++++++++++++++---------- test/qcheck.py | 2 +- test/test_bson.py | 2 +- 3 files changed, 60 insertions(+), 20 deletions(-) diff --git a/pymongo/_cbsonmodule.c b/pymongo/_cbsonmodule.c index 8485e1fdc..8e9ff760c 100644 --- a/pymongo/_cbsonmodule.c +++ b/pymongo/_cbsonmodule.c @@ -363,6 +363,15 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject* *(buffer->buffer + type_byte) = 0x07; return 1; } else if (PyObject_IsInstance(value, DBRef)) { + *(buffer->buffer + type_byte) = 0x03; + int start_position = buffer->position; + + // save space for length + int length_location = buffer_save_bytes(buffer, 4); + if (length_location == -1) { + return 0; + } + PyObject* collection_object = PyObject_GetAttrString(value, "collection"); if (!collection_object) { return 0; @@ -382,36 +391,47 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject* Py_DECREF(encoded_collection); return 0; } - PyObject* id_str = PyObject_Str(id_object); - Py_DECREF(id_object); - if (!id_str) { + + if (!buffer_write_bytes(buffer, "\x02$ref\x00", 6)) { Py_DECREF(encoded_collection); - return 0; - } - const char* id = PyString_AsString(id_str); - if (!id) { - Py_DECREF(encoded_collection); - Py_DECREF(id_str); + Py_DECREF(id_object); return 0; } int collection_length = strlen(collection) + 1; if (!buffer_write_bytes(buffer, (const char*)&collection_length, 4)) { Py_DECREF(encoded_collection); - Py_DECREF(id_str); + Py_DECREF(id_object); return 0; } if (!buffer_write_bytes(buffer, collection, collection_length)) { Py_DECREF(encoded_collection); - Py_DECREF(id_str); + Py_DECREF(id_object); return 0; } Py_DECREF(encoded_collection); - if (!write_shuffled_oid(buffer, id)) { - Py_DECREF(id_str); + + int type_pos = buffer_save_bytes(buffer, 1); + if (type_pos == -1) { + Py_DECREF(id_object); return 0; } - Py_DECREF(id_str); - *(buffer->buffer + type_byte) = 0x0C; + if (!buffer_write_bytes(buffer, "$id\x00", 4)) { + Py_DECREF(id_object); + return 0; + } + if (!write_element_to_buffer(buffer, type_pos, id_object)) { + Py_DECREF(id_object); + return 0; + } + Py_DECREF(id_object); + + // write null byte and fill in length + char zero = 0; + if (!buffer_write_bytes(buffer, &zero, 1)) { + return 0; + } + int length = buffer->position - start_position; + memcpy(buffer->buffer + length_location, &length, 4); return 1; } else if (PyObject_HasAttrString(value, "pattern") && @@ -626,9 +646,29 @@ static PyObject* get_value(const char* buffer, int* position, int type) { { int size; memcpy(&size, buffer + *position, 4); - value = elements_to_dict(buffer + *position + 4, size - 5); - if (!value) { - return NULL; + if (strcmp(buffer + *position + 5, "$ref") == 0) { // DBRef + int offset = *position + 14; + int collection_length = strlen(buffer + offset); + PyObject* collection = PyUnicode_DecodeUTF8(buffer + offset, collection_length, "strict"); + if (!collection) { + return NULL; + } + offset += collection_length + 1; + char id_type = buffer[offset]; + offset += 5; + PyObject* id = get_value(buffer, &offset, (int)id_type); + if (!id) { + Py_DECREF(collection); + return NULL; + } + value = PyObject_CallFunctionObjArgs(DBRef, collection, id, NULL); + Py_DECREF(collection); + Py_DECREF(id); + } else { + value = elements_to_dict(buffer + *position + 4, size - 5); + if (!value) { + return NULL; + } } *position += size; break; diff --git a/test/qcheck.py b/test/qcheck.py index e46e48511..8dd8cc6a2 100644 --- a/test/qcheck.py +++ b/test/qcheck.py @@ -26,7 +26,7 @@ from pymongo.dbref import DBRef from pymongo.son import SON gen_target = 100 -reduction_attempts = 100 +reduction_attempts = 10 examples = 5 def lift(value): diff --git a/test/test_bson.py b/test/test_bson.py index fc72a3128..8bacffe95 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -122,7 +122,7 @@ class TestBSON(unittest.TestCase): def from_then_to_dict(dict): return dict == (BSON.from_dict(dict)).to_dict() - qcheck.check_unittest(self, from_then_to_dict, qcheck.gen_mongo_dict(3, False)) + qcheck.check_unittest(self, from_then_to_dict, qcheck.gen_mongo_dict(3)) def test_data_files(self): # TODO don't hardcode this, actually clone the repo