diff --git a/pymongo/bulk.py b/pymongo/bulk.py index 0f5730928..829b482c9 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -267,17 +267,18 @@ class _Bulk(object): # sock_info.write_command. sock_info.validate_session(client, session) while run: - cmd = SON([(_COMMANDS[run.op_type], self.collection.name), - ('ordered', self.ordered)]) - if not write_concern.is_server_default: - cmd['writeConcern'] = write_concern.document - if self.bypass_doc_val: - cmd['bypassDocumentValidation'] = True + cmd_name = _COMMANDS[run.op_type] bwc = self.bulk_ctx_class( - db_name, cmd, sock_info, op_id, listeners, session, + db_name, cmd_name, sock_info, op_id, listeners, session, run.op_type, self.collection.codec_options) while run.idx_offset < len(run.ops): + cmd = SON([(cmd_name, self.collection.name), + ('ordered', self.ordered)]) + if not write_concern.is_server_default: + cmd['writeConcern'] = write_concern.document + if self.bypass_doc_val: + cmd['bypassDocumentValidation'] = True if session: # Start a new retryable write unless one was already # started for this command. @@ -287,9 +288,10 @@ class _Bulk(object): session._apply_to(cmd, retryable, ReadPreference.PRIMARY, sock_info) sock_info.send_cluster_time(cmd, session, client) + sock_info.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. - result, to_send = bwc.execute(ops, client) + result, to_send = bwc.execute(cmd, ops, client) # Retryable writeConcernErrors halt the execution of this run. wce = result.get('writeConcernError', {}) @@ -359,17 +361,19 @@ class _Bulk(object): run = self.current_run while run: - cmd = SON([(_COMMANDS[run.op_type], self.collection.name), - ('ordered', False), - ('writeConcern', {'w': 0})]) + cmd_name = _COMMANDS[run.op_type] bwc = self.bulk_ctx_class( - db_name, cmd, sock_info, op_id, listeners, None, + db_name, cmd_name, sock_info, op_id, listeners, None, run.op_type, self.collection.codec_options) while run.idx_offset < len(run.ops): + cmd = SON([(cmd_name, self.collection.name), + ('ordered', False), + ('writeConcern', {'w': 0})]) + sock_info.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. - to_send = bwc.execute_unack(ops, client) + to_send = bwc.execute_unack(cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) diff --git a/pymongo/message.py b/pymongo/message.py index 8a496d5b1..86a83f152 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -704,50 +704,48 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None): class _BulkWriteContext(object): """A wrapper around SocketInfo for use with write splitting functions.""" - __slots__ = ('db_name', 'command', 'sock_info', 'op_id', + __slots__ = ('db_name', 'sock_info', 'op_id', 'name', 'field', 'publish', 'start_time', 'listeners', 'session', 'compress', 'op_type', 'codec') - def __init__(self, database_name, command, sock_info, operation_id, + def __init__(self, database_name, cmd_name, sock_info, operation_id, listeners, session, op_type, codec): self.db_name = database_name - self.command = command self.sock_info = sock_info self.op_id = operation_id self.listeners = listeners self.publish = listeners.enabled_for_commands - self.name = next(iter(command)) + self.name = cmd_name self.field = _FIELD_MAP[self.name] self.start_time = datetime.datetime.now() if self.publish else None self.session = session self.compress = True if sock_info.compression_context else False self.op_type = op_type self.codec = codec - sock_info.add_server_api(command) - def _batch_command(self, docs): + def _batch_command(self, cmd, docs): namespace = self.db_name + '.$cmd' request_id, msg, to_send = _do_batched_op_msg( - namespace, self.op_type, self.command, docs, self.check_keys, + namespace, self.op_type, cmd, docs, self.check_keys, self.codec, self) if not to_send: raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send - def execute(self, docs, client): - request_id, msg, to_send = self._batch_command(docs) - result = self.write_command(request_id, msg, to_send) + def execute(self, cmd, docs, client): + request_id, msg, to_send = self._batch_command(cmd, docs) + result = self.write_command(cmd, request_id, msg, to_send) client._process_response(result, self.session) return result, to_send - def execute_unack(self, docs, client): - request_id, msg, to_send = self._batch_command(docs) + def execute_unack(self, cmd, docs, client): + request_id, msg, to_send = self._batch_command(cmd, docs) # Though this isn't strictly a "legacy" write, the helper # handles publishing commands and sending our message # without receiving a result. Send 0 for max_doc_size # to disable size checking. Size checking is handled while # the documents are encoded to BSON. - self.unack_write(request_id, msg, 0, to_send) + self.unack_write(cmd, request_id, msg, 0, to_send) return to_send @property @@ -778,12 +776,12 @@ class _BulkWriteContext(object): """The maximum size of a BSON command before batch splitting.""" return self.max_bson_size - def unack_write(self, request_id, msg, max_doc_size, docs): + def unack_write(self, cmd, request_id, msg, max_doc_size, docs): """A proxy for SocketInfo.unack_write that handles event publishing. """ if self.publish: duration = datetime.datetime.now() - self.start_time - cmd = self._start(request_id, docs) + cmd = self._start(cmd, request_id, docs) start = datetime.datetime.now() try: result = self.sock_info.unack_write(msg, max_doc_size) @@ -811,12 +809,12 @@ class _BulkWriteContext(object): self.start_time = datetime.datetime.now() return result - def write_command(self, request_id, msg, docs): + def write_command(self, cmd, request_id, msg, docs): """A proxy for SocketInfo.write_command that handles event publishing. """ if self.publish: duration = datetime.datetime.now() - self.start_time - self._start(request_id, docs) + self._start(cmd, request_id, docs) start = datetime.datetime.now() try: reply = self.sock_info.write_command(request_id, msg) @@ -836,9 +834,8 @@ class _BulkWriteContext(object): self.start_time = datetime.datetime.now() return reply - def _start(self, request_id, docs): + def _start(self, cmd, request_id, docs): """Publish a CommandStartedEvent.""" - cmd = self.command.copy() cmd[self.field] = docs self.listeners.publish_command_start( cmd, self.db_name, @@ -871,10 +868,10 @@ _MAX_SPLIT_SIZE_ENC = 2097152 class _EncryptedBulkWriteContext(_BulkWriteContext): __slots__ = () - def _batch_command(self, docs): + def _batch_command(self, cmd, docs): namespace = self.db_name + '.$cmd' msg, to_send = _encode_batched_write_command( - namespace, self.op_type, self.command, docs, self.check_keys, + namespace, self.op_type, cmd, docs, self.check_keys, self.codec, self) if not to_send: raise InvalidOperation("cannot do an empty bulk write") @@ -885,17 +882,18 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): DEFAULT_RAW_BSON_OPTIONS) return cmd, to_send - def execute(self, docs, client): - cmd, to_send = self._batch_command(docs) + def execute(self, cmd, docs, client): + batched_cmd, to_send = self._batch_command(cmd, docs) result = self.sock_info.command( - self.db_name, cmd, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + self.db_name, batched_cmd, + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, session=self.session, client=client) return result, to_send - def execute_unack(self, docs, client): - cmd, to_send = self._batch_command(docs) + def execute_unack(self, cmd, docs, client): + batched_cmd, to_send = self._batch_command(cmd, docs) self.sock_info.command( - self.db_name, cmd, write_concern=WriteConcern(w=0), + self.db_name, batched_cmd, write_concern=WriteConcern(w=0), session=self.session, client=client) return to_send diff --git a/test/test_transactions.py b/test/test_transactions.py index 33a9186e8..de3066303 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -286,6 +286,36 @@ class TestTransactions(TransactionsBase): ): op(*args, session=s) + # Require 4.2+ for large (16MB+) transactions. + @client_context.require_version_min(4, 2) + @client_context.require_transactions + def test_transaction_starts_with_batched_write(self): + # Start a transaction with a batch of operations that needs to be + # split. + listener = OvertCommandListener() + client = rs_client(event_listeners=[listener]) + coll = client[self.db.name].test + coll.delete_many({}) + listener.reset() + self.addCleanup(client.close) + self.addCleanup(coll.drop) + ops = [InsertOne({'a': '1'*(10*1024*1024)}) for _ in range(10)] + with client.start_session() as session: + with session.start_transaction(): + coll.bulk_write(ops, session=session) + # Assert commands were constructed properly. + self.assertEqual(['insert', 'insert', 'insert', 'commitTransaction'], + listener.started_command_names()) + first_cmd = listener.results['started'][0].command + self.assertTrue(first_cmd['startTransaction']) + lsid = first_cmd['lsid'] + txn_number = first_cmd['txnNumber'] + for event in listener.results['started'][1:]: + self.assertNotIn('startTransaction', event.command) + self.assertEqual(lsid, event.command['lsid']) + self.assertEqual(txn_number, event.command['txnNumber']) + self.assertEqual(10, coll.count_documents({})) + class PatchSessionTimeout(object): """Patches the client_session's with_transaction timeout for testing."""