PYTHON-4668 Improve performance of client.bulk_write (#1800)

This commit is contained in:
Shruti Sridhar 2024-08-20 17:18:28 -07:00 committed by GitHub
parent c03721c8f5
commit 4024a1b85d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 103 additions and 63 deletions

View File

@ -145,7 +145,7 @@ For example, a duplicate key error on the third operation below aborts the remai
'idx': 2,
'code': 11000,
'errmsg': 'E11000 duplicate key error ... dup key: { _id: 3 }', ...
'op': {'insert': 'db.test_three', 'document': {'_id': 3}}}]
'op': {'insert': 0, 'document': {'_id': 3}}}]
>>> exception.partial_result.inserted_count
2
>>> exception.partial_result.deleted_count
@ -181,7 +181,7 @@ For example, the fourth and fifth write operations below get executed successful
'idx': 2,
'code': 11000,
'errmsg': 'E11000 duplicate key error ... dup key: { _id: 5 }', ...
'op': {'insert': 'db.test_five', 'document': {'_id': 5}}}]
'op': {'insert': 0, 'document': {'_id': 5}}}]
>>> exception.partial_result.inserted_count
3
>>> exception.partial_result.deleted_count

View File

@ -108,6 +108,7 @@ class _AsyncClientBulk:
self.verbose_results = verbose_results
self.ops: list[tuple[str, Mapping[str, Any]]] = []
self.namespaces: list[str] = []
self.idx_offset: int = 0
self.total_ops: int = 0
@ -132,8 +133,9 @@ class _AsyncClientBulk:
# Generate ObjectId client side.
if not (isinstance(document, RawBSONDocument) or "_id" in document):
document["_id"] = ObjectId()
cmd = {"insert": namespace, "document": document}
cmd = {"insert": -1, "document": document}
self.ops.append(("insert", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_update(
@ -150,7 +152,7 @@ class _AsyncClientBulk:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd = {
"update": namespace,
"update": -1,
"filter": selector,
"updateMods": update,
"multi": multi,
@ -171,6 +173,7 @@ class _AsyncClientBulk:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append(("update", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_replace(
@ -185,7 +188,7 @@ class _AsyncClientBulk:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
cmd = {
"update": namespace,
"update": -1,
"filter": selector,
"updateMods": replacement,
"multi": False,
@ -200,6 +203,7 @@ class _AsyncClientBulk:
self.uses_collation = True
cmd["collation"] = collation
self.ops.append(("replace", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_delete(
@ -211,7 +215,7 @@ class _AsyncClientBulk:
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create a delete document and add it to the list of ops."""
cmd = {"delete": namespace, "filter": selector, "multi": multi}
cmd = {"delete": -1, "filter": selector, "multi": multi}
if hint is not None:
self.uses_hint_delete = True
cmd["hint"] = hint
@ -222,6 +226,7 @@ class _AsyncClientBulk:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append(("delete", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
@_handle_reauth
@ -407,9 +412,10 @@ class _AsyncClientBulk:
bwc: _ClientBulkWriteContext,
cmd: dict[str, Any],
ops: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Executes a batch of bulkWrite server commands (unack)."""
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops)
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
await self.unack_write(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
return to_send_ops, to_send_ns
@ -418,9 +424,10 @@ class _AsyncClientBulk:
bwc: _ClientBulkWriteContext,
cmd: dict[str, Any],
ops: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Executes a batch of bulkWrite server commands (ack)."""
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops)
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
result = await self.write_command(
bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client
) # type: ignore[arg-type]
@ -540,11 +547,12 @@ class _AsyncClientBulk:
# CSOT: apply timeout before encoding the command.
conn.apply_timeout(self.client, cmd)
ops = islice(self.ops, self.idx_offset, None)
namespaces = islice(self.namespaces, self.idx_offset, None)
# Run as many ops as possible in one server command.
if write_concern.acknowledged:
raw_result, to_send_ops, _ = await self._execute_batch(bwc, cmd, ops) # type: ignore[arg-type]
result = copy.deepcopy(raw_result)
raw_result, to_send_ops, _ = await self._execute_batch(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
result = raw_result
# Top-level server/network error.
if result.get("error"):
@ -600,7 +608,7 @@ class _AsyncClientBulk:
self.started_retryable_write = False
else:
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops) # type: ignore[arg-type]
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
self.idx_offset += len(to_send_ops)
@ -697,9 +705,10 @@ class _AsyncClientBulk:
conn.add_server_api(cmd)
ops = islice(self.ops, self.idx_offset, None)
namespaces = islice(self.namespaces, self.idx_offset, None)
# Run as many ops as possible in one server command.
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops) # type: ignore[arg-type]
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
self.idx_offset += len(to_send_ops)

View File

@ -21,7 +21,6 @@ MongoDB.
"""
from __future__ import annotations
import copy
import datetime
import random
import struct
@ -950,10 +949,13 @@ class _ClientBulkWriteContext(_BulkWriteContextBase):
)
def batch_command(
self, cmd: MutableMapping[str, Any], operations: list[tuple[str, Mapping[str, Any]]]
self,
cmd: MutableMapping[str, Any],
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
request_id, msg, to_send_ops, to_send_ns = _client_do_batched_op_msg(
cmd, operations, self.codec, self
cmd, operations, namespaces, self.codec, self
)
if not to_send_ops:
raise InvalidOperation("cannot do an empty bulk write")
@ -1035,6 +1037,7 @@ def _client_construct_op_msg(
def _client_batched_op_msg_impl(
command: Mapping[str, Any],
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
ack: bool,
opts: CodecOptions,
ctx: _ClientBulkWriteContext,
@ -1076,14 +1079,14 @@ def _client_batched_op_msg_impl(
ns_info = {}
to_send_ops: list[Mapping[str, Any]] = []
to_send_ns: list[Mapping[str, int]] = []
to_send_ns: list[Mapping[str, str]] = []
to_send_ops_encoded: list[bytes] = []
to_send_ns_encoded: list[bytes] = []
total_ops_length = 0
total_ns_length = 0
idx = 0
for real_op_type, op_doc in operations:
for (real_op_type, op_doc), namespace in zip(operations, namespaces):
op_type = real_op_type
# Check insert/replace document size if unacknowledged.
if real_op_type == "insert":
@ -1096,24 +1099,23 @@ def _client_batched_op_msg_impl(
doc_size = len(_dict_to_bson(op_doc["updateMods"], False, opts))
_check_doc_size_limits(real_op_type, doc_size, max_bson_size)
ns_doc_to_send = None
ns_doc = None
ns_length = 0
namespace = op_doc[op_type]
if namespace not in ns_info:
ns_doc_to_send = {"ns": namespace}
ns_doc = {"ns": namespace}
new_ns_index = len(to_send_ns)
ns_info[namespace] = new_ns_index
# First entry in the operation doc has the operation type as its
# key and the index of its namespace within ns_info as its value.
op_doc_to_send = copy.deepcopy(op_doc)
op_doc_to_send[op_type] = ns_info[namespace] # type: ignore[index]
op_doc[op_type] = ns_info[namespace] # type: ignore[index]
# Encode current operation doc and, if newly added, namespace doc.
op_doc_encoded = _dict_to_bson(op_doc_to_send, False, opts)
op_doc_encoded = _dict_to_bson(op_doc, False, opts)
op_length = len(op_doc_encoded)
if ns_doc_to_send:
ns_doc_encoded = _dict_to_bson(ns_doc_to_send, False, opts)
if ns_doc:
ns_doc_encoded = _dict_to_bson(ns_doc, False, opts)
ns_length = len(ns_doc_encoded)
# Check operation document size if unacknowledged.
@ -1128,11 +1130,11 @@ def _client_batched_op_msg_impl(
break
# Add op and ns documents to this batch.
to_send_ops.append(op_doc_to_send)
to_send_ops.append(op_doc)
to_send_ops_encoded.append(op_doc_encoded)
total_ops_length += op_length
if ns_doc_to_send:
to_send_ns.append(ns_doc_to_send)
if ns_doc:
to_send_ns.append(ns_doc)
to_send_ns_encoded.append(ns_doc_encoded)
total_ns_length += ns_length
@ -1153,6 +1155,7 @@ def _client_batched_op_msg_impl(
def _client_encode_batched_op_msg(
command: Mapping[str, Any],
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
ack: bool,
opts: CodecOptions,
ctx: _ClientBulkWriteContext,
@ -1163,7 +1166,7 @@ def _client_encode_batched_op_msg(
buf = _BytesIO()
to_send_ops, to_send_ns, _ = _client_batched_op_msg_impl(
command, operations, ack, opts, ctx, buf
command, operations, namespaces, ack, opts, ctx, buf
)
return buf.getvalue(), to_send_ops, to_send_ns
@ -1171,6 +1174,7 @@ def _client_encode_batched_op_msg(
def _client_batched_op_msg_compressed(
command: Mapping[str, Any],
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
ack: bool,
opts: CodecOptions,
ctx: _ClientBulkWriteContext,
@ -1179,7 +1183,7 @@ def _client_batched_op_msg_compressed(
with OP_MSG, compressed.
"""
data, to_send_ops, to_send_ns = _client_encode_batched_op_msg(
command, operations, ack, opts, ctx
command, operations, namespaces, ack, opts, ctx
)
assert ctx.conn.compression_context is not None
@ -1190,6 +1194,7 @@ def _client_batched_op_msg_compressed(
def _client_batched_op_msg(
command: Mapping[str, Any],
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
ack: bool,
opts: CodecOptions,
ctx: _ClientBulkWriteContext,
@ -1203,7 +1208,7 @@ def _client_batched_op_msg(
buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00")
to_send_ops, to_send_ns, length = _client_batched_op_msg_impl(
command, operations, ack, opts, ctx, buf
command, operations, namespaces, ack, opts, ctx, buf
)
# Header - request id and message length
@ -1219,6 +1224,7 @@ def _client_batched_op_msg(
def _client_do_batched_op_msg(
command: MutableMapping[str, Any],
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
opts: CodecOptions,
ctx: _ClientBulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
@ -1231,8 +1237,8 @@ def _client_do_batched_op_msg(
else:
ack = True
if ctx.conn.compression_context:
return _client_batched_op_msg_compressed(command, operations, ack, opts, ctx)
return _client_batched_op_msg(command, operations, ack, opts, ctx)
return _client_batched_op_msg_compressed(command, operations, namespaces, ack, opts, ctx)
return _client_batched_op_msg(command, operations, namespaces, ack, opts, ctx)
# End OP_MSG -----------------------------------------------------

View File

@ -108,6 +108,7 @@ class _ClientBulk:
self.verbose_results = verbose_results
self.ops: list[tuple[str, Mapping[str, Any]]] = []
self.namespaces: list[str] = []
self.idx_offset: int = 0
self.total_ops: int = 0
@ -132,8 +133,9 @@ class _ClientBulk:
# Generate ObjectId client side.
if not (isinstance(document, RawBSONDocument) or "_id" in document):
document["_id"] = ObjectId()
cmd = {"insert": namespace, "document": document}
cmd = {"insert": -1, "document": document}
self.ops.append(("insert", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_update(
@ -150,7 +152,7 @@ class _ClientBulk:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd = {
"update": namespace,
"update": -1,
"filter": selector,
"updateMods": update,
"multi": multi,
@ -171,6 +173,7 @@ class _ClientBulk:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append(("update", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_replace(
@ -185,7 +188,7 @@ class _ClientBulk:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
cmd = {
"update": namespace,
"update": -1,
"filter": selector,
"updateMods": replacement,
"multi": False,
@ -200,6 +203,7 @@ class _ClientBulk:
self.uses_collation = True
cmd["collation"] = collation
self.ops.append(("replace", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_delete(
@ -211,7 +215,7 @@ class _ClientBulk:
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create a delete document and add it to the list of ops."""
cmd = {"delete": namespace, "filter": selector, "multi": multi}
cmd = {"delete": -1, "filter": selector, "multi": multi}
if hint is not None:
self.uses_hint_delete = True
cmd["hint"] = hint
@ -222,6 +226,7 @@ class _ClientBulk:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append(("delete", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
@_handle_reauth
@ -407,9 +412,10 @@ class _ClientBulk:
bwc: _ClientBulkWriteContext,
cmd: dict[str, Any],
ops: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Executes a batch of bulkWrite server commands (unack)."""
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops)
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
self.unack_write(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
return to_send_ops, to_send_ns
@ -418,9 +424,10 @@ class _ClientBulk:
bwc: _ClientBulkWriteContext,
cmd: dict[str, Any],
ops: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Executes a batch of bulkWrite server commands (ack)."""
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops)
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
result = self.write_command(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
self.client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send_ops, to_send_ns # type: ignore[return-value]
@ -538,11 +545,12 @@ class _ClientBulk:
# CSOT: apply timeout before encoding the command.
conn.apply_timeout(self.client, cmd)
ops = islice(self.ops, self.idx_offset, None)
namespaces = islice(self.namespaces, self.idx_offset, None)
# Run as many ops as possible in one server command.
if write_concern.acknowledged:
raw_result, to_send_ops, _ = self._execute_batch(bwc, cmd, ops) # type: ignore[arg-type]
result = copy.deepcopy(raw_result)
raw_result, to_send_ops, _ = self._execute_batch(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
result = raw_result
# Top-level server/network error.
if result.get("error"):
@ -598,7 +606,7 @@ class _ClientBulk:
self.started_retryable_write = False
else:
to_send_ops, _ = self._execute_batch_unack(bwc, cmd, ops) # type: ignore[arg-type]
to_send_ops, _ = self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
self.idx_offset += len(to_send_ops)
@ -695,9 +703,10 @@ class _ClientBulk:
conn.add_server_api(cmd)
ops = islice(self.ops, self.idx_offset, None)
namespaces = islice(self.namespaces, self.idx_offset, None)
# Run as many ops as possible in one server command.
to_send_ops, _ = self._execute_batch_unack(bwc, cmd, ops) # type: ignore[arg-type]
to_send_ops, _ = self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
self.idx_offset += len(to_send_ops)

View File

@ -48,12 +48,9 @@ _IS_SYNC = False
class TestClientBulkWrite(AsyncIntegrationTest):
@async_client_context.require_version_min(8, 0, 0, -24)
async def test_returns_error_if_no_namespace_provided(self):
client = await async_rs_or_single_client()
self.addAsyncCleanup(client.close)
models = [InsertOne(document={"a": "b"})]
with self.assertRaises(InvalidOperation) as context:
await client.bulk_write(models=models)
await self.client.bulk_write(models=models)
self.assertIn(
"MongoClient.bulk_write requires a namespace to be provided for each write operation",
context.exception._message,
@ -64,15 +61,26 @@ class TestClientBulkWrite(AsyncIntegrationTest):
with patch.object(
_AsyncClientBulk, "write_command", return_value={"error": TypeError("mock type error")}
):
client = await async_rs_or_single_client()
self.addAsyncCleanup(client.close)
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(ClientBulkWriteException) as context:
await client.bulk_write(models=models)
await self.client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, TypeError)
self.assertFalse(hasattr(context.exception.error, "details"))
@async_client_context.require_version_min(8, 0, 0, -24)
async def test_formats_write_error_correctly(self):
models = [
InsertOne(namespace="db.coll", document={"_id": 1}),
InsertOne(namespace="db.coll", document={"_id": 1}),
]
with self.assertRaises(ClientBulkWriteException) as context:
await self.client.bulk_write(models=models)
write_error = context.exception.write_errors[0]
self.assertEqual(write_error["idx"], 1)
self.assertEqual(write_error["op"], {"insert": 0, "document": {"_id": 1}})
# https://github.com/mongodb/specifications/tree/master/source/crud/tests
class TestClientBulkWriteCRUD(AsyncIntegrationTest):

View File

@ -449,7 +449,7 @@ class TestSmallDocClientBulkInsert(SmallDocInsertTest, unittest.TestCase):
super().setUp()
self.models = []
for doc in self.documents:
self.models.append(InsertOne(namespace="perftest.corpus", document=doc.copy()))
self.models.append(InsertOne(namespace="perftest.corpus", document=doc))
@client_context.require_version_min(8, 0, 0, -24)
def do_task(self):
@ -461,7 +461,7 @@ class TestSmallDocBulkMixedOps(SmallDocMixedTest, unittest.TestCase):
super().setUp()
self.models: list[Union[InsertOne, ReplaceOne, DeleteOne]] = []
for doc in self.documents:
self.models.append(InsertOne(document=doc.copy()))
self.models.append(InsertOne(document=doc))
self.models.append(ReplaceOne(filter={}, replacement=doc.copy(), upsert=True))
self.models.append(DeleteOne(filter={}))
@ -475,7 +475,7 @@ class TestSmallDocClientBulkMixedOps(SmallDocMixedTest, unittest.TestCase):
super().setUp()
self.models: list[Union[InsertOne, ReplaceOne, DeleteOne]] = []
for doc in self.documents:
self.models.append(InsertOne(namespace="perftest.corpus", document=doc.copy()))
self.models.append(InsertOne(namespace="perftest.corpus", document=doc))
self.models.append(
ReplaceOne(
namespace="perftest.corpus", filter={}, replacement=doc.copy(), upsert=True
@ -499,7 +499,7 @@ class TestLargeDocClientBulkInsert(LargeDocInsertTest, unittest.TestCase):
super().setUp()
self.models = []
for doc in self.documents:
self.models.append(InsertOne(namespace="perftest.corpus", document=doc.copy()))
self.models.append(InsertOne(namespace="perftest.corpus", document=doc))
@client_context.require_version_min(8, 0, 0, -24)
def do_task(self):

View File

@ -48,12 +48,9 @@ _IS_SYNC = True
class TestClientBulkWrite(IntegrationTest):
@client_context.require_version_min(8, 0, 0, -24)
def test_returns_error_if_no_namespace_provided(self):
client = rs_or_single_client()
self.addCleanup(client.close)
models = [InsertOne(document={"a": "b"})]
with self.assertRaises(InvalidOperation) as context:
client.bulk_write(models=models)
self.client.bulk_write(models=models)
self.assertIn(
"MongoClient.bulk_write requires a namespace to be provided for each write operation",
context.exception._message,
@ -64,15 +61,26 @@ class TestClientBulkWrite(IntegrationTest):
with patch.object(
_ClientBulk, "write_command", return_value={"error": TypeError("mock type error")}
):
client = rs_or_single_client()
self.addCleanup(client.close)
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(ClientBulkWriteException) as context:
client.bulk_write(models=models)
self.client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, TypeError)
self.assertFalse(hasattr(context.exception.error, "details"))
@client_context.require_version_min(8, 0, 0, -24)
def test_formats_write_error_correctly(self):
models = [
InsertOne(namespace="db.coll", document={"_id": 1}),
InsertOne(namespace="db.coll", document={"_id": 1}),
]
with self.assertRaises(ClientBulkWriteException) as context:
self.client.bulk_write(models=models)
write_error = context.exception.write_errors[0]
self.assertEqual(write_error["idx"], 1)
self.assertEqual(write_error["op"], {"insert": 0, "document": {"_id": 1}})
# https://github.com/mongodb/specifications/tree/master/source/crud/tests
class TestClientBulkWriteCRUD(IntegrationTest):