From b2d1034989b90cd64b165ea068bdc7e91e3fbb3c Mon Sep 17 00:00:00 2001 From: Bernie Hackett Date: Mon, 25 Nov 2013 18:21:04 -0800 Subject: [PATCH] Use commands for write operations PYTHON-554 --- bson/_cbsonmodule.c | 28 -- bson/_cbsonmodule.h | 26 ++ pymongo/_cmessagemodule.c | 419 ++++++++++++++++++++++++++-- pymongo/collection.py | 109 +++++++- pymongo/helpers.py | 4 + pymongo/master_slave_connection.py | 8 +- pymongo/message.py | 167 ++++++++++- pymongo/mongo_client.py | 10 +- pymongo/mongo_replica_set_client.py | 11 +- test/test_collection.py | 40 ++- 10 files changed, 727 insertions(+), 95 deletions(-) diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index b22359ac3..c8901857c 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -67,34 +67,6 @@ static struct module_state _state; /* Maximum number of regex flags */ #define FLAGS_SIZE 7 -#if defined(WIN32) || defined(_MSC_VER) -/* This macro is basically an implementation of asprintf for win32 - * We get the length of the int as string and malloc a buffer for it, - * returning -1 if that malloc fails. We then actually print to the - * buffer to get the string value as an int. Like asprintf, the result - * must be explicitly free'd when done being used. - */ -#if defined(_MSC_VER) && (_MSC_VER >= 1400) -#define INT2STRING(buffer, i) \ - _snprintf_s((buffer), \ - _scprintf("%d", (i)) + 1, \ - _scprintf("%d", (i)) + 1, \ - "%d", \ - (i)) -#define STRCAT(dest, n, src) strcat_s((dest), (n), (src)) -#else -#define INT2STRING(buffer, i) \ - _snprintf((buffer), \ - _scprintf("%d", (i)) + 1, \ - "%d", \ - (i)) -#define STRCAT(dest, n, src) strcat((dest), (src)) -#endif -#else -#define INT2STRING(buffer, i) snprintf((buffer), sizeof((buffer)), "%d", (i)) -#define STRCAT(dest, n, src) strcat((dest), (src)) -#endif - #define JAVA_LEGACY 5 #define CSHARP_LEGACY 6 #define BSON_MAX_SIZE 2147483647 diff --git a/bson/_cbsonmodule.h b/bson/_cbsonmodule.h index be69eafcb..c4db6d0ae 100644 --- a/bson/_cbsonmodule.h +++ b/bson/_cbsonmodule.h @@ -26,6 +26,32 @@ typedef int Py_ssize_t; #define PY_SSIZE_T_MIN INT_MIN #endif +#if defined(WIN32) || defined(_MSC_VER) +/* + * This macro is basically an implementation of asprintf for win32 + * We print to the provided buffer to get the string value as an int. + */ +#if defined(_MSC_VER) && (_MSC_VER >= 1400) +#define INT2STRING(buffer, i) \ + _snprintf_s((buffer), \ + _scprintf("%d", (i)) + 1, \ + _scprintf("%d", (i)) + 1, \ + "%d", \ + (i)) +#define STRCAT(dest, n, src) strcat_s((dest), (n), (src)) +#else +#define INT2STRING(buffer, i) \ + _snprintf((buffer), \ + _scprintf("%d", (i)) + 1, \ + "%d", \ + (i)) +#define STRCAT(dest, n, src) strcat((dest), (src)) +#endif +#else +#define INT2STRING(buffer, i) snprintf((buffer), sizeof((buffer)), "%d", (i)) +#define STRCAT(dest, n, src) strcat((dest), (src)) +#endif + /* C API functions */ #define _cbson_buffer_write_bytes_INDEX 0 #define _cbson_buffer_write_bytes_RETURN int diff --git a/pymongo/_cmessagemodule.c b/pymongo/_cmessagemodule.c index 5c0f69868..7e60f9701 100644 --- a/pymongo/_cmessagemodule.c +++ b/pymongo/_cmessagemodule.c @@ -540,12 +540,29 @@ static PyObject* _cbson_get_more_message(PyObject* self, PyObject* args) { return result; } +static void +_set_document_too_large(int size, long max) { + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { +#if PY_MAJOR_VERSION >= 3 + PyObject* error = PyUnicode_FromFormat(DOC_TOO_LARGE_FMT, size, max); +#else + PyObject* error = PyString_FromFormat(DOC_TOO_LARGE_FMT, size, max); +#endif + if (error) { + PyErr_SetObject(InvalidDocument, error); + Py_DECREF(error); + } + Py_DECREF(InvalidDocument); + } +} + static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { struct module_state *state = GETSTATE(self); /* NOTE just using a random number as the request_id */ int request_id = rand(); - int options = 0, max_size = 0; + int options = 0; int length_location, message_length; int collection_name_length; char* collection_name = NULL; @@ -562,6 +579,7 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { unsigned char safe; unsigned char continue_on_error; unsigned char uuid_subtype; + unsigned char empty = 1; long max_bson_size; long max_message_size; buffer_t buffer; @@ -633,6 +651,7 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { while ((doc = PyIter_Next(iterator)) != NULL) { int before = buffer_get_position(buffer); int cur_size; + empty = 0; if (!write_dict(state->_cbson, buffer, doc, check_keys, uuid_subtype, 1)) { Py_DECREF(doc); goto iterfail; @@ -640,23 +659,8 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { Py_DECREF(doc); cur_size = buffer_get_position(buffer) - before; - max_size = (cur_size > max_size) ? cur_size : max_size; if (cur_size > max_bson_size) { - PyObject* InvalidDocument = _error("InvalidDocument"); - if (InvalidDocument) { -#if PY_MAJOR_VERSION >= 3 - PyObject* error = PyUnicode_FromFormat(DOC_TOO_LARGE_FMT, - cur_size, max_bson_size); -#else - PyObject* error = PyString_FromFormat(DOC_TOO_LARGE_FMT, - cur_size, max_bson_size); -#endif - if (error) { - PyErr_SetObject(InvalidDocument, error); - Py_DECREF(error); - } - Py_DECREF(InvalidDocument); - } + _set_document_too_large(cur_size, max_bson_size); goto iterfail; } @@ -727,8 +731,8 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { /* We're doing unacknowledged writes and * continue_on_error is False. Just return. */ Py_DECREF(etype); - Py_DECREF(evalue); - Py_DECREF(etrace); + Py_XDECREF(evalue); + Py_XDECREF(etrace); Py_DECREF(iterator); buffer_free(buffer); PyMem_Free(collection_name); @@ -763,7 +767,7 @@ static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { goto insertfail; } - if (!max_size) { + if (empty) { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "cannot do an empty bulk insert"); @@ -824,6 +828,379 @@ insertfail: return NULL; } +static PyObject* +_send_write_command(PyObject* client, buffer_t buffer, + int lst_len_loc, int cmd_len_loc, unsigned char* errors) { + + PyObject* msg; + PyObject* result; + PyObject* ok; + + int request_id = rand(); + int position = buffer_get_position(buffer); + int length = position - lst_len_loc - 1; + memcpy(buffer_get_buffer(buffer) + lst_len_loc, &length, 4); + length = position - cmd_len_loc; + memcpy(buffer_get_buffer(buffer) + cmd_len_loc, &length, 4); + memcpy(buffer_get_buffer(buffer), &position, 4); + memcpy(buffer_get_buffer(buffer) + 4, &request_id, 4); + + /* objectify buffer */ + msg = Py_BuildValue("i" BYTES_FORMAT_STRING, request_id, + buffer_get_buffer(buffer), + buffer_get_position(buffer)); + if (!msg) + return NULL; + + /* Send the current batch */ + result = PyObject_CallMethod(client, "_send_message", + "NOO", msg, Py_True, Py_True); + if (!result) { + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; + PyObject* OperationFailure; + PyErr_Fetch(&etype, &evalue, &etrace); + OperationFailure = _error("OperationFailure"); + if (OperationFailure) { + if (PyErr_GivenExceptionMatches(etype, OperationFailure)) { + PyObject* details; + Py_DECREF(OperationFailure); + if (!evalue || + !(details = PyObject_GetAttrString(evalue, + "error_document")) || + details == Py_None) { + /* + * Either there is no error_document, or + * something went very wrong... + */ + PyErr_Restore(etype, evalue, etrace); + return NULL; + } + Py_DECREF(etype); + Py_DECREF(evalue); + Py_XDECREF(etrace); + *errors = 1; + return details; + } + Py_DECREF(OperationFailure); + } + /* + * This isn't OperationFailure or we couldn't + * import OperationFailure. Re-raise immediately. + */ + PyErr_Restore(etype, evalue, etrace); + return NULL; + } + if (!(ok = PyDict_GetItemString(result, "ok")) || + !PyObject_IsTrue(ok)) + *errors = 1; + return result; +} + +static buffer_t +_command_buffer_new(char* ns, int ns_len) { + buffer_t buffer; + if (!(buffer = buffer_new())) { + PyErr_NoMemory(); + return NULL; + } + /* Save space for message length and request id */ + if ((buffer_save_space(buffer, 8)) == -1) { + PyErr_NoMemory(); + buffer_free(buffer); + return NULL; + } + if (!buffer_write_bytes(buffer, + "\x00\x00\x00\x00" /* responseTo */ + "\xd4\x07\x00\x00" /* opcode */ + "\x00\x00\x00\x00", /* options */ + 12) || + !buffer_write_bytes(buffer, + ns, ns_len + 1) || /* namespace */ + !buffer_write_bytes(buffer, + "\x00\x00\x00\x00" /* skip */ + "\xFF\xFF\xFF\xFF", /* limit (-1) */ + 8)) { + buffer_free(buffer); + return NULL; + } + return buffer; +} + +static PyObject* +_cbson_do_batched_write_command(PyObject* self, PyObject* args) { + struct module_state *state = GETSTATE(self); + + long max_bson_size; + long max_cmd_size; + long idx_offset = 0; + int idx = 0; + int cmd_len_loc; + int lst_len_loc; + int ns_len; + char *ns = NULL, *cmd = NULL; + PyObject* max_bson_size_obj; + PyObject* command; + PyObject* doc; + PyObject* docs; + PyObject* client; + PyObject* iterator; + PyObject* result; + PyObject* results; + unsigned char check_keys; + unsigned char ordered; + unsigned char uuid_subtype; + unsigned char empty = 1; + unsigned char errors = 0; + buffer_t buffer; + + if (!PyArg_ParseTuple(args, "et#sOObbbO", "utf-8", + &ns, &ns_len, &cmd, &command, &docs, + &check_keys, &ordered, &uuid_subtype, &client)) { + return NULL; + } + + max_bson_size_obj = PyObject_GetAttrString(client, "max_bson_size"); +#if PY_MAJOR_VERSION >= 3 + max_bson_size = PyLong_AsLong(max_bson_size_obj); +#else + max_bson_size = PyInt_AsLong(max_bson_size_obj); +#endif + Py_XDECREF(max_bson_size_obj); + if (max_bson_size == -1) { + PyMem_Free(ns); + return NULL; + } + /* + * Max BSON object size + 16k - 2 bytes for ending NUL bytes + * XXX: This should come from the server - SERVER-10643 + */ + max_cmd_size = max_bson_size + 16382; + + if (!(results = PyList_New(0))) { + PyMem_Free(ns); + return NULL; + } + + if (!(buffer = _command_buffer_new(ns, ns_len))) { + PyMem_Free(ns); + Py_DECREF(results); + return NULL; + } + + PyMem_Free(ns); + + /* Position of command document length */ + cmd_len_loc = buffer_get_position(buffer); + if (!write_dict(state->_cbson, buffer, command, 0, uuid_subtype, 0)) { + goto cmdfail; + } + + /* Write type byte for array */ + *(buffer_get_buffer(buffer) + (buffer_get_position(buffer) - 1)) = 0x4; + + switch (*cmd) { + case 'i': + { + if (!buffer_write_bytes(buffer, "documents\x00", 10)) + goto cmdfail; + break; + } + case 'u': + { + /* MongoDB does key validation for update. */ + check_keys = 0; + if (!buffer_write_bytes(buffer, "updates\x00", 8)) + goto cmdfail; + break; + } + case 'd': + { + /* Never check keys in a delete command. */ + check_keys = 0; + if (!buffer_write_bytes(buffer, "deletes\x00", 8)) + goto cmdfail; + break; + } + default: + { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { +#if PY_MAJOR_VERSION >= 3 + PyObject* error = PyUnicode_FromFormat("Unknown command: %s", + cmd); +#else + PyObject* error = PyString_FromFormat("Unknown command: %s", + cmd); +#endif + if (error) { + PyErr_SetObject(InvalidOperation, error); + Py_DECREF(error); + } + Py_DECREF(InvalidOperation); + } + goto cmdfail; + } + } + + /* Save space for list document */ + lst_len_loc = buffer_save_space(buffer, 4); + if (lst_len_loc == -1) { + PyErr_NoMemory(); + goto cmdfail; + } + + iterator = PyObject_GetIter(docs); + if (iterator == NULL) { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { + PyErr_SetString(InvalidOperation, "input is not iterable"); + Py_DECREF(InvalidOperation); + } + goto cmdfail; + } + while ((doc = PyIter_Next(iterator)) != NULL) { + int sub_doc_begin = buffer_get_position(buffer); + int cur_doc_begin; + int cur_size; + char key[16]; + empty = 0; + INT2STRING(key, idx); + if (!buffer_write_bytes(buffer, "\x03", 1) || + !buffer_write_bytes(buffer, key, (int)strlen(key) + 1)) { + Py_DECREF(doc); + goto cmditerfail; + } + cur_doc_begin = buffer_get_position(buffer); + if (!write_dict(state->_cbson, buffer, doc, + check_keys, uuid_subtype, 1)) { + Py_DECREF(doc); + goto cmditerfail; + } + Py_DECREF(doc); + + /* We have enough data, maybe send this batch. */ + if (buffer_get_position(buffer) > max_cmd_size) { + buffer_t new_buffer; + cur_size = buffer_get_position(buffer) - cur_doc_begin; + + /* This single document is too large for the command. */ + if (!idx) { + if (*cmd == 'i') { /* Insert */ + _set_document_too_large(cur_size, max_bson_size); + } else { /* Update and delete */ + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { + /* + * There's nothing intelligent we can say + * about size for update and remove. + */ + PyErr_SetString(InvalidDocument, + "command document too large"); + Py_DECREF(InvalidDocument); + } + } + goto cmditerfail; + } + + if (!(new_buffer = buffer_new())) { + PyErr_NoMemory(); + goto cmditerfail; + } + /* New buffer including the current overflow document */ + if (!buffer_write_bytes(new_buffer, + (const char*)buffer_get_buffer(buffer), lst_len_loc + 5) || + !buffer_write_bytes(new_buffer, "0\x00", 2) || + !buffer_write_bytes(new_buffer, + (const char*)buffer_get_buffer(buffer) + cur_doc_begin, cur_size)) { + buffer_free(new_buffer); + goto cmditerfail; + } + /* + * Roll the existing buffer back to the beginning + * of the last document encoded. + */ + buffer_update_position(buffer, sub_doc_begin); + + if (!buffer_write_bytes(buffer, "\x00\x00", 2)) + goto cmditerfail; + + result = _send_write_command(client, buffer, + lst_len_loc, cmd_len_loc, &errors); + + buffer_free(buffer); + buffer = new_buffer; + + if (!result) + goto cmditerfail; + +#if PY_MAJOR_VERSION >= 3 + result = Py_BuildValue("NN", + PyLong_FromLong(idx_offset), result); +#else + result = Py_BuildValue("NN", + PyInt_FromLong(idx_offset), result); +#endif + if (!result) + goto cmditerfail; + + PyList_Append(results, result); + Py_DECREF(result); + + if (errors && ordered) { + Py_DECREF(iterator); + buffer_free(buffer); + return results; + } + idx_offset += idx; + idx = 0; + } + idx += 1; + } + Py_DECREF(iterator); + + if (PyErr_Occurred()) { + goto cmdfail; + } + + if (empty) { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { + PyErr_SetString(InvalidOperation, "cannot do an empty bulk write"); + Py_DECREF(InvalidOperation); + } + goto cmdfail; + } + + if (!buffer_write_bytes(buffer, "\x00\x00", 2)) + goto cmdfail; + + result = _send_write_command(client, buffer, + lst_len_loc, cmd_len_loc, &errors); + if (!result) + goto cmdfail; + +#if PY_MAJOR_VERSION >= 3 + result = Py_BuildValue("NN", PyLong_FromLong(idx_offset), result); +#else + result = Py_BuildValue("NN", PyInt_FromLong(idx_offset), result); +#endif + if (!result) + goto cmdfail; + + buffer_free(buffer); + + PyList_Append(results, result); + Py_DECREF(result); + return results; + +cmditerfail: + Py_DECREF(iterator); +cmdfail: + Py_DECREF(results); + buffer_free(buffer); + return NULL; +} + static PyMethodDef _CMessageMethods[] = { {"_insert_message", _cbson_insert_message, METH_VARARGS, _cbson_insert_message_doc}, @@ -835,6 +1212,8 @@ static PyMethodDef _CMessageMethods[] = { "create a get more message to be sent to MongoDB"}, {"_do_batched_insert", _cbson_do_batched_insert, METH_VARARGS, "insert a batch of documents, splitting the batch as needed"}, + {"_do_batched_write_command", _cbson_do_batched_write_command, METH_VARARGS, + "execute a batch of insert, update, or delete commands"}, {NULL, NULL, 0, NULL} }; diff --git a/pymongo/collection.py b/pymongo/collection.py index ab87c4682..c05ff0163 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -24,6 +24,7 @@ from pymongo import (common, message) from pymongo.cursor import Cursor from pymongo.errors import InvalidName +from pymongo.helpers import _check_command_response try: @@ -323,10 +324,11 @@ class Collection(common.BaseObject): .. mongodoc:: insert """ + client = self.database.connection # Batch inserts require us to know the connected master's # max_bson_size and max_message_size. We have to be connected # to a master to know that. - self.database.connection._ensure_connected(True) + client._ensure_connected(True) docs = doc_or_docs return_one = False @@ -354,10 +356,41 @@ class Collection(common.BaseObject): yield doc safe, options = self._get_write_mode(safe, **kwargs) - message._do_batched_insert(self.__full_name, gen(), - check_keys, safe, options, - continue_on_error, self.uuid_subtype, - self.database.connection) + + if client.max_wire_version > 1 and safe: + # Insert command + dbname, collname = self.__full_name.split('.', 1) + namespace = '%s.%s' % (dbname, '$cmd') + command = SON([('insert', collname), + ('ordered', not continue_on_error)]) + + if safe: + command['writeConcern'] = options + + results = message._do_batched_write_command( + namespace, 'insert', command, gen(), check_keys, + not continue_on_error, self.uuid_subtype, client) + + errors = [result for result in results if not result[1]['ok']] + if errors: + # If multiple batches had errors + # just raise from the last batch... + offset, error = errors[-1] + if "errDetails" in error: + # ...and the last error in that batch. + error = error["errDetails"][-1] + error["index"] += offset + # We use _check_command_response to figure out the + # error type (OperationFailure, DuplicateKeyError, etc.) + # but we have to add the 'ok' field if we're passing it + # a subdocument from errDetails. + error['ok'] = 0 + _check_command_response(error, None) + else: + # Legacy batched OP_INSERT + message._do_batched_insert(self.__full_name, gen(), check_keys, + safe, options, continue_on_error, + self.uuid_subtype, client) if return_one: return ids[0] @@ -478,10 +511,39 @@ class Collection(common.BaseObject): if first.startswith('$'): check_keys = False - return self.__database.connection._send_message( - message.update(self.__full_name, upsert, multi, - spec, document, safe, options, - check_keys, self.uuid_subtype), safe) + client = self.database.connection + if client.max_wire_version > 1 and safe: + # Update command + dbname, collname = self.__full_name.split('.', 1) + namespace = '%s.%s' % (dbname, '$cmd') + command = SON([('update', collname)]) + + if safe: + command['writeConcern'] = options + + docs = [SON([('q', spec), ('u', document), + ('multi', multi), ('upsert', upsert)])] + + _, result = message._do_batched_write_command( + namespace, 'update', command, docs, + check_keys, True, self.uuid_subtype, client)[0] + if not result['ok']: + _check_command_response(result, None) + + # Add the updatedExisting field for compatibility + if result.get('n') and 'upserted' not in result: + result['updatedExisting'] = True + else: + result['updatedExisting'] = False + + return result + + else: + # Legacy OP_UPDATE + return client._send_message( + message.update(self.__full_name, upsert, multi, + spec, document, safe, options, + check_keys, self.uuid_subtype), safe) def drop(self): """Alias for :meth:`~pymongo.database.Database.drop_collection`. @@ -563,9 +625,32 @@ class Collection(common.BaseObject): spec_or_id = {"_id": spec_or_id} safe, options = self._get_write_mode(safe, **kwargs) - return self.__database.connection._send_message( - message.delete(self.__full_name, spec_or_id, safe, - options, self.uuid_subtype), safe) + + client = self.database.connection + if client.max_wire_version > 1 and safe: + # Delete command + dbname, collname = self.__full_name.split('.', 1) + namespace = '%s.%s' % (dbname, '$cmd') + command = SON([('delete', collname)]) + + if safe: + command['writeConcern'] = options + + docs = [SON([('q', spec_or_id), ('limit', 0)])] + + _, result = message._do_batched_write_command( + namespace, 'delete', command, docs, + False, True, self.uuid_subtype, client)[0] + if not result['ok']: + _check_command_response(result, None) + + return result + + else: + # Legacy OP_DELETE + return client._send_message( + message.delete(self.__full_name, spec_or_id, safe, + options, self.uuid_subtype), safe) def find_one(self, spec_or_id=None, *args, **kwargs): """Get a single document from the database. diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 825c48b19..008e9dfda 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -173,6 +173,10 @@ def _check_command_response(response, reset, msg=None, allowable_errors=None): elif code == 50: raise ExecutionTimeout(errmsg, code, response) + # wtimeout from write commands + if "errInfo" in details and details["errInfo"].get("wtimeout"): + raise WTimeoutError(errmsg, code, response) + msg = msg or "%s" raise OperationFailure(msg % errmsg, code, response) diff --git a/pymongo/master_slave_connection.py b/pymongo/master_slave_connection.py index a76bcfc24..25d7f6598 100644 --- a/pymongo/master_slave_connection.py +++ b/pymongo/master_slave_connection.py @@ -168,7 +168,8 @@ class MasterSlaveConnection(BaseObject): # that killcursor operations can be sent to the same instance on which # the cursor actually resides... def _send_message(self, message, - with_last_error=False, _connection_to_use=None): + with_last_error=False, + command=False, _connection_to_use=None): """Say something to Mongo. Sends a message on the Master connection. This is used for inserts, @@ -183,9 +184,10 @@ class MasterSlaveConnection(BaseObject): - `safe`: perform a getLastError after sending the message """ if _connection_to_use is None or _connection_to_use == -1: - return self.__master._send_message(message, with_last_error) + return self.__master._send_message(message, + with_last_error, command) return self.__slaves[_connection_to_use]._send_message( - message, with_last_error, check_primary=False) + message, with_last_error, command, check_primary=False) # _connection_to_use is a hack that we need to include to make sure # that getmore operations can be sent to the same instance on which diff --git a/pymongo/message.py b/pymongo/message.py index 368cc010c..74b667107 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -27,7 +27,7 @@ import struct import bson from bson.binary import OLD_UUID_SUBTYPE -from bson.py3compat import b +from bson.py3compat import b, StringIO from bson.son import SON try: from pymongo import _cmessage @@ -37,13 +37,22 @@ except ImportError: from pymongo.errors import InvalidDocument, InvalidOperation, OperationFailure -__ZERO = b("\x00\x00\x00\x00") - -EMPTY = b("") - MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 +_EMPTY = b('') +_BSONOBJ = b('\x03') +_ZERO_8 = b('\x00') +_ZERO_16 = b('\x00\x00') +_ZERO_32 = b('\x00\x00\x00\x00') +_ZERO_64 = b('\x00\x00\x00\x00\x00\x00\x00\x00') +_SKIPLIM = b('\x00\x00\x00\x00\xff\xff\xff\xff') +_CMD_MAP = { + 'insert': b('\x04documents\x00\x00\x00\x00\x00'), + 'update': b('\x04updates\x00\x00\x00\x00\x00'), + 'delete': b('\x04deletes\x00\x00\x00\x00\x00'), +} + def __last_error(namespace, args): """Data to send to do a lastError. @@ -62,7 +71,7 @@ def __pack_message(operation, data): request_id = random.randint(MIN_INT32, MAX_INT32) message = struct.pack("= max_cmd_size: + if not idx: + if name == 'insert': + raise InvalidDocument("BSON document too large (%d bytes)" + " - the connected server supports" + " BSON document sizes up to %d" + " bytes." % (len(value), + max_bson_size)) + # There's nothing intelligent we can say + # about size for update and remove + raise InvalidDocument("command document too large") + errors, result = send_message() + results.append((idx_offset, result)) + if errors and ordered: + return results + + # Truncate back to the start of list elements + buf.seek(list_start + 4) + buf.truncate() + idx_offset += idx + idx = 0 + key = b('0') + buf.write(_BSONOBJ) + buf.write(key) + buf.write(_ZERO_8) + buf.write(value) + idx += 1 + + if not has_docs: + raise InvalidOperation("cannot do an empty bulk write") + + _, result = send_message() + results.append((idx_offset, result)) + return results +if _use_c: + _do_batched_write_command = _cmessage._do_batched_write_command diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index d5d6e35e3..e3ddc2911 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -902,7 +902,7 @@ class MongoClient(common.BaseObject): self.__cursor_manager = manager - def __check_response_to_last_error(self, response): + def __check_response_to_last_error(self, response, command): """Check a response to a lastError message for errors. `response` is a byte string representing a response to the message. @@ -915,6 +915,9 @@ class MongoClient(common.BaseObject): assert response["number_returned"] == 1 error = response["data"][0] + if command: + return error + helpers._check_command_response(error, self.disconnect) error_msg = error.get("err", "") @@ -959,7 +962,8 @@ class MongoClient(common.BaseObject): # don't include BSON documents. return message - def _send_message(self, message, with_last_error=False, check_primary=True): + def _send_message(self, message, + with_last_error=False, command=False, check_primary=True): """Say something to Mongo. Raises ConnectionFailure if the message cannot be sent. Raises @@ -992,7 +996,7 @@ class MongoClient(common.BaseObject): if with_last_error: response = self.__receive_message_on_socket(1, request_id, sock_info) - rv = self.__check_response_to_last_error(response) + rv = self.__check_response_to_last_error(response, command) return rv except OperationFailure: diff --git a/pymongo/mongo_replica_set_client.py b/pymongo/mongo_replica_set_client.py index 61e804ac0..419872501 100644 --- a/pymongo/mongo_replica_set_client.py +++ b/pymongo/mongo_replica_set_client.py @@ -1390,7 +1390,7 @@ class MongoReplicaSetClient(common.BaseObject): if member and sock_info: member.pool.maybe_return_socket(sock_info) - def __check_response_to_last_error(self, response): + def __check_response_to_last_error(self, response, command): """Check a response to a lastError message for errors. `response` is a byte string representing a response to the message. @@ -1403,6 +1403,9 @@ class MongoReplicaSetClient(common.BaseObject): assert response["number_returned"] == 1 error = response["data"][0] + if command: + return error + helpers._check_command_response(error, self.disconnect) error_msg = error.get("err", "") @@ -1468,8 +1471,8 @@ class MongoReplicaSetClient(common.BaseObject): # don't include BSON documents. return msg - def _send_message(self, msg, - with_last_error=False, _connection_to_use=None): + def _send_message(self, msg, with_last_error=False, + command=False, _connection_to_use=None): """Say something to Mongo. Raises ConnectionFailure if the message cannot be sent. Raises @@ -1505,7 +1508,7 @@ class MongoReplicaSetClient(common.BaseObject): rv = None if with_last_error: response = self.__recv_msg(1, rqst_id, sock_info) - rv = self.__check_response_to_last_error(response) + rv = self.__check_response_to_last_error(response, command) return rv except OperationFailure: raise diff --git a/test/test_collection.py b/test/test_collection.py index c6c0a87e3..7d79c7055 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -911,7 +911,10 @@ class TestCollection(unittest.TestCase): self.assertTrue(self.db.test.insert({"hello": "world"})) doc = self.db.test.find_one() doc['a.b'] = 'c' - self.assertRaises(InvalidDocument, self.db.test.save, doc) + expected = InvalidDocument + if version.at_least(self.client, (2, 5, 4, -1)): + expected = OperationFailure + self.assertRaises(expected, self.db.test.save, doc) def test_unique_index(self): db = self.db @@ -1020,7 +1023,7 @@ class TestCollection(unittest.TestCase): self.db.test.update({}, {"$thismodifierdoesntexist": 1}) except OperationFailure, exc: if version.at_least(self.db.connection, (1, 3)): - self.assertTrue(exc.code in (10147, 17009)) + self.assertTrue(exc.code in (10147, 16840, 17009)) # Just check that we set the error document. Fields # vary by MongoDB version. self.assertTrue(exc.error_document is not None) @@ -1159,11 +1162,15 @@ class TestCollection(unittest.TestCase): doc = self.db.test.find_one() doc['a.b'] = 'c' + expected = InvalidDocument + if version.at_least(self.client, (2, 5, 4, -1)): + expected = OperationFailure + # Replace - self.assertRaises(InvalidDocument, + self.assertRaises(expected, self.db.test.update, {"hello": "world"}, doc) # Upsert - self.assertRaises(InvalidDocument, + self.assertRaises(expected, self.db.test.update, {"foo": "bar"}, doc, upsert=True) # Check that the last two ops didn't actually modify anything @@ -1185,7 +1192,7 @@ class TestCollection(unittest.TestCase): # doesn't change. If the behavior changes checking the first key for # '$' in update won't be good enough anymore. doc = SON([("hello", "world"), ("$set", {"foo.bar": "bim"})]) - self.assertRaises(InvalidDocument, self.db.test.update, + self.assertRaises(expected, self.db.test.update, {"hello": "world"}, doc, upsert=True) # Replace with empty document @@ -1731,23 +1738,34 @@ class TestCollection(unittest.TestCase): half_size = int(max_size / 2) if version.at_least(self.db.connection, (1, 7, 4)): self.assertEqual(max_size, 16777216) - self.assertRaises(InvalidDocument, self.db.test.insert, + + expected = InvalidDocument + if version.at_least(self.client, (2, 5, 4, -1)): + # Document too large handled by the server + expected = OperationFailure + self.assertRaises(expected, self.db.test.insert, {"foo": "x" * max_size}) - self.assertRaises(InvalidDocument, self.db.test.save, + self.assertRaises(expected, self.db.test.save, {"foo": "x" * max_size}) - self.assertRaises(InvalidDocument, self.db.test.insert, + self.assertRaises(expected, self.db.test.insert, [{"x": 1}, {"foo": "x" * max_size}]) self.db.test.insert([{"foo": "x" * half_size}, {"foo": "x" * half_size}]) self.db.test.insert({"bar": "x"}) + # Use w=0 here to test legacy doc size checking in all server versions self.assertRaises(InvalidDocument, self.db.test.update, - {"bar": "x"}, {"bar": "x" * (max_size - 14)}) + {"bar": "x"}, {"bar": "x" * (max_size - 14)}, w=0) + # This will pass with OP_UPDATE or the update command. self.db.test.update({"bar": "x"}, {"bar": "x" * (max_size - 15)}) def test_insert_large_batch(self): - max_bson_size = self.db.connection.max_bson_size - big_string = 'x' * (max_bson_size - 100) + max_bson_size = self.client.max_bson_size + if version.at_least(self.client, (2, 5, 4, -1)): + # Write commands are limited to 16MB + 16k per batch + big_string = 'x' * int(max_bson_size / 2) + else: + big_string = 'x' * (max_bson_size - 100) self.db.test.drop() self.assertEqual(0, self.db.test.count())