diff --git a/bson/buffer.c b/bson/buffer.c index e3a60a1ac..19d528cee 100644 --- a/bson/buffer.c +++ b/bson/buffer.c @@ -140,3 +140,7 @@ int buffer_get_position(buffer_t buffer) { char* buffer_get_buffer(buffer_t buffer) { return buffer->buffer; } + +void buffer_update_position(buffer_t buffer, buffer_position new_position) { + buffer->position = new_position; +} diff --git a/bson/buffer.h b/bson/buffer.h index c47d655c9..23a46d95c 100644 --- a/bson/buffer.h +++ b/bson/buffer.h @@ -51,5 +51,6 @@ int buffer_write_at_position(buffer_t buffer, buffer_position position, const ch * since they break the abstraction. */ buffer_position buffer_get_position(buffer_t buffer); char* buffer_get_buffer(buffer_t buffer); +void buffer_update_position(buffer_t buffer, buffer_position new_position); #endif diff --git a/pymongo/_cmessagemodule.c b/pymongo/_cmessagemodule.c index ea0d565bd..a7aff06a3 100644 --- a/pymongo/_cmessagemodule.c +++ b/pymongo/_cmessagemodule.c @@ -131,10 +131,36 @@ static int add_last_error(PyObject* self, buffer_t buffer, return 1; } +static int init_insert_buffer(buffer_t buffer, int request_id, int options, + const char* coll_name, int coll_name_len) { + /* Save space for message length */ + int length_location = buffer_save_space(buffer, 4); + if (length_location == -1) { + PyErr_NoMemory(); + return length_location; + } + if (!buffer_write_bytes(buffer, (const char*)&request_id, 4) || + !buffer_write_bytes(buffer, + "\x00\x00\x00\x00" + "\xd2\x07\x00\x00", + 8) || + !buffer_write_bytes(buffer, (const char*)&options, 4) || + !buffer_write_bytes(buffer, + coll_name, + coll_name_len + 1)) { + return -1; + } + return length_location; +} + static PyObject* _cbson_insert_message(PyObject* self, PyObject* args) { - /* NOTE just using a random number as the request_id */ + /* Note: As of PyMongo 2.6, this function is no longer used. It + * is being kept (with tests) for backwards compatibility with 3rd + * party libraries that may currently be using it, but will likely + * be removed in a future release. */ struct module_state *state = GETSTATE(self); + /* NOTE just using a random number as the request_id */ int request_id = rand(); char* collection_name = NULL; int collection_name_length; @@ -172,23 +198,13 @@ static PyObject* _cbson_insert_message(PyObject* self, PyObject* args) { return NULL; } - // save space for message length - length_location = buffer_save_space(buffer, 4); + length_location = init_insert_buffer(buffer, + request_id, + options, + collection_name, + collection_name_length); if (length_location == -1) { PyMem_Free(collection_name); - PyErr_NoMemory(); - return NULL; - } - if (!buffer_write_bytes(buffer, (const char*)&request_id, 4) || - !buffer_write_bytes(buffer, - "\x00\x00\x00\x00" - "\xd2\x07\x00\x00", - 8) || - !buffer_write_bytes(buffer, (const char*)&options, 4) || - !buffer_write_bytes(buffer, - collection_name, - collection_name_length + 1)) { - PyMem_Free(collection_name); buffer_free(buffer); return NULL; } @@ -259,6 +275,14 @@ static PyObject* _cbson_insert_message(PyObject* self, PyObject* args) { return result; } +PyDoc_STRVAR(_cbson_insert_message_doc, +"Create an insert message to be sent to MongoDB\n\ +\n\ +Note: As of PyMongo 2.6, this function is no longer used. It\n\ +is being kept (with tests) for backwards compatibility with 3rd\n\ +party libraries that may currently be using it, but will likely\n\ +be removed in a future release."); + static PyObject* _cbson_update_message(PyObject* self, PyObject* args) { /* NOTE just using a random number as the request_id */ struct module_state *state = GETSTATE(self); @@ -512,15 +536,290 @@ static PyObject* _cbson_get_more_message(PyObject* self, PyObject* args) { return result; } +static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { + struct module_state *state = GETSTATE(self); + + /* NOTE just using a random number as the request_id */ + int request_id = rand(); + int options = 0, max_size = 0; + int length_location, message_length; + int collection_name_length; + char* collection_name = NULL; + PyObject* docs; + PyObject* doc; + PyObject* iterator; + PyObject* client; + PyObject* last_error_args; + PyObject* result; + unsigned char check_keys; + unsigned char safe; + unsigned char continue_on_error; + unsigned char uuid_subtype; + long max_bson_size; + long max_message_size; + buffer_t buffer; + PyObject *exc_type = NULL, *exc_value = NULL, *exc_trace = NULL; + + if (!PyArg_ParseTuple(args, "et#ObbObbO", + "utf-8", + &collection_name, + &collection_name_length, + &docs, &check_keys, &safe, + &last_error_args, + &continue_on_error, + &uuid_subtype, &client)) { + return NULL; + } + if (continue_on_error) { + options += 1; + } + +#if PY_MAJOR_VERSION >= 3 + max_bson_size = PyLong_AsLong( + PyObject_GetAttrString(client, "max_bson_size")); +#else + max_bson_size = PyInt_AsLong( + PyObject_GetAttrString(client, "max_bson_size")); +#endif + if (max_bson_size == -1) { + PyMem_Free(collection_name); + return NULL; + } +#if PY_MAJOR_VERSION >= 3 + max_message_size = PyLong_AsLong( + PyObject_GetAttrString(client, "max_message_size")); +#else + max_message_size = PyInt_AsLong( + PyObject_GetAttrString(client, "max_message_size")); +#endif + if (max_message_size == -1) { + PyMem_Free(collection_name); + return NULL; + } + + buffer = buffer_new(); + if (!buffer) { + PyErr_NoMemory(); + PyMem_Free(collection_name); + return NULL; + } + + length_location = init_insert_buffer(buffer, + request_id, + options, + collection_name, + collection_name_length); + if (length_location == -1) { + goto insertfail; + } + + iterator = PyObject_GetIter(docs); + if (iterator == NULL) { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { + PyErr_SetString(InvalidOperation, "input is not iterable"); + Py_DECREF(InvalidOperation); + } + goto insertfail; + } + while ((doc = PyIter_Next(iterator)) != NULL) { + int before = buffer_get_position(buffer); + int cur_size; + if (!write_dict(state->_cbson, buffer, doc, check_keys, uuid_subtype, 1)) { + Py_DECREF(doc); + goto iterfail; + } + Py_DECREF(doc); + + cur_size = buffer_get_position(buffer) - before; + max_size = (cur_size > max_size) ? cur_size : max_size; + if (cur_size > max_bson_size) { + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { + const char* msg = ("BSON document too large (%ld bytes)" + " - the connected server supports" + " BSON document sizes up to %ld bytes."); +#if PY_MAJOR_VERSION >= 3 + PyObject* error = PyUnicode_FromFormat(msg, + cur_size, max_bson_size); +#else + PyObject* error = PyString_FromFormat(msg, + cur_size, max_bson_size); +#endif + if (error) { + PyErr_SetObject(InvalidDocument, error); + Py_DECREF(error); + } + Py_DECREF(InvalidDocument); + } + goto iterfail; + } + + /* We have enough data, send this batch. */ + if (buffer_get_position(buffer) > max_message_size) { + int new_request_id = rand(); + int message_start; + PyObject* send_gle = Py_False; + buffer_t new_buffer = buffer_new(); + if (!new_buffer) { + PyErr_NoMemory(); + goto iterfail; + } + message_start = init_insert_buffer(new_buffer, + new_request_id, + options, + collection_name, + collection_name_length); + if (message_start == -1) { + buffer_free(new_buffer); + goto iterfail; + } + + /* Copy the overflow encoded document into the new buffer. */ + if (!buffer_write_bytes(new_buffer, + (const char*)buffer_get_buffer(buffer) + before, cur_size)) { + buffer_free(new_buffer); + goto iterfail; + } + + /* Roll back to the beginning of this document. */ + buffer_update_position(buffer, before); + message_length = buffer_get_position(buffer) - length_location; + memcpy(buffer_get_buffer(buffer) + length_location, &message_length, 4); + + /* If we are doing unacknowledged writes *and* continue_on_error + * is True it's pointless (and slower) to send GLE. */ + if (safe || !continue_on_error) { + send_gle = Py_True; + if (!add_last_error(self, buffer, request_id, collection_name, + collection_name_length, last_error_args)) { + buffer_free(new_buffer); + goto iterfail; + } + } + /* Objectify buffer */ + result = Py_BuildValue("i" BYTES_FORMAT_STRING, request_id, + buffer_get_buffer(buffer), + buffer_get_position(buffer)); + buffer_free(buffer); + buffer = new_buffer; + request_id = new_request_id; + length_location = message_start; + + if (!PyObject_CallMethod(client, + "_send_message", "NO", result, send_gle)) { + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; + PyErr_Fetch(&etype, &evalue, &etrace); + PyObject* OperationFailure = _error("OperationFailure"); + if (OperationFailure) { + if (PyErr_GivenExceptionMatches(etype, OperationFailure)) { + if (!safe || continue_on_error) { + Py_DECREF(OperationFailure); + if (!safe) { + /* We're doing unacknowledged writes and + * continue_on_error is False. Just return. */ + Py_DECREF(etype); + Py_DECREF(evalue); + Py_DECREF(etrace); + Py_DECREF(iterator); + buffer_free(buffer); + PyMem_Free(collection_name); + Py_RETURN_NONE; + } + /* continue_on_error is True, store the error + * details to re-raise after the final batch */ + Py_XDECREF(exc_type); + Py_XDECREF(exc_value); + Py_XDECREF(exc_trace); + exc_type = etype; + exc_value = evalue; + exc_trace = etrace; + continue; + } + } + Py_DECREF(OperationFailure); + } + /* This isn't OperationFailure, we couldn't + * import OperationFailure, or we are doing + * acknowledged writes. Re-raise immediately. */ + PyErr_Restore(etype, evalue, etrace); + goto iterfail; + } + } + } + Py_DECREF(iterator); + + if (PyErr_Occurred()) { + goto insertfail; + } + + if (!max_size) { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { + PyErr_SetString(InvalidOperation, "cannot do an empty bulk insert"); + Py_DECREF(InvalidOperation); + } + goto insertfail; + } + + message_length = buffer_get_position(buffer) - length_location; + memcpy(buffer_get_buffer(buffer) + length_location, &message_length, 4); + + if (safe) { + if (!add_last_error(self, buffer, request_id, collection_name, + collection_name_length, last_error_args)) { + goto insertfail; + } + } + + PyMem_Free(collection_name); + + /* objectify buffer */ + result = Py_BuildValue("i" BYTES_FORMAT_STRING, request_id, + buffer_get_buffer(buffer), + buffer_get_position(buffer)); + buffer_free(buffer); + + /* Send the last (or only) batch */ + if (!PyObject_CallMethod(client, "_send_message", "NN", + result, PyBool_FromLong((long)safe))) { + Py_XDECREF(exc_type); + Py_XDECREF(exc_value); + Py_XDECREF(exc_trace); + return NULL; + } + + if (exc_type) { + /* Re-raise any previously stored exception + * due to continue_on_error being True */ + PyErr_Restore(exc_type, exc_value, exc_trace); + return NULL; + } + + Py_RETURN_NONE; + +iterfail: + Py_DECREF(iterator); +insertfail: + Py_XDECREF(exc_type); + Py_XDECREF(exc_value); + Py_XDECREF(exc_trace); + buffer_free(buffer); + PyMem_Free(collection_name); + return NULL; +} + static PyMethodDef _CMessageMethods[] = { {"_insert_message", _cbson_insert_message, METH_VARARGS, - "create an insert message to be sent to MongoDB"}, + _cbson_insert_message_doc}, {"_update_message", _cbson_update_message, METH_VARARGS, "create an update message to be sent to MongoDB"}, {"_query_message", _cbson_query_message, METH_VARARGS, "create a query message to be sent to MongoDB"}, {"_get_more_message", _cbson_get_more_message, METH_VARARGS, "create a get more message to be sent to MongoDB"}, + {"_do_batched_insert", _cbson_do_batched_insert, METH_VARARGS, + "insert a batch of documents, splitting the batch as needed"}, {NULL, NULL, 0, NULL} }; diff --git a/pymongo/collection.py b/pymongo/collection.py index dad8e780e..fbb007a69 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -351,10 +351,10 @@ class Collection(common.BaseObject): docs = [self.__database._fix_incoming(doc, self) for doc in docs] safe, options = self._get_write_mode(safe, **kwargs) - self.__database.connection._send_message( - message.insert(self.__full_name, docs, - check_keys, safe, options, - continue_on_error, self.__uuid_subtype), safe) + message._do_batched_insert(self.__full_name, docs, + check_keys, safe, options, + continue_on_error, self.__uuid_subtype, + self.database.connection) ids = [doc.get("_id", None) for doc in docs] if return_one: diff --git a/pymongo/message.py b/pymongo/message.py index 81a2517fc..25be90d27 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -34,7 +34,7 @@ try: _use_c = True except ImportError: _use_c = False -from pymongo.errors import InvalidOperation +from pymongo.errors import InvalidDocument, InvalidOperation, OperationFailure __ZERO = b("\x00\x00\x00\x00") @@ -70,6 +70,12 @@ def __pack_message(operation, data): def insert(collection_name, docs, check_keys, safe, last_error_args, continue_on_error, uuid_subtype): """Get an **insert** message. + + .. note:: As of PyMongo 2.6, this function is no longer used. It + is being kept (with tests) for backwards compatibility with 3rd + party libraries that may currently be using it, but will likely + be removed in a future release. + """ options = 0 if continue_on_error: @@ -181,3 +187,68 @@ def kill_cursors(cursor_ids): for cursor_id in cursor_ids: data += struct.pack(" client.max_bson_size: + raise InvalidDocument("BSON document too large (%d bytes)" + " - the connected server supports" + " BSON document sizes up to %d" + " bytes." % + (encoded_length, client.max_bson_size)) + message_length += encoded_length + if message_length < client.max_message_size: + data.append(encoded) + continue + + # We have enough data, send this message. + send_safe = safe or not continue_on_error + try: + client._send_message(_insert_message(EMPTY.join(data), + send_safe), send_safe) + # Exception type could be OperationFailure or a subtype + # (e.g. DuplicateKeyError) + except OperationFailure, exc: + # Like it says, continue on error... + if continue_on_error: + # Store exception details to re-raise after the final batch. + last_error = exc + # With unacknowledged writes just return at the first error. + elif not safe: + return + # With acknowledged writes raise immediately. + else: + raise + message_length = len(begin) + encoded_length + data = [begin, encoded] + + client._send_message(_insert_message(EMPTY.join(data), safe), safe) + + # Re-raise any exception stored due to continue_on_error + if last_error is not None: + raise last_error +if _use_c: + _do_batched_insert = _cmessage._do_batched_insert diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 7b229b3ff..35452493d 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -580,6 +580,15 @@ class MongoClient(common.BaseObject): """ return self.__max_bson_size + @property + def max_message_size(self): + """Return the maximum message size the connected server + accepts in bytes. + + .. versionadded:: 2.6 + """ + return self.__max_message_size + def __simple_command(self, sock_info, dbname, spec): """Send a command to the server. """ @@ -622,6 +631,10 @@ class MongoClient(common.BaseObject): if "maxBsonObjectSize" in response: self.__max_bson_size = response["maxBsonObjectSize"] + if "maxMessageSizeBytes" in response: + self.__max_message_size = response["maxMessageSizeBytes"] + else: + self.__max_message_size = 2 * self.max_bson_size # Replica Set? if not self.__direct: diff --git a/pymongo/mongo_replica_set_client.py b/pymongo/mongo_replica_set_client.py index fd4b5e064..ba10170e1 100644 --- a/pymongo/mongo_replica_set_client.py +++ b/pymongo/mongo_replica_set_client.py @@ -413,6 +413,8 @@ class Member(object): self.tags = ismaster_response.get('tags', {}) self.max_bson_size = ismaster_response.get( 'maxBsonObjectSize', MAX_BSON_SIZE) + self.max_message_size = ismaster_response.get( + 'maxMessageSizeBytes', 2 * self.max_bson_size) def clone_with(self, ismaster_response, ping_time_sample): """Get a clone updated with ismaster response and a single ping time. @@ -985,6 +987,16 @@ class MongoReplicaSetClient(common.BaseObject): return rs_state.primary_member.max_bson_size return 0 + @property + def max_message_size(self): + """Returns the maximum message size the connected primary + accepts in bytes. Returns 0 if no primary is available. + """ + rs_state = self.__rs_state + if rs_state.primary_member: + return rs_state.primary_member.max_message_size + return 0 + @property def auto_start_request(self): """Is auto_start_request enabled? diff --git a/test/test_collection.py b/test/test_collection.py index e245c1e25..c9cb2fc9e 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -36,6 +36,7 @@ from bson.py3compat import b from bson.son import SON from pymongo import (ASCENDING, DESCENDING, GEO2D, GEOHAYSTACK, GEOSPHERE, HASHED) +from pymongo import message as message_module from pymongo.collection import Collection from pymongo.cursor import Cursor from pymongo.son_manipulator import SONManipulator @@ -1674,6 +1675,85 @@ class TestCollection(unittest.TestCase): {"bar": "x"}, {"bar": "x" * (max_size - 14)}) self.db.test.update({"bar": "x"}, {"bar": "x" * (max_size - 15)}) + def test_insert_large_batch(self): + max_bson_size = self.db.connection.max_bson_size + big_string = 'x' * (max_bson_size - 100) + self.db.test.drop() + self.assertEqual(0, self.db.test.count()) + + # Batch insert that requires 2 batches + batch = [{'x': big_string}, {'x': big_string}, + {'x': big_string}, {'x': big_string}] + self.assertTrue(self.db.test.insert(batch, w=1)) + self.assertEqual(4, self.db.test.count()) + + batch[1]['_id'] = batch[0]['_id'] + + # Test that inserts fail after first error, acknowledged. + self.db.test.drop() + self.assertRaises(DuplicateKeyError, self.db.test.insert, batch, w=1) + self.assertEqual(1, self.db.test.count()) + + # Test that inserts fail after first error, unacknowledged. + self.db.test.drop() + self.assertTrue(self.db.test.insert(batch, w=0)) + self.assertEqual(1, self.db.test.count()) + + # 2 batches, 2 errors, acknowledged, continue on error + self.db.test.drop() + batch[3]['_id'] = batch[2]['_id'] + try: + self.db.test.insert(batch, continue_on_error=True, w=1) + except OperationFailure, e: + # Make sure we report the last error, not the first. + self.assertTrue(str(batch[2]['_id']) in str(e)) + else: + self.fail('OpreationFailure not raised.') + # Only the first and third documents should be inserted. + self.assertEqual(2, self.db.test.count()) + + # 2 batches, 2 errors, unacknowledged, continue on error + self.db.test.drop() + self.assertTrue(self.db.test.insert(batch, continue_on_error=True, w=0)) + # Only the first and third documents should be inserted. + self.assertEqual(2, self.db.test.count()) + + # Starting in PyMongo 2.6 we no longer use message.insert for inserts, but + # message.insert is part of the public API. Do minimal testing here; there + # isn't really a better place. + def test_insert_message_creation(self): + send = self.db.connection._send_message + name = "%s.%s" % (self.db.name, "test") + + def do_insert(args): + send(message_module.insert(*args), args[3]) + + self.db.drop_collection("test") + self.db.test.insert({'_id': 0}, w=1) + self.assertTrue(1, self.db.test.count()) + + simple_args = (name, [{'_id': 0}], True, False, {}, False, 3) + gle_args = (name, [{'_id': 0}], True, True, {'w': 1}, False, 3) + coe_args = (name, [{'_id': 0}, {'_id': 1}], + True, True, {'w': 1}, True, 3) + + self.assertEqual(None, do_insert(simple_args)) + self.assertTrue(1, self.db.test.count()) + self.assertRaises(DuplicateKeyError, do_insert, gle_args) + self.assertTrue(1, self.db.test.count()) + self.assertRaises(DuplicateKeyError, do_insert, coe_args) + self.assertTrue(2, self.db.test.count()) + + if have_uuid: + doc = {'_id': 2, 'uuid': uuid.uuid4()} + uuid_sub_args = (name, [doc], + True, True, {'w': 1}, True, 6) + do_insert(uuid_sub_args) + coll = self.db.test + self.assertNotEqual(doc, coll.find_one({'_id': 2})) + coll.uuid_subtype = 6 + self.assertEqual(doc, coll.find_one({'_id': 2})) + def test_map_reduce(self): if not version.at_least(self.db.connection, (1, 1, 1)): raise SkipTest("mapReduce command requires MongoDB >= 1.1.1")