PYTHON-4667 Handle $clusterTime from error responses in client Bulk Write (#1822)

This commit is contained in:
Steven Silvester 2024-09-04 19:40:37 -05:00 committed by GitHub
parent e27b428914
commit 4d4813070d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 50 additions and 7 deletions

View File

@ -281,6 +281,7 @@ class _AsyncBulk:
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
await client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
@ -308,6 +309,9 @@ class _AsyncBulk:
if bwc.publish:
bwc._fail(request_id, failure, duration)
# Process the response from the server.
if isinstance(exc, (NotPrimaryError, OperationFailure)):
await client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
raise
finally:
bwc.start_time = datetime.datetime.now()
@ -449,7 +453,6 @@ class _AsyncBulk:
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
await client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send # type: ignore[return-value]

View File

@ -283,6 +283,8 @@ class _AsyncClientBulk:
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
# Process the response from the server.
await self.client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
@ -312,6 +314,11 @@ class _AsyncClientBulk:
bwc._fail(request_id, failure, duration)
# Top-level error will be embedded in ClientBulkWriteException.
reply = {"error": exc}
# Process the response from the server.
if isinstance(exc, OperationFailure):
await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
else:
await self.client._process_response({}, bwc.session) # type: ignore[arg-type]
finally:
bwc.start_time = datetime.datetime.now()
return reply # type: ignore[return-value]
@ -431,7 +438,6 @@ class _AsyncClientBulk:
result = await self.write_command(
bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client
) # type: ignore[arg-type]
await self.client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send_ops, to_send_ns # type: ignore[return-value]
async def _process_results_cursor(

View File

@ -281,6 +281,7 @@ class _Bulk:
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
@ -308,6 +309,9 @@ class _Bulk:
if bwc.publish:
bwc._fail(request_id, failure, duration)
# Process the response from the server.
if isinstance(exc, (NotPrimaryError, OperationFailure)):
client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
raise
finally:
bwc.start_time = datetime.datetime.now()
@ -449,7 +453,6 @@ class _Bulk:
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
result = self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send # type: ignore[return-value]

View File

@ -283,6 +283,8 @@ class _ClientBulk:
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
# Process the response from the server.
self.client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
@ -312,6 +314,11 @@ class _ClientBulk:
bwc._fail(request_id, failure, duration)
# Top-level error will be embedded in ClientBulkWriteException.
reply = {"error": exc}
# Process the response from the server.
if isinstance(exc, OperationFailure):
self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
else:
self.client._process_response({}, bwc.session) # type: ignore[arg-type]
finally:
bwc.start_time = datetime.datetime.now()
return reply # type: ignore[return-value]
@ -429,7 +436,6 @@ class _ClientBulk:
"""Executes a batch of bulkWrite server commands (ack)."""
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
result = self.write_command(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
self.client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send_ops, to_send_ns # type: ignore[return-value]
def _process_results_cursor(

View File

@ -29,21 +29,22 @@ except ImportError:
from bson import Timestamp
from pymongo import DeleteMany, InsertOne, MongoClient, UpdateOne
from pymongo.errors import OperationFailure
pytestmark = pytest.mark.mockupdb
class TestClusterTime(unittest.TestCase):
def cluster_time_conversation(self, callback, replies):
def cluster_time_conversation(self, callback, replies, max_wire_version=6):
cluster_time = Timestamp(0, 0)
server = MockupDB()
# First test all commands include $clusterTime with wire version 6.
# First test all commands include $clusterTime with max_wire_version.
_ = server.autoresponds(
"ismaster",
{
"minWireVersion": 0,
"maxWireVersion": 6,
"maxWireVersion": max_wire_version,
"$clusterTime": {"clusterTime": cluster_time},
},
)
@ -166,6 +167,30 @@ class TestClusterTime(unittest.TestCase):
request.reply(reply)
client.close()
def test_collection_bulk_error(self):
def callback(client: MongoClient[dict]) -> None:
with self.assertRaises(OperationFailure):
client.db.collection.bulk_write([InsertOne({}), InsertOne({})])
self.cluster_time_conversation(
callback,
[{"ok": 0, "errmsg": "mock error"}],
)
def test_client_bulk_error(self):
def callback(client: MongoClient[dict]) -> None:
with self.assertRaises(OperationFailure):
client.bulk_write(
[
InsertOne({}, namespace="db.collection"),
InsertOne({}, namespace="db.collection"),
]
)
self.cluster_time_conversation(
callback, [{"ok": 0, "errmsg": "mock error"}], max_wire_version=25
)
if __name__ == "__main__":
unittest.main()