PYTHON-2915 Fix bug when starting a transaction with a large bulk write (#743)

This commit is contained in:
Shane Harvey 2021-09-24 15:47:37 -07:00 committed by GitHub
parent a80169d1fa
commit 7467aa634d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 40 deletions

View File

@ -267,17 +267,18 @@ class _Bulk(object):
# sock_info.write_command.
sock_info.validate_session(client, session)
while run:
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
('ordered', self.ordered)])
if not write_concern.is_server_default:
cmd['writeConcern'] = write_concern.document
if self.bypass_doc_val:
cmd['bypassDocumentValidation'] = True
cmd_name = _COMMANDS[run.op_type]
bwc = self.bulk_ctx_class(
db_name, cmd, sock_info, op_id, listeners, session,
db_name, cmd_name, sock_info, op_id, listeners, session,
run.op_type, self.collection.codec_options)
while run.idx_offset < len(run.ops):
cmd = SON([(cmd_name, self.collection.name),
('ordered', self.ordered)])
if not write_concern.is_server_default:
cmd['writeConcern'] = write_concern.document
if self.bypass_doc_val:
cmd['bypassDocumentValidation'] = True
if session:
# Start a new retryable write unless one was already
# started for this command.
@ -287,9 +288,10 @@ class _Bulk(object):
session._apply_to(cmd, retryable, ReadPreference.PRIMARY,
sock_info)
sock_info.send_cluster_time(cmd, session, client)
sock_info.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible in one command.
result, to_send = bwc.execute(ops, client)
result, to_send = bwc.execute(cmd, ops, client)
# Retryable writeConcernErrors halt the execution of this run.
wce = result.get('writeConcernError', {})
@ -359,17 +361,19 @@ class _Bulk(object):
run = self.current_run
while run:
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
('ordered', False),
('writeConcern', {'w': 0})])
cmd_name = _COMMANDS[run.op_type]
bwc = self.bulk_ctx_class(
db_name, cmd, sock_info, op_id, listeners, None,
db_name, cmd_name, sock_info, op_id, listeners, None,
run.op_type, self.collection.codec_options)
while run.idx_offset < len(run.ops):
cmd = SON([(cmd_name, self.collection.name),
('ordered', False),
('writeConcern', {'w': 0})])
sock_info.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
to_send = bwc.execute_unack(ops, client)
to_send = bwc.execute_unack(cmd, ops, client)
run.idx_offset += len(to_send)
self.current_run = run = next(generator, None)

View File

@ -704,50 +704,48 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
class _BulkWriteContext(object):
"""A wrapper around SocketInfo for use with write splitting functions."""
__slots__ = ('db_name', 'command', 'sock_info', 'op_id',
__slots__ = ('db_name', 'sock_info', 'op_id',
'name', 'field', 'publish', 'start_time', 'listeners',
'session', 'compress', 'op_type', 'codec')
def __init__(self, database_name, command, sock_info, operation_id,
def __init__(self, database_name, cmd_name, sock_info, operation_id,
listeners, session, op_type, codec):
self.db_name = database_name
self.command = command
self.sock_info = sock_info
self.op_id = operation_id
self.listeners = listeners
self.publish = listeners.enabled_for_commands
self.name = next(iter(command))
self.name = cmd_name
self.field = _FIELD_MAP[self.name]
self.start_time = datetime.datetime.now() if self.publish else None
self.session = session
self.compress = True if sock_info.compression_context else False
self.op_type = op_type
self.codec = codec
sock_info.add_server_api(command)
def _batch_command(self, docs):
def _batch_command(self, cmd, docs):
namespace = self.db_name + '.$cmd'
request_id, msg, to_send = _do_batched_op_msg(
namespace, self.op_type, self.command, docs, self.check_keys,
namespace, self.op_type, cmd, docs, self.check_keys,
self.codec, self)
if not to_send:
raise InvalidOperation("cannot do an empty bulk write")
return request_id, msg, to_send
def execute(self, docs, client):
request_id, msg, to_send = self._batch_command(docs)
result = self.write_command(request_id, msg, to_send)
def execute(self, cmd, docs, client):
request_id, msg, to_send = self._batch_command(cmd, docs)
result = self.write_command(cmd, request_id, msg, to_send)
client._process_response(result, self.session)
return result, to_send
def execute_unack(self, docs, client):
request_id, msg, to_send = self._batch_command(docs)
def execute_unack(self, cmd, docs, client):
request_id, msg, to_send = self._batch_command(cmd, docs)
# Though this isn't strictly a "legacy" write, the helper
# handles publishing commands and sending our message
# without receiving a result. Send 0 for max_doc_size
# to disable size checking. Size checking is handled while
# the documents are encoded to BSON.
self.unack_write(request_id, msg, 0, to_send)
self.unack_write(cmd, request_id, msg, 0, to_send)
return to_send
@property
@ -778,12 +776,12 @@ class _BulkWriteContext(object):
"""The maximum size of a BSON command before batch splitting."""
return self.max_bson_size
def unack_write(self, request_id, msg, max_doc_size, docs):
def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
"""A proxy for SocketInfo.unack_write that handles event publishing.
"""
if self.publish:
duration = datetime.datetime.now() - self.start_time
cmd = self._start(request_id, docs)
cmd = self._start(cmd, request_id, docs)
start = datetime.datetime.now()
try:
result = self.sock_info.unack_write(msg, max_doc_size)
@ -811,12 +809,12 @@ class _BulkWriteContext(object):
self.start_time = datetime.datetime.now()
return result
def write_command(self, request_id, msg, docs):
def write_command(self, cmd, request_id, msg, docs):
"""A proxy for SocketInfo.write_command that handles event publishing.
"""
if self.publish:
duration = datetime.datetime.now() - self.start_time
self._start(request_id, docs)
self._start(cmd, request_id, docs)
start = datetime.datetime.now()
try:
reply = self.sock_info.write_command(request_id, msg)
@ -836,9 +834,8 @@ class _BulkWriteContext(object):
self.start_time = datetime.datetime.now()
return reply
def _start(self, request_id, docs):
def _start(self, cmd, request_id, docs):
"""Publish a CommandStartedEvent."""
cmd = self.command.copy()
cmd[self.field] = docs
self.listeners.publish_command_start(
cmd, self.db_name,
@ -871,10 +868,10 @@ _MAX_SPLIT_SIZE_ENC = 2097152
class _EncryptedBulkWriteContext(_BulkWriteContext):
__slots__ = ()
def _batch_command(self, docs):
def _batch_command(self, cmd, docs):
namespace = self.db_name + '.$cmd'
msg, to_send = _encode_batched_write_command(
namespace, self.op_type, self.command, docs, self.check_keys,
namespace, self.op_type, cmd, docs, self.check_keys,
self.codec, self)
if not to_send:
raise InvalidOperation("cannot do an empty bulk write")
@ -885,17 +882,18 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
DEFAULT_RAW_BSON_OPTIONS)
return cmd, to_send
def execute(self, docs, client):
cmd, to_send = self._batch_command(docs)
def execute(self, cmd, docs, client):
batched_cmd, to_send = self._batch_command(cmd, docs)
result = self.sock_info.command(
self.db_name, cmd, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
self.db_name, batched_cmd,
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
session=self.session, client=client)
return result, to_send
def execute_unack(self, docs, client):
cmd, to_send = self._batch_command(docs)
def execute_unack(self, cmd, docs, client):
batched_cmd, to_send = self._batch_command(cmd, docs)
self.sock_info.command(
self.db_name, cmd, write_concern=WriteConcern(w=0),
self.db_name, batched_cmd, write_concern=WriteConcern(w=0),
session=self.session, client=client)
return to_send

View File

@ -286,6 +286,36 @@ class TestTransactions(TransactionsBase):
):
op(*args, session=s)
# Require 4.2+ for large (16MB+) transactions.
@client_context.require_version_min(4, 2)
@client_context.require_transactions
def test_transaction_starts_with_batched_write(self):
# Start a transaction with a batch of operations that needs to be
# split.
listener = OvertCommandListener()
client = rs_client(event_listeners=[listener])
coll = client[self.db.name].test
coll.delete_many({})
listener.reset()
self.addCleanup(client.close)
self.addCleanup(coll.drop)
ops = [InsertOne({'a': '1'*(10*1024*1024)}) for _ in range(10)]
with client.start_session() as session:
with session.start_transaction():
coll.bulk_write(ops, session=session)
# Assert commands were constructed properly.
self.assertEqual(['insert', 'insert', 'insert', 'commitTransaction'],
listener.started_command_names())
first_cmd = listener.results['started'][0].command
self.assertTrue(first_cmd['startTransaction'])
lsid = first_cmd['lsid']
txn_number = first_cmd['txnNumber']
for event in listener.results['started'][1:]:
self.assertNotIn('startTransaction', event.command)
self.assertEqual(lsid, event.command['lsid'])
self.assertEqual(txn_number, event.command['txnNumber'])
self.assertEqual(10, coll.count_documents({}))
class PatchSessionTimeout(object):
"""Patches the client_session's with_transaction timeout for testing."""