PYTHON-2915 Fix bug when starting a transaction with a large bulk write (#743)
This commit is contained in:
parent
a80169d1fa
commit
7467aa634d
@ -267,17 +267,18 @@ class _Bulk(object):
|
||||
# sock_info.write_command.
|
||||
sock_info.validate_session(client, session)
|
||||
while run:
|
||||
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
|
||||
('ordered', self.ordered)])
|
||||
if not write_concern.is_server_default:
|
||||
cmd['writeConcern'] = write_concern.document
|
||||
if self.bypass_doc_val:
|
||||
cmd['bypassDocumentValidation'] = True
|
||||
cmd_name = _COMMANDS[run.op_type]
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name, cmd, sock_info, op_id, listeners, session,
|
||||
db_name, cmd_name, sock_info, op_id, listeners, session,
|
||||
run.op_type, self.collection.codec_options)
|
||||
|
||||
while run.idx_offset < len(run.ops):
|
||||
cmd = SON([(cmd_name, self.collection.name),
|
||||
('ordered', self.ordered)])
|
||||
if not write_concern.is_server_default:
|
||||
cmd['writeConcern'] = write_concern.document
|
||||
if self.bypass_doc_val:
|
||||
cmd['bypassDocumentValidation'] = True
|
||||
if session:
|
||||
# Start a new retryable write unless one was already
|
||||
# started for this command.
|
||||
@ -287,9 +288,10 @@ class _Bulk(object):
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY,
|
||||
sock_info)
|
||||
sock_info.send_cluster_time(cmd, session, client)
|
||||
sock_info.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible in one command.
|
||||
result, to_send = bwc.execute(ops, client)
|
||||
result, to_send = bwc.execute(cmd, ops, client)
|
||||
|
||||
# Retryable writeConcernErrors halt the execution of this run.
|
||||
wce = result.get('writeConcernError', {})
|
||||
@ -359,17 +361,19 @@ class _Bulk(object):
|
||||
run = self.current_run
|
||||
|
||||
while run:
|
||||
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
|
||||
('ordered', False),
|
||||
('writeConcern', {'w': 0})])
|
||||
cmd_name = _COMMANDS[run.op_type]
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name, cmd, sock_info, op_id, listeners, None,
|
||||
db_name, cmd_name, sock_info, op_id, listeners, None,
|
||||
run.op_type, self.collection.codec_options)
|
||||
|
||||
while run.idx_offset < len(run.ops):
|
||||
cmd = SON([(cmd_name, self.collection.name),
|
||||
('ordered', False),
|
||||
('writeConcern', {'w': 0})])
|
||||
sock_info.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible.
|
||||
to_send = bwc.execute_unack(ops, client)
|
||||
to_send = bwc.execute_unack(cmd, ops, client)
|
||||
run.idx_offset += len(to_send)
|
||||
self.current_run = run = next(generator, None)
|
||||
|
||||
|
||||
@ -704,50 +704,48 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
|
||||
class _BulkWriteContext(object):
|
||||
"""A wrapper around SocketInfo for use with write splitting functions."""
|
||||
|
||||
__slots__ = ('db_name', 'command', 'sock_info', 'op_id',
|
||||
__slots__ = ('db_name', 'sock_info', 'op_id',
|
||||
'name', 'field', 'publish', 'start_time', 'listeners',
|
||||
'session', 'compress', 'op_type', 'codec')
|
||||
|
||||
def __init__(self, database_name, command, sock_info, operation_id,
|
||||
def __init__(self, database_name, cmd_name, sock_info, operation_id,
|
||||
listeners, session, op_type, codec):
|
||||
self.db_name = database_name
|
||||
self.command = command
|
||||
self.sock_info = sock_info
|
||||
self.op_id = operation_id
|
||||
self.listeners = listeners
|
||||
self.publish = listeners.enabled_for_commands
|
||||
self.name = next(iter(command))
|
||||
self.name = cmd_name
|
||||
self.field = _FIELD_MAP[self.name]
|
||||
self.start_time = datetime.datetime.now() if self.publish else None
|
||||
self.session = session
|
||||
self.compress = True if sock_info.compression_context else False
|
||||
self.op_type = op_type
|
||||
self.codec = codec
|
||||
sock_info.add_server_api(command)
|
||||
|
||||
def _batch_command(self, docs):
|
||||
def _batch_command(self, cmd, docs):
|
||||
namespace = self.db_name + '.$cmd'
|
||||
request_id, msg, to_send = _do_batched_op_msg(
|
||||
namespace, self.op_type, self.command, docs, self.check_keys,
|
||||
namespace, self.op_type, cmd, docs, self.check_keys,
|
||||
self.codec, self)
|
||||
if not to_send:
|
||||
raise InvalidOperation("cannot do an empty bulk write")
|
||||
return request_id, msg, to_send
|
||||
|
||||
def execute(self, docs, client):
|
||||
request_id, msg, to_send = self._batch_command(docs)
|
||||
result = self.write_command(request_id, msg, to_send)
|
||||
def execute(self, cmd, docs, client):
|
||||
request_id, msg, to_send = self._batch_command(cmd, docs)
|
||||
result = self.write_command(cmd, request_id, msg, to_send)
|
||||
client._process_response(result, self.session)
|
||||
return result, to_send
|
||||
|
||||
def execute_unack(self, docs, client):
|
||||
request_id, msg, to_send = self._batch_command(docs)
|
||||
def execute_unack(self, cmd, docs, client):
|
||||
request_id, msg, to_send = self._batch_command(cmd, docs)
|
||||
# Though this isn't strictly a "legacy" write, the helper
|
||||
# handles publishing commands and sending our message
|
||||
# without receiving a result. Send 0 for max_doc_size
|
||||
# to disable size checking. Size checking is handled while
|
||||
# the documents are encoded to BSON.
|
||||
self.unack_write(request_id, msg, 0, to_send)
|
||||
self.unack_write(cmd, request_id, msg, 0, to_send)
|
||||
return to_send
|
||||
|
||||
@property
|
||||
@ -778,12 +776,12 @@ class _BulkWriteContext(object):
|
||||
"""The maximum size of a BSON command before batch splitting."""
|
||||
return self.max_bson_size
|
||||
|
||||
def unack_write(self, request_id, msg, max_doc_size, docs):
|
||||
def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
|
||||
"""A proxy for SocketInfo.unack_write that handles event publishing.
|
||||
"""
|
||||
if self.publish:
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
cmd = self._start(request_id, docs)
|
||||
cmd = self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
try:
|
||||
result = self.sock_info.unack_write(msg, max_doc_size)
|
||||
@ -811,12 +809,12 @@ class _BulkWriteContext(object):
|
||||
self.start_time = datetime.datetime.now()
|
||||
return result
|
||||
|
||||
def write_command(self, request_id, msg, docs):
|
||||
def write_command(self, cmd, request_id, msg, docs):
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing.
|
||||
"""
|
||||
if self.publish:
|
||||
duration = datetime.datetime.now() - self.start_time
|
||||
self._start(request_id, docs)
|
||||
self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
try:
|
||||
reply = self.sock_info.write_command(request_id, msg)
|
||||
@ -836,9 +834,8 @@ class _BulkWriteContext(object):
|
||||
self.start_time = datetime.datetime.now()
|
||||
return reply
|
||||
|
||||
def _start(self, request_id, docs):
|
||||
def _start(self, cmd, request_id, docs):
|
||||
"""Publish a CommandStartedEvent."""
|
||||
cmd = self.command.copy()
|
||||
cmd[self.field] = docs
|
||||
self.listeners.publish_command_start(
|
||||
cmd, self.db_name,
|
||||
@ -871,10 +868,10 @@ _MAX_SPLIT_SIZE_ENC = 2097152
|
||||
class _EncryptedBulkWriteContext(_BulkWriteContext):
|
||||
__slots__ = ()
|
||||
|
||||
def _batch_command(self, docs):
|
||||
def _batch_command(self, cmd, docs):
|
||||
namespace = self.db_name + '.$cmd'
|
||||
msg, to_send = _encode_batched_write_command(
|
||||
namespace, self.op_type, self.command, docs, self.check_keys,
|
||||
namespace, self.op_type, cmd, docs, self.check_keys,
|
||||
self.codec, self)
|
||||
if not to_send:
|
||||
raise InvalidOperation("cannot do an empty bulk write")
|
||||
@ -885,17 +882,18 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
|
||||
DEFAULT_RAW_BSON_OPTIONS)
|
||||
return cmd, to_send
|
||||
|
||||
def execute(self, docs, client):
|
||||
cmd, to_send = self._batch_command(docs)
|
||||
def execute(self, cmd, docs, client):
|
||||
batched_cmd, to_send = self._batch_command(cmd, docs)
|
||||
result = self.sock_info.command(
|
||||
self.db_name, cmd, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
self.db_name, batched_cmd,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
session=self.session, client=client)
|
||||
return result, to_send
|
||||
|
||||
def execute_unack(self, docs, client):
|
||||
cmd, to_send = self._batch_command(docs)
|
||||
def execute_unack(self, cmd, docs, client):
|
||||
batched_cmd, to_send = self._batch_command(cmd, docs)
|
||||
self.sock_info.command(
|
||||
self.db_name, cmd, write_concern=WriteConcern(w=0),
|
||||
self.db_name, batched_cmd, write_concern=WriteConcern(w=0),
|
||||
session=self.session, client=client)
|
||||
return to_send
|
||||
|
||||
|
||||
@ -286,6 +286,36 @@ class TestTransactions(TransactionsBase):
|
||||
):
|
||||
op(*args, session=s)
|
||||
|
||||
# Require 4.2+ for large (16MB+) transactions.
|
||||
@client_context.require_version_min(4, 2)
|
||||
@client_context.require_transactions
|
||||
def test_transaction_starts_with_batched_write(self):
|
||||
# Start a transaction with a batch of operations that needs to be
|
||||
# split.
|
||||
listener = OvertCommandListener()
|
||||
client = rs_client(event_listeners=[listener])
|
||||
coll = client[self.db.name].test
|
||||
coll.delete_many({})
|
||||
listener.reset()
|
||||
self.addCleanup(client.close)
|
||||
self.addCleanup(coll.drop)
|
||||
ops = [InsertOne({'a': '1'*(10*1024*1024)}) for _ in range(10)]
|
||||
with client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
coll.bulk_write(ops, session=session)
|
||||
# Assert commands were constructed properly.
|
||||
self.assertEqual(['insert', 'insert', 'insert', 'commitTransaction'],
|
||||
listener.started_command_names())
|
||||
first_cmd = listener.results['started'][0].command
|
||||
self.assertTrue(first_cmd['startTransaction'])
|
||||
lsid = first_cmd['lsid']
|
||||
txn_number = first_cmd['txnNumber']
|
||||
for event in listener.results['started'][1:]:
|
||||
self.assertNotIn('startTransaction', event.command)
|
||||
self.assertEqual(lsid, event.command['lsid'])
|
||||
self.assertEqual(txn_number, event.command['txnNumber'])
|
||||
self.assertEqual(10, coll.count_documents({}))
|
||||
|
||||
|
||||
class PatchSessionTimeout(object):
|
||||
"""Patches the client_session's with_transaction timeout for testing."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user