From 47302096f997e168d66595b4afdc50fd95207c3a Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 25 Oct 2017 15:00:35 -0700 Subject: [PATCH] PYTHON-1339 Retryable multi-statement writes. MongoClient with retryWrites=true works when the cluster does not support retryable writes. --- doc/changelog.rst | 6 +- pymongo/bulk.py | 161 ++++--- pymongo/collection.py | 3 +- pymongo/mongo_client.py | 100 ++-- pymongo/server_description.py | 7 + test/retryable_writes/bulkWrite.json | 654 ++++++++++++++++++++++++++ test/retryable_writes/insertMany.json | 153 ++++++ test/test_crud.py | 14 +- test/test_retryable_writes.py | 268 +++++++---- 9 files changed, 1168 insertions(+), 198 deletions(-) create mode 100644 test/retryable_writes/bulkWrite.json create mode 100644 test/retryable_writes/insertMany.json diff --git a/doc/changelog.rst b/doc/changelog.rst index 1ac4c592d..04622947f 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -31,12 +31,14 @@ Highlights include: :meth:`~pymongo.database.Database.list_collection_names`. - Support for mongodb+srv:// URIs. See :class:`~pymongo.mongo_client.MongoClient` for details. -- Index management helpers ( - :meth:`~pymongo.collection.Collection.create_index`, +- Index management helpers + (:meth:`~pymongo.collection.Collection.create_index`, :meth:`~pymongo.collection.Collection.create_indexes`, :meth:`~pymongo.collection.Collection.drop_index`, :meth:`~pymongo.collection.Collection.drop_indexes`, :meth:`~pymongo.collection.Collection.reindex`) now support maxTimeMS. +- Support for retryable writes and the ``retryWrites`` URI option. See + :class:`~pymongo.mongo_client.MongoClient` for details. Deprecations: diff --git a/pymongo/bulk.py b/pymongo/bulk.py index 5bc468096..804a8d0c2 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -28,8 +28,10 @@ from pymongo.common import (validate_is_mapping, from pymongo.collation import validate_collation_or_none from pymongo.errors import (BulkWriteError, ConfigurationError, + ConnectionFailure, InvalidOperation, - OperationFailure) + OperationFailure, + ServerSelectionTimeoutError) from pymongo.message import (_INSERT, _UPDATE, _DELETE, _do_batched_insert, _do_batched_write_command, @@ -64,6 +66,7 @@ class _Run(object): self.op_type = op_type self.index_map = [] self.ops = [] + self.idx_offset = 0 def index(self, idx): """Get the original index of an operation in this run. @@ -145,6 +148,10 @@ class _Bulk(object): self.bypass_doc_val = bypass_document_validation self.uses_collation = False self.uses_array_filters = False + self.is_retryable = True + self.retrying = False + # Extra state so that we know where to pick up on a retry attempt. + self.current_run = None def add_insert(self, document): """Add an insert document to the list of ops. @@ -169,6 +176,9 @@ class _Bulk(object): if array_filters is not None: self.uses_array_filters = True cmd['arrayFilters'] = array_filters + if multi: + # A bulk_write containing an update_many is not retryable. + self.is_retryable = False self.ops.append((_UPDATE, cmd)) def add_replace(self, selector, replacement, upsert=False, @@ -192,6 +202,9 @@ class _Bulk(object): if collation is not None: self.uses_collation = True cmd['collation'] = collation + if limit == _DELETE_ALL: + # A bulk_write containing a delete_many is not retryable. + self.is_retryable = False self.ops.append((_DELETE, cmd)) def gen_ordered(self): @@ -220,7 +233,70 @@ class _Bulk(object): if run.ops: yield run - def execute_command(self, sock_info, generator, write_concern, session): + def _execute_command(self, generator, write_concern, session, + sock_info, op_id, retryable, full_result): + if sock_info.max_wire_version < 5 and self.uses_collation: + raise ConfigurationError( + 'Must be connected to MongoDB 3.4+ to use a collation.') + if sock_info.max_wire_version < 6 and self.uses_array_filters: + raise ConfigurationError( + 'Must be connected to MongoDB 3.6+ to use arrayFilters.') + + db_name = self.collection.database.name + client = self.collection.database.client + listeners = client._event_listeners + + if not self.current_run: + self.current_run = next(generator) + run = self.current_run + + # sock_info.command validates the session, but we use + # sock_info.write_command. + sock_info.validate_session(client, session) + while run: + cmd = SON([(_COMMANDS[run.op_type], self.collection.name), + ('ordered', self.ordered)]) + if write_concern.document: + cmd['writeConcern'] = write_concern.document + if self.bypass_doc_val and sock_info.max_wire_version >= 4: + cmd['bypassDocumentValidation'] = True + if session: + cmd['lsid'] = session._use_lsid() + bwc = _BulkWriteContext(db_name, cmd, sock_info, op_id, + listeners, session) + + results = [] + while run.idx_offset < len(run.ops): + if session and retryable: + cmd['txnNumber'] = session._transaction_id() + client._send_cluster_time(cmd, session) + check_keys = run.op_type == _INSERT + ops = islice(run.ops, run.idx_offset, None) + # Run as many ops as possible. + request_id, msg, to_send = _do_batched_write_command( + self.namespace, run.op_type, cmd, ops, check_keys, + self.collection.codec_options, bwc) + if not to_send: + raise InvalidOperation("cannot do an empty bulk write") + result = bwc.write_command(request_id, msg, to_send) + client._receive_cluster_time(result, session) + results.append((run.idx_offset, result)) + # We're no longer in a retry once a command succeeds. + self.retrying = False + if self.ordered and "writeErrors" in result: + break + run.idx_offset += len(to_send) + + _merge_command(run, full_result, results) + + # We're supposed to continue if errors are + # at the write concern level (e.g. wtimeout) + if self.ordered and full_result['writeErrors']: + break + # Reset our state + self.current_run = run = next(generator, None) + + def execute_command(self, generator, write_concern, session): """Execute using write commands. """ # nModified is only reported for write commands, not legacy ops. @@ -235,51 +311,16 @@ class _Bulk(object): "upserted": [], } op_id = _randint() - db_name = self.collection.database.name + + def retryable_bulk(session, sock_info, retryable): + self._execute_command( + generator, write_concern, session, sock_info, op_id, + retryable, full_result) + client = self.collection.database.client - listeners = client._event_listeners - - with self.collection.database.client._tmp_session(session) as s: - # sock_info.command validates the session, but we use - # sock_info.write_command. - sock_info.validate_session(client, s) - for run in generator: - cmd = SON([(_COMMANDS[run.op_type], self.collection.name), - ('ordered', self.ordered)]) - if write_concern.document: - cmd['writeConcern'] = write_concern.document - if self.bypass_doc_val and sock_info.max_wire_version >= 4: - cmd['bypassDocumentValidation'] = True - if s: - cmd['lsid'] = s._use_lsid() - bwc = _BulkWriteContext(db_name, cmd, sock_info, op_id, - listeners, s) - - results = [] - idx_offset = 0 - while idx_offset < len(run.ops): - check_keys = run.op_type == _INSERT - ops = islice(run.ops, idx_offset, None) - # Run as many ops as possible. - client._send_cluster_time(cmd, s) - request_id, msg, to_send = _do_batched_write_command( - self.namespace, run.op_type, cmd, ops, check_keys, - self.collection.codec_options, bwc) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - result = bwc.write_command(request_id, msg, to_send) - client._receive_cluster_time(result, s) - results.append((idx_offset, result)) - if self.ordered and "writeErrors" in result: - break - idx_offset += len(to_send) - - _merge_command(run, full_result, results) - - # We're supposed to continue if errors are - # at the write concern level (e.g. wtimeout) - if self.ordered and full_result['writeErrors']: - break + with client._tmp_session(session) as s: + client._retry_with_session( + self.is_retryable, retryable_bulk, s, self) if full_result["writeErrors"] or full_result["writeConcernErrors"]: if full_result['writeErrors']: @@ -309,6 +350,12 @@ class _Bulk(object): def execute_no_results(self, sock_info, generator): """Execute all operations, returning no results (w=0). """ + if self.uses_collation: + raise ConfigurationError( + 'Collation is unsupported for unacknowledged writes.') + if self.uses_array_filters: + raise ConfigurationError( + 'arrayFilters is unsupported for unacknowledged writes.') # Cannot have both unacknowledged write and bypass document validation. if self.bypass_doc_val and sock_info.max_wire_version >= 4: raise OperationFailure("Cannot set bypass_document_validation with" @@ -378,25 +425,11 @@ class _Bulk(object): generator = self.gen_unordered() client = self.collection.database.client - with client._socket_for_writes() as sock_info: - if sock_info.max_wire_version < 5 and self.uses_collation: - raise ConfigurationError( - 'Must be connected to MongoDB 3.4+ to use a collation.') - if sock_info.max_wire_version < 6 and self.uses_array_filters: - raise ConfigurationError( - 'Must be connected to MongoDB 3.6+ to use arrayFilters.') - if not write_concern.acknowledged: - if self.uses_collation: - raise ConfigurationError( - 'Collation is unsupported for unacknowledged writes.') - if self.uses_array_filters: - raise ConfigurationError( - 'arrayFilters is unsupported for unacknowledged ' - 'writes.') + if not write_concern.acknowledged: + with client._socket_for_writes() as sock_info: self.execute_no_results(sock_info, generator) - else: - return self.execute_command( - sock_info, generator, write_concern, session) + else: + return self.execute_command(generator, write_concern, session) class BulkUpsertOperation(object): diff --git a/pymongo/collection.py b/pymongo/collection.py index 65fd2b003..6271cd59c 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -738,8 +738,7 @@ class Collection(common.BaseObject): blk = _Bulk(self, ordered, bypass_document_validation) blk.ops = [doc for doc in gen()] - with self.__database.client._tmp_session(session) as s: - blk.execute(self.write_concern.document, session=s) + blk.execute(self.write_concern.document, session=session) return InsertManyResult(inserted_ids, self.write_concern.acknowledged) def _update(self, sock_info, criteria, document, upsert=False, diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index f2a24e2f6..f62d19fdd 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -225,8 +225,26 @@ class MongoClient(common.BaseObject): - `event_listeners`: a list or tuple of event listeners. See :mod:`~pymongo.monitoring` for details. - `retryWrites`: (boolean) Whether supported write operations - executed within this MongoClient will be retried after a network - error, requires MongoDB 3.6+. See + executed within this MongoClient will be retried once after a + network error on MongoDB 3.6+. Defaults to ``False``. + The supported write operations are: + + - :meth:`~pymongo.collection.Collection.bulk_write`, as long as + :class:`~pymongo.operations.UpdateMany` or + :class:`~pymongo.operations.DeleteMany` are not included. + - :meth:`~pymongo.collection.Collection.delete_one` + - :meth:`~pymongo.collection.Collection.insert_one` + - :meth:`~pymongo.collection.Collection.insert_many` + - :meth:`~pymongo.collection.Collection.replace_one` + - :meth:`~pymongo.collection.Collection.update_one` + - :meth:`~pymongo.collection.Collection.find_one_and_delete` + - :meth:`~pymongo.collection.Collection.find_one_and_replace` + - :meth:`~pymongo.collection.Collection.find_one_and_update` + + Unsupported write operations include, but are not limited to, + :meth:`~pymongo.collection.Collection.aggregate` using the ``$out`` + pipeline operator and any operation with an unacknowledged write + concern (e.g. {w: 0})). See https://github.com/mongodb/specifications/blob/master/source/retryable-writes/retryable-writes.rst - `socketKeepAlive`: (boolean) **DEPRECATED** Whether to send periodic keep-alive packets on connected sockets. Defaults to @@ -890,8 +908,7 @@ class MongoClient(common.BaseObject): return self._topology @contextlib.contextmanager - def _get_socket(self, selector): - server = self._get_topology().select_server(selector) + def _get_socket(self, server): try: with server.get_socket(self.__all_credentials) as sock_info: yield sock_info @@ -914,7 +931,8 @@ class MongoClient(common.BaseObject): raise def _socket_for_writes(self): - return self._get_socket(writable_server_selector) + server = self._get_topology().select_server(writable_server_selector) + return self._get_socket(server) @contextlib.contextmanager def _socket_for_reads(self, read_preference): @@ -927,7 +945,8 @@ class MongoClient(common.BaseObject): # Thread safe: if the type is single it cannot change. topology = self._get_topology() single = topology.description.topology_type == TOPOLOGY_TYPE.Single - with self._get_socket(read_preference) as sock_info: + server = topology.select_server(read_preference) + with self._get_socket(server) as sock_info: slave_ok = (single and not sock_info.is_mongos) or ( preference != ReadPreference.PRIMARY) yield sock_info, slave_ok @@ -992,46 +1011,61 @@ class MongoClient(common.BaseObject): self.__reset_server(server.description.address) raise - def _retryable_write(self, retryable_operation, func, session): - """Execute an operation possibly with one retry. + def _retry_with_session(self, retryable, func, session, bulk): + """Execute an operation with at most one consecutive retries - Returns func()'s return value on success. On error retries once. + Returns func()'s return value on success. On error retries the same + command once. Re-raises any exception thrown by func(). """ - retryable = retryable_operation and self.retry_writes - with self._tmp_session(session) as s: + retryable = retryable and self.retry_writes + last_error = None + retrying = False + + def is_retrying(): + return bulk.retrying if bulk else retrying + while True: try: - with self._socket_for_writes() as sock_info: - if retryable and ( - s is None or not sock_info.supports_sessions): - raise ConfigurationError( - 'Retryable writes are not supported by this ' - 'MongoDB deployment') - return func(s, sock_info, retryable) + server = self._get_topology().select_server( + writable_server_selector) + supports_session = ( + session is not None and + server.description.retryable_writes_supported) + with self._get_socket(server) as sock_info: + if retryable and not supports_session: + if is_retrying(): + # A retry is not possible because this server does + # not support sessions raise the last error. + raise last_error + retryable = False + if is_retrying(): + # Reset the transaction id and retry the operation. + session._retry_transaction_id() + return func(session, sock_info, retryable) except ServerSelectionTimeoutError: + if is_retrying(): + # The application may think the write was never attempted + # if we raise ServerSelectionTimeoutError on the retry + # attempt. Raise the original exception instead. + raise last_error # A ServerSelectionTimeoutError error indicates that there may # be a persistent outage. Attempting to retry in this case will # most likely be a waste of time. raise except ConnectionFailure as exc: - if not retryable: + if not retryable or is_retrying(): raise - try: - with self._socket_for_writes() as sock_info: - if sock_info.supports_sessions: - # Reset the transaction id and retry the operation. - s._retry_transaction_id() - return func(s, sock_info, retryable) + if bulk: + bulk.retrying = True + else: + retrying = True + last_error = exc - # A retry was not possible because the new server does - # not support sessions raise the original error. - raise - except ServerSelectionTimeoutError: - # The application may think the write was never attempted - # if we raise ServerSelectionTimeoutError on the retry - # attempt. Raise the original exception instead. - raise exc + def _retryable_write(self, retryable, func, session): + """Internal retryable write helper.""" + with self._tmp_session(session) as s: + return self._retry_with_session(retryable, func, s, None) def __reset_server(self, address): """Clear our connection pool for a server and mark it Unknown.""" diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 841791160..b5ee79791 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -201,5 +201,12 @@ class ServerDescription(object): def is_server_type_known(self): return self.server_type != SERVER_TYPE.Unknown + @property + def retryable_writes_supported(self): + """Checks if this server supports retryable writes.""" + return ( + self._ls_timeout_minutes is not None and + self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)) + # For unittesting only. Use under no circumstances! _host_to_round_trip_time = {} diff --git a/test/retryable_writes/bulkWrite.json b/test/retryable_writes/bulkWrite.json new file mode 100644 index 000000000..7b88ffb3c --- /dev/null +++ b/test/retryable_writes/bulkWrite.json @@ -0,0 +1,654 @@ +{ + "data": [ + { + "_id": 1, + "x": 11 + } + ], + "minServerVersion": "3.6", + "tests": [ + { + "description": "First command is retried", + "failPoint": { + "mode": { + "times": 1 + } + }, + "operation": { + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 2 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + }, + { + "name": "deleteOne", + "arguments": { + "filter": { + "_id": 1 + } + } + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "result": { + "deletedCount": 1, + "insertedIds": { + "0": 2 + }, + "matchedCount": 1, + "modifiedCount": 1, + "upsertedCount": 0, + "upsertedIds": {} + }, + "collection": { + "data": [ + { + "_id": 2, + "x": 23 + } + ] + } + } + }, + { + "description": "All commands are retried", + "failPoint": { + "mode": { + "times": 7 + } + }, + "operation": { + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 2 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + }, + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 3, + "x": 33 + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 4, + "x": 44 + }, + "update": { + "$inc": { + "x": 1 + } + }, + "upsert": true + } + }, + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 5, + "x": 55 + } + } + }, + { + "name": "replaceOne", + "arguments": { + "filter": { + "_id": 3 + }, + "replacement": { + "_id": 3, + "x": 333 + } + } + }, + { + "name": "deleteOne", + "arguments": { + "filter": { + "_id": 1 + } + } + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "result": { + "deletedCount": 1, + "insertedIds": { + "0": 2, + "2": 3, + "4": 5 + }, + "matchedCount": 2, + "modifiedCount": 2, + "upsertedCount": 1, + "upsertedIds": { + "3": 4 + } + }, + "collection": { + "data": [ + { + "_id": 2, + "x": 23 + }, + { + "_id": 3, + "x": 333 + }, + { + "_id": 4, + "x": 45 + }, + { + "_id": 5, + "x": 55 + } + ] + } + } + }, + { + "description": "Both commands are retried after their first statement fails", + "failPoint": { + "mode": { + "times": 2 + } + }, + "operation": { + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 1 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 2 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "result": { + "deletedCount": 0, + "insertedIds": { + "0": 2 + }, + "matchedCount": 2, + "modifiedCount": 2, + "upsertedCount": 0, + "upsertedIds": {} + }, + "collection": { + "data": [ + { + "_id": 1, + "x": 12 + }, + { + "_id": 2, + "x": 23 + } + ] + } + } + }, + { + "description": "Second command is retried after its second statement fails", + "failPoint": { + "mode": { + "skip": 2 + } + }, + "operation": { + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 1 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 2 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "result": { + "deletedCount": 0, + "insertedIds": { + "0": 2 + }, + "matchedCount": 2, + "modifiedCount": 2, + "upsertedCount": 0, + "upsertedIds": {} + }, + "collection": { + "data": [ + { + "_id": 1, + "x": 12 + }, + { + "_id": 2, + "x": 23 + } + ] + } + } + }, + { + "description": "BulkWrite with unordered execution", + "failPoint": { + "mode": { + "times": 1 + } + }, + "operation": { + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 3, + "x": 33 + } + } + } + ], + "options": { + "ordered": false + } + } + }, + "outcome": { + "result": { + "deletedCount": 0, + "insertedIds": { + "0": 2, + "1": 3 + }, + "matchedCount": 0, + "modifiedCount": 0, + "upsertedCount": 0, + "upsertedIds": {} + }, + "collection": { + "data": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + } + }, + { + "description": "First insertOne is never committed", + "failPoint": { + "mode": { + "times": 2 + }, + "data": { + "failBeforeCommitExceptionCode": 1 + } + }, + "operation": { + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 2 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + }, + { + "name": "deleteOne", + "arguments": { + "filter": { + "_id": 1 + } + } + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "error": true, + "result": { + "deletedCount": 0, + "insertedIds": {}, + "matchedCount": 0, + "modifiedCount": 0, + "upsertedCount": 0, + "upsertedIds": {} + }, + "collection": { + "data": [ + { + "_id": 1, + "x": 11 + } + ] + } + } + }, + { + "description": "Second updateOne is never committed", + "failPoint": { + "mode": { + "skip": 1 + }, + "data": { + "failBeforeCommitExceptionCode": 1 + } + }, + "operation": { + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 2 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + }, + { + "name": "deleteOne", + "arguments": { + "filter": { + "_id": 1 + } + } + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "error": true, + "result": { + "deletedCount": 0, + "insertedIds": { + "0": 2 + }, + "matchedCount": 0, + "modifiedCount": 0, + "upsertedCount": 0, + "upsertedIds": {} + }, + "collection": { + "data": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + } + ] + } + } + }, + { + "description": "Third updateOne is never committed", + "failPoint": { + "mode": { + "skip": 2 + }, + "data": { + "failBeforeCommitExceptionCode": 1 + } + }, + "operation": { + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 1 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + }, + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + }, + { + "name": "updateOne", + "arguments": { + "filter": { + "_id": 2 + }, + "update": { + "$inc": { + "x": 1 + } + } + } + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "error": true, + "result": { + "deletedCount": 0, + "insertedIds": { + "0": 2 + }, + "matchedCount": 1, + "modifiedCount": 1, + "upsertedCount": 0, + "upsertedIds": {} + }, + "collection": { + "data": [ + { + "_id": 1, + "x": 12 + }, + { + "_id": 2, + "x": 22 + } + ] + } + } + } + ] +} diff --git a/test/retryable_writes/insertMany.json b/test/retryable_writes/insertMany.json new file mode 100644 index 000000000..2d71cb911 --- /dev/null +++ b/test/retryable_writes/insertMany.json @@ -0,0 +1,153 @@ +{ + "data": [ + { + "_id": 1, + "x": 11 + } + ], + "minServerVersion": "3.6", + "tests": [ + { + "description": "InsertMany succeeds after one network error", + "failPoint": { + "mode": { + "times": 1 + } + }, + "operation": { + "name": "insertMany", + "arguments": { + "documents": [ + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "result": { + "insertedIds": { + "0": 2, + "1": 3 + } + }, + "collection": { + "data": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + } + }, + { + "description": "InsertMany with unordered execution", + "failPoint": { + "mode": { + "times": 1 + } + }, + "operation": { + "name": "insertMany", + "arguments": { + "documents": [ + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ], + "options": { + "ordered": false + } + } + }, + "outcome": { + "result": { + "insertedIds": { + "0": 2, + "1": 3 + } + }, + "collection": { + "data": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + } + }, + { + "description": "InsertMany fails after multiple network errors", + "failPoint": { + "mode": "alwaysOn", + "data": { + "failBeforeCommitExceptionCode": 1 + } + }, + "operation": { + "name": "insertMany", + "arguments": { + "documents": [ + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + }, + { + "_id": 4, + "x": 44 + } + ], + "options": { + "ordered": true + } + } + }, + "outcome": { + "error": true, + "collection": { + "data": [ + { + "_id": 1, + "x": 11 + } + ] + } + } + } + ] +} diff --git a/test/test_crud.py b/test/test_crud.py index 56b8316aa..fbd34a71f 100644 --- a/test/test_crud.py +++ b/test/test_crud.py @@ -74,7 +74,19 @@ def check_result(expected_result, result): # BulkWriteResult does not have inserted_ids. if isinstance(result, BulkWriteResult): return len(expected_result[res]) == result.inserted_count - return expected_result[res] == result.inserted_ids + # InsertManyResult may be compared to [id1] from the + # crud spec or {"0": id1} from the retryable write spec. + ids = expected_result[res] + if isinstance(ids, dict): + ids = [ids[str(i)] for i in range(len(ids))] + return ids == result.inserted_ids + elif prop == "upserted_ids": + # Convert indexes from strings to integers. + ids = expected_result[res] + expected_ids = {} + for str_index in ids: + expected_ids[int(str_index)] = ids[str_index] + return expected_ids == result.upserted_ids elif getattr(result, prop) != expected_result[res]: return False return True diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 371293ad5..1235531a9 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -16,24 +16,30 @@ import json import os -import re import sys sys.path[0:0] = [""] -from bson import SON +from bson.int64 import Int64 +from bson.objectid import ObjectId +from bson.son import SON -from pymongo.errors import (ConfigurationError, - ConnectionFailure, - OperationFailure, + +from pymongo.errors import (ConnectionFailure, ServerSelectionTimeoutError) from pymongo.monitoring import _SENSITIVE_COMMANDS from pymongo.mongo_client import MongoClient +from pymongo.operations import (InsertOne, + DeleteMany, + DeleteOne, + ReplaceOne, + UpdateMany, + UpdateOne) from pymongo.write_concern import WriteConcern -from test import unittest, client_context, IntegrationTest +from test import unittest, client_context, IntegrationTest, SkipTest from test.utils import rs_or_single_client, EventListener, DeprecationFilter -from test.test_crud import run_operation +from test.test_crud import check_result, run_operation # Location of JSON test specifications. _TEST_PATH = os.path.join( @@ -62,6 +68,7 @@ class TestAllScenarios(IntegrationTest): @classmethod @client_context.require_version_min(3, 5) @client_context.require_replica_set + @client_context.require_test_commands def setUpClass(cls): super(TestAllScenarios, cls).setUpClass() cls.client = rs_or_single_client(retryWrites=True) @@ -90,6 +97,7 @@ def create_test(scenario_def, test): test_outcome = test['outcome'] should_fail = test_outcome.get('error') + result = None error = None try: result = run_operation(self.db.test, test) @@ -111,8 +119,11 @@ def create_test(scenario_def, test): db_coll = self.db.test self.assertEqual(list(db_coll.find()), expected_c['data']) expected_result = test_outcome.get('result') - if expected_result is not None: - self.assertTrue(result, expected_result) + # We can't test the expected result when the test should fail because + # the BulkWriteResult is not reported when raising a network error. + if not should_fail: + self.assertTrue(check_result(expected_result, result), + "%r != %r" % (expected_result, result)) return run_scenario @@ -144,24 +155,35 @@ create_tests() def retryable_single_statement_ops(coll): return [ + (coll.bulk_write, [[InsertOne({}), InsertOne({})]], {}), + (coll.bulk_write, [[InsertOne({}), + InsertOne({})]], {'ordered': False}), + (coll.bulk_write, [[ReplaceOne({}, {})]], {}), + (coll.bulk_write, [[ReplaceOne({}, {}), ReplaceOne({}, {})]], {}), + (coll.bulk_write, [[UpdateOne({}, {'$set': {'a': 1}}), + UpdateOne({}, {'$set': {'a': 1}})]], {}), + (coll.bulk_write, [[DeleteOne({})]], {}), + (coll.bulk_write, [[DeleteOne({}), DeleteOne({})]], {}), (coll.insert_one, [{}], {}), + (coll.insert_many, [[{}, {}]], {}), (coll.replace_one, [{}, {}], {}), (coll.update_one, [{}, {'$set': {'a': 1}}], {}), (coll.delete_one, [{}], {}), - # Insert document for find_one_and_*. - (coll.insert_one, [{}], {}), (coll.find_one_and_replace, [{}, {'a': 3}], {}), (coll.find_one_and_update, [{}, {'$set': {'a': 1}}], {}), (coll.find_one_and_delete, [{}, {}], {}), # Deprecated methods. - # Insert document for update. - (coll.insert_one, [{}], {}), + # Insert with single or multiple documents. + (coll.insert, [{}], {}), + (coll.insert, [[{}]], {}), + (coll.insert, [[{}, {}]], {}), + # Save with and without an _id. + (coll.save, [{}], {}), + (coll.save, [{'_id': ObjectId()}], {}), # Non-multi update. (coll.update, [{}, {'$set': {'a': 1}}], {}), # Non-multi remove. (coll.remove, [{}], {'multi': False}), - # Insert document for find_and_modify. - (coll.insert_one, [{}], {}), # Replace. (coll.find_and_modify, [{}, {'a': 3}], {}), # Update. @@ -173,6 +195,9 @@ def retryable_single_statement_ops(coll): def non_retryable_single_statement_ops(coll): return [ + (coll.bulk_write, [[UpdateOne({}, {'$set': {'a': 1}}), + UpdateMany({}, {'$set': {'a': 1}})]], {}), + (coll.bulk_write, [[DeleteOne({}), DeleteMany({})]], {}), (coll.update_many, [{}, {'$set': {'a': 1}}], {}), (coll.delete_many, [{}], {}), # Deprecated methods. @@ -213,8 +238,6 @@ class IgnoreDeprecationsTest(IntegrationTest): class TestRetryableWrites(IgnoreDeprecationsTest): @classmethod - @client_context.require_version_min(3, 5) - @client_context.require_no_standalone def setUpClass(cls): super(TestRetryableWrites, cls).setUpClass() cls.listener = CommandListener() @@ -223,13 +246,15 @@ class TestRetryableWrites(IgnoreDeprecationsTest): cls.db = cls.client.pymongo_test def setUp(self): - if client_context.is_rs: + if (client_context.version.at_least(3, 5) and client_context.is_rs + and client_context.test_commands_enabled): self.client.admin.command(SON([ ('configureFailPoint', 'onPrimaryTransactionalWrite'), ('mode', 'alwaysOn')])) def tearDown(self): - if client_context.is_rs: + if (client_context.version.at_least(3, 5) and client_context.is_rs + and client_context.test_commands_enabled): self.client.admin.command(SON([ ('configureFailPoint', 'onPrimaryTransactionalWrite'), ('mode', 'off')])) @@ -238,54 +263,72 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener = CommandListener() client = rs_or_single_client( retryWrites=False, event_listeners=[listener]) + self.addCleanup(client.close) for method, args, kwargs in retryable_single_statement_ops( client.db.retryable_write_test): + msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) listener.results.clear() method(*args, **kwargs) for event in listener.results['started']: self.assertNotIn( 'txnNumber', event.command, - '%s sent txnNumber with %s' % ( - method.__name__, event.command_name)) - - def test_supported_single_statement(self): + '%s sent txnNumber with %s' % (msg, event.command_name)) + @client_context.require_version_min(3, 5) + @client_context.require_no_standalone + def test_supported_single_statement_supported_cluster(self): for method, args, kwargs in retryable_single_statement_ops( self.db.retryable_write_test): + msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) self.listener.results.clear() method(*args, **kwargs) commands_started = self.listener.results['started'] - self.assertEqual(len(self.listener.results['succeeded']), 1, - method.__name__) + self.assertEqual(len(self.listener.results['succeeded']), 1, msg) first_attempt = commands_started[0] self.assertIn( 'lsid', first_attempt.command, - '%s sent no lsid with %s' % ( - method.__name__, first_attempt.command_name)) + '%s sent no lsid with %s' % (msg, first_attempt.command_name)) initial_session_id = first_attempt.command['lsid'] self.assertIn( 'txnNumber', first_attempt.command, '%s sent no txnNumber with %s' % ( - method.__name__, first_attempt.command_name)) + msg, first_attempt.command_name)) - # The failpoint is only enabled on a replica set. - if client_context.is_rs: - self.assertEqual(len(self.listener.results['failed']), 1, - method.__name__) - initial_transaction_id = first_attempt.command['txnNumber'] - retry_attempt = commands_started[1] - self.assertIn( - 'lsid', retry_attempt.command, - '%s sent no lsid with %s' % ( - method.__name__, first_attempt.command_name)) - self.assertEqual( - retry_attempt.command['lsid'], initial_session_id) - self.assertIn( - 'txnNumber', retry_attempt.command, - '%s sent no txnNumber with %s' % ( - method.__name__, first_attempt.command_name)) - self.assertEqual(retry_attempt.command['txnNumber'], - initial_transaction_id) + # There should be no retry when the failpoint is not active. + if (client_context.is_mongos or + not client_context.test_commands_enabled): + self.assertEqual(len(commands_started), 1) + continue + + initial_transaction_id = first_attempt.command['txnNumber'] + retry_attempt = commands_started[1] + self.assertIn( + 'lsid', retry_attempt.command, + '%s sent no lsid with %s' % (msg, first_attempt.command_name)) + self.assertEqual( + retry_attempt.command['lsid'], initial_session_id, msg) + self.assertIn( + 'txnNumber', retry_attempt.command, + '%s sent no txnNumber with %s' % ( + msg, first_attempt.command_name)) + self.assertEqual(retry_attempt.command['txnNumber'], + initial_transaction_id, msg) + + def test_supported_single_statement_unsupported_cluster(self): + if client_context.version.at_least(3, 5) and ( + client_context.is_rs or client_context.is_mongos): + raise SkipTest('This cluster supports retryable writes') + + for method, args, kwargs in retryable_single_statement_ops( + self.db.retryable_write_test): + msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) + self.listener.results.clear() + method(*args, **kwargs) + + for event in self.listener.results['started']: + self.assertNotIn( + 'txnNumber', event.command, + '%s sent txnNumber with %s' % (msg, event.command_name)) def test_unsupported_single_statement(self): coll = self.db.retryable_write_test @@ -293,33 +336,36 @@ class TestRetryableWrites(IgnoreDeprecationsTest): coll_w0 = coll.with_options(write_concern=WriteConcern(w=0)) for method, args, kwargs in (non_retryable_single_statement_ops(coll) + retryable_single_statement_ops(coll_w0)): + msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) self.listener.results.clear() method(*args, **kwargs) started_events = self.listener.results['started'] - self.assertEqual(len(self.listener.results['succeeded']), 1, - method.__name__) - self.assertEqual(len(started_events), 1, method.__name__) + self.assertEqual(len(self.listener.results['succeeded']), + len(started_events), msg) + self.assertEqual(len(self.listener.results['failed']), 0, msg) for event in started_events: self.assertNotIn( 'txnNumber', event.command, - '%s sent txnNumber with %s' % ( - method.__name__, event.command_name)) + '%s sent txnNumber with %s' % (msg, event.command_name)) def test_server_selection_timeout_not_retried(self): """A ServerSelectionTimeoutError is not retried.""" listener = CommandListener() client = MongoClient( 'somedomainthatdoesntexist.org', - serverSelectionTimeoutMS=10, + serverSelectionTimeoutMS=1, retryWrites=True, event_listeners=[listener]) for method, args, kwargs in retryable_single_statement_ops( client.db.retryable_write_test): + msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) listener.results.clear() - with self.assertRaises(ServerSelectionTimeoutError): + with self.assertRaises(ServerSelectionTimeoutError, msg=msg): method(*args, **kwargs) - self.assertEqual(len(listener.results['started']), 0) + self.assertEqual(len(listener.results['started']), 0, msg) + @client_context.require_version_min(3, 5) @client_context.require_replica_set + @client_context.require_test_commands def test_retry_timeout_raises_original_error(self): """A ServerSelectionTimeoutError on the retry attempt raises the original error. @@ -327,61 +373,91 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener = CommandListener() client = rs_or_single_client( retryWrites=True, event_listeners=[listener]) - socket_for_writes = client._socket_for_writes + self.addCleanup(client.close) + topology = client._topology + select_server = topology.select_server - def mock_socket_for_writes(*args, **kwargs): - sock_info = socket_for_writes(*args, **kwargs) + def mock_select_server(*args, **kwargs): + server = select_server(*args, **kwargs) - def raise_error(): + def raise_error(*args, **kwargs): raise ServerSelectionTimeoutError( 'No primary available for writes') # Raise ServerSelectionTimeout on the retry attempt. - client._socket_for_writes = raise_error - return sock_info + topology.select_server = raise_error + return server for method, args, kwargs in retryable_single_statement_ops( client.db.retryable_write_test): + msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) listener.results.clear() - client._socket_for_writes = mock_socket_for_writes - with self.assertRaises(ConnectionFailure): - method(*args, **kwargs) - self.assertEqual(len(listener.results['started']), 1) - - -class TestRetryableWritesNotSupported(IgnoreDeprecationsTest): - - @client_context.require_version_max(3, 5, 0, -1) - def test_raises_error(self): - client = rs_or_single_client(retryWrites=True) - coll = client.pymongo_test.test - # No error running non-retryable operations. - client.admin.command('isMaster') - for method, args, kwargs in non_retryable_single_statement_ops(coll): - method(*args, **kwargs) - - for method, args, kwargs in retryable_single_statement_ops(coll): - with self.assertRaisesRegex( - ConfigurationError, - 'Retryable writes are not supported by this MongoDB ' - 'deployment'): + topology.select_server = mock_select_server + with self.assertRaises(ConnectionFailure, msg=msg): method(*args, **kwargs) + self.assertEqual(len(listener.results['started']), 1, msg) @client_context.require_version_min(3, 5) - @client_context.require_standalone - def test_standalone_raises_error(self): - client = rs_or_single_client(retryWrites=True) - coll = client.pymongo_test.test - # No error running non-retryable operations. - client.admin.command('isMaster') - for method, args, kwargs in non_retryable_single_statement_ops(coll): - method(*args, **kwargs) + @client_context.require_replica_set + @client_context.require_test_commands + def test_batch_splitting(self): + """Test retry succeeds after failures during batch splitting.""" + large = 's' * 1024 * 1024 * 15 + coll = self.db.retryable_write_test + coll.delete_many({}) + self.listener.results.clear() + coll.bulk_write([ + InsertOne({'_id': 1, 'l': large}), + InsertOne({'_id': 2, 'l': large}), + InsertOne({'_id': 3, 'l': large}), + UpdateOne({'_id': 1, 'l': large}, + {'$unset': {'l': 1}, '$inc': {'count': 1}}), + UpdateOne({'_id': 2, 'l': large}, {'$set': {'foo': 'bar'}}), + DeleteOne({'l': large}), + DeleteOne({'l': large})]) + # Each command should fail and be retried. + self.assertEqual(len(self.listener.results['started']), 14) + self.assertEqual(coll.find_one(), {'_id': 1, 'count': 1}) - for method, args, kwargs in retryable_single_statement_ops(coll): - with self.assertRaisesRegex( - OperationFailure, - 'Transaction numbers are only allowed on a replica set ' - 'member or mongos'): - method(*args, **kwargs) + @client_context.require_version_min(3, 5) + @client_context.require_replica_set + @client_context.require_test_commands + def test_batch_splitting_retry_fails(self): + """Test retry fails during batch splitting.""" + large = 's' * 1024 * 1024 * 15 + coll = self.db.retryable_write_test + coll.delete_many({}) + self.client.admin.command(SON([ + ('configureFailPoint', 'onPrimaryTransactionalWrite'), + ('mode', {'skip': 1}), + ('data', {'failBeforeCommitExceptionCode': 1})])) + self.listener.results.clear() + with self.client.start_session() as session: + initial_txn = session._server_session._transaction_id + try: + coll.bulk_write([InsertOne({'_id': 1, 'l': large}), + InsertOne({'_id': 2, 'l': large}), + InsertOne({'_id': 3, 'l': large})], + session=session) + except ConnectionFailure: + pass + else: + self.fail("bulk_write should have failed") + + started = self.listener.results['started'] + self.assertEqual(len(started), 3) + self.assertEqual(len(self.listener.results['succeeded']), 1) + expected_txn = Int64(initial_txn + 1) + self.assertEqual(started[0].command['txnNumber'], expected_txn) + self.assertEqual(started[0].command['lsid'], session.session_id) + expected_txn = Int64(initial_txn + 2) + self.assertEqual(started[1].command['txnNumber'], expected_txn) + self.assertEqual(started[1].command['lsid'], session.session_id) + started[1].command.pop('$clusterTime') + started[2].command.pop('$clusterTime') + self.assertEqual(started[1].command, started[2].command) + final_txn = session._server_session._transaction_id + self.assertEqual(final_txn, expected_txn) + self.assertEqual(coll.find_one(projection={'_id': True}), {'_id': 1}) if __name__ == '__main__':