diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index f026dd7f5..08b2043de 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -25,6 +25,7 @@ from pymongo.aggregation import (_CollectionAggregationCommand, from pymongo.collation import validate_collation_or_none from pymongo.command_cursor import CommandCursor from pymongo.errors import (ConnectionFailure, + CursorNotFound, InvalidOperation, OperationFailure, PyMongoError) @@ -32,11 +33,25 @@ from pymongo.errors import (ConnectionFailure, # The change streams spec considers the following server errors from the # getMore command non-resumable. All other getMore errors are resumable. -_NON_RESUMABLE_GETMORE_ERRORS = frozenset([ - 11601, # Interrupted - 136, # CappedPositionLost - 237, # CursorKilled - None, # No error code was returned. +_RESUMABLE_GETMORE_ERRORS = frozenset([ + 6, # HostUnreachable + 7, # HostNotFound + 89, # NetworkTimeout + 91, # ShutdownInProgress + 189, # PrimarySteppedDown + 262, # ExceededTimeLimit + 9001, # SocketException + 10107, # NotMaster + 11600, # InterruptedAtShutdown + 11602, # InterruptedDueToReplStateChange + 13435, # NotMasterNoSlaveOk + 13436, # NotMasterOrSecondary + 63, # StaleShardVersion + 150, # StaleEpoch + 13388, # StaleConfig + 234, # RetryChangeStream + 133, # FailedToSatisfyReadPreference + 216, # ElectionInProgress ]) @@ -283,12 +298,17 @@ class ChangeStream(object): # one resume attempt. try: change = self._cursor._try_next(True) - except ConnectionFailure: + except (ConnectionFailure, CursorNotFound): self._resume() change = self._cursor._try_next(False) except OperationFailure as exc: - if (exc.code in _NON_RESUMABLE_GETMORE_ERRORS or - exc.has_error_label("NonResumableChangeStreamError")): + if exc._max_wire_version is None: + raise + is_resumable = ((exc._max_wire_version >= 9 and + exc.has_error_label("ResumableChangeStreamError")) or + (exc._max_wire_version < 9 and + exc.code in _RESUMABLE_GETMORE_ERRORS)) + if not is_resumable: raise self._resume() change = self._cursor._try_next(False) diff --git a/pymongo/errors.py b/pymongo/errors.py index a309a9e7a..e5d52bfe3 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -144,7 +144,7 @@ class OperationFailure(PyMongoError): The :attr:`details` attribute. """ - def __init__(self, error, code=None, details=None): + def __init__(self, error, code=None, details=None, max_wire_version=None): error_labels = None if details is not None: error_labels = details.get('errorLabels') @@ -152,6 +152,11 @@ class OperationFailure(PyMongoError): error, error_labels=error_labels) self.__code = code self.__details = details + self.__max_wire_version = max_wire_version + + @property + def _max_wire_version(self): + return self.__max_wire_version @property def code(self): @@ -177,6 +182,7 @@ class OperationFailure(PyMongoError): return output_str.encode('utf-8', errors='replace') return output_str + class CursorNotFound(OperationFailure): """Raised while iterating query results if the cursor is invalidated on the server. diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 51215b4c4..67b2e1584 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -102,7 +102,8 @@ def _index_document(index_list): return index -def _check_command_response(response, msg=None, allowable_errors=None, +def _check_command_response(response, max_wire_version, msg=None, + allowable_errors=None, parse_write_concern_error=False): """Check the response to a command for errors. """ @@ -110,7 +111,8 @@ def _check_command_response(response, msg=None, allowable_errors=None, # Server didn't recognize our message as a command. raise OperationFailure(response.get("$err"), response.get("code"), - response) + response, + max_wire_version) if parse_write_concern_error and 'writeConcernError' in response: _raise_write_concern_error(response['writeConcernError']) @@ -146,25 +148,30 @@ def _check_command_response(response, msg=None, allowable_errors=None, details.get("assertion", "")) raise OperationFailure(errmsg, details.get("assertionCode"), - response) + response, + max_wire_version) # Other errors # findAndModify with upsert can raise duplicate key error if code in (11000, 11001, 12582): - raise DuplicateKeyError(errmsg, code, response) + raise DuplicateKeyError(errmsg, code, response, + max_wire_version) elif code == 50: - raise ExecutionTimeout(errmsg, code, response) + raise ExecutionTimeout(errmsg, code, response, + max_wire_version) elif code == 43: - raise CursorNotFound(errmsg, code, response) + raise CursorNotFound(errmsg, code, response, + max_wire_version) msg = msg or "%s" - raise OperationFailure(msg % errmsg, code, response) + raise OperationFailure(msg % errmsg, code, response, + max_wire_version) -def _check_gle_response(result): +def _check_gle_response(result, max_wire_version): """Return getlasterror response as a dict, or raise OperationFailure.""" # Did getlasterror itself fail? - _check_command_response(result) + _check_command_response(result, max_wire_version) if result.get("wtimeout", False): # MongoDB versions before 1.8.0 return the error message in an "errmsg" diff --git a/pymongo/network.py b/pymongo/network.py index 3224cf649..d9d645fa9 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -157,7 +157,8 @@ def command(sock_info, dbname, spec, slave_ok, is_mongos, client._process_response(response_doc, session) if check: helpers._check_command_response( - response_doc, None, allowable_errors, + response_doc, sock_info.max_wire_version, None, + allowable_errors, parse_write_concern_error=parse_write_concern_error) except Exception as exc: if publish: diff --git a/pymongo/pool.py b/pymongo/pool.py index 6e848b7bb..b04e4bd33 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -606,7 +606,7 @@ class SocketInfo(object): self.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response() response_doc = unpacked_docs[0] - helpers._check_command_response(response_doc) + helpers._check_command_response(response_doc, self.max_wire_version) return response_doc def command(self, dbname, spec, slave_ok=False, @@ -751,7 +751,8 @@ class SocketInfo(object): self.send_message(msg, max_doc_size) if with_last_error: reply = self.receive_message(request_id) - return helpers._check_gle_response(reply.command_response()) + return helpers._check_gle_response(reply.command_response(), + self.max_wire_version) def write_command(self, request_id, msg): """Send "insert" etc. command, returning response as a dict. @@ -767,7 +768,7 @@ class SocketInfo(object): result = reply.command_response() # Raises NotMasterError or OperationFailure. - helpers._check_command_response(result) + helpers._check_command_response(result, self.max_wire_version) return result def check_auth(self, all_credentials): diff --git a/pymongo/server.py b/pymongo/server.py index 18919b9e2..8ea361e10 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -133,7 +133,8 @@ class Server(object): first = docs[0] operation.client._process_response( first, operation.session) - _check_command_response(first) + _check_command_response( + first, sock_info.max_wire_version) except Exception as exc: if publish: duration = datetime.now() - start diff --git a/test/change_streams/change-streams-errors.json b/test/change_streams/change-streams-errors.json index 5ebbd28f4..7b7cea30a 100644 --- a/test/change_streams/change-streams-errors.json +++ b/test/change_streams/change-streams-errors.json @@ -75,7 +75,6 @@ { "description": "Change Stream should error when _id is projected out", "minServerVersion": "4.1.11", - "maxServerVersion": "4.3.3", "target": "collection", "topology": [ "replicaset", @@ -103,10 +102,54 @@ ], "result": { "error": { - "code": 280, - "errorLabels": [ - "NonResumableChangeStreamError" - ] + "code": 280 + } + } + }, + { + "description": "change stream errors on MaxTimeMSExpired", + "minServerVersion": "4.2", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "getMore" + ], + "errorCode": 50, + "closeConnection": false + } + }, + "target": "collection", + "topology": [ + "replicaset", + "sharded" + ], + "changeStreamPipeline": [ + { + "$project": { + "_id": 0 + } + } + ], + "changeStreamOptions": {}, + "operations": [ + { + "database": "change-stream-tests", + "collection": "test", + "name": "insertOne", + "arguments": { + "document": { + "z": 3 + } + } + } + ], + "result": { + "error": { + "code": 50 } } } diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 669c774a3..7c55c6e96 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -37,7 +37,6 @@ from bson.py3compat import iteritems from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument from pymongo import MongoClient -from pymongo.change_stream import _NON_RESUMABLE_GETMORE_ERRORS from pymongo.command_cursor import CommandCursor from pymongo.errors import (InvalidOperation, OperationFailure, ServerSelectionTimeoutError) @@ -555,47 +554,11 @@ class ProseSpecTestsMixin(object): self.assertEqual(listener.results['started'][0].command_name, 'aggregate') - # Prose test no. 5 - def test_does_not_resume_fatal_errors(self): - """ChangeStream will not attempt to resume fatal server errors.""" - if client_context.supports_failCommand_fail_point: - # failCommand does not support returning no errorCode. - TEST_ERROR_CODES = _NON_RESUMABLE_GETMORE_ERRORS - {None} - @contextmanager - def generate_error(change_stream, code): - fail_point = {'mode': {'times': 1}, 'data': { - 'errorCode': code, 'failCommands': ['getMore']}} - with self.fail_point(fail_point): - yield - else: - TEST_ERROR_CODES = _NON_RESUMABLE_GETMORE_ERRORS - @contextmanager - def generate_error(change_stream, code): - def mock_try_next(*args, **kwargs): - change_stream._cursor.close() - raise OperationFailure('Mock server error', code=code) - - original_cursor = change_stream._cursor - change_stream._cursor._try_next = mock_try_next - try: - yield - finally: - # Un patch the instance. - del original_cursor._try_next - - for code in TEST_ERROR_CODES: - with self.change_stream() as change_stream: - self.watched_collection().insert_one({}) - with generate_error(change_stream, code): - with self.assertRaises(OperationFailure): - next(change_stream) - with self.assertRaises(StopIteration): - next(change_stream) - + # Prose test no. 5 - REMOVED # Prose test no. 6 - SKIPPED - # readPreference is not configurable using the watch() helpers so we can - # skip this test. Also, PyMongo performs server selection for each - # operation which ensure compliance with this prose test. + # Reason: readPreference is not configurable using the watch() helpers + # so we can skip this test. Also, PyMongo performs server selection for + # each operation which ensure compliance with this prose test. # Prose test no. 7 def test_initial_empty_batch(self): @@ -1075,7 +1038,7 @@ class TestAllScenarios(unittest.TestCase): @classmethod @client_context.require_connection def setUpClass(cls): - cls.listener = WhiteListEventListener("aggregate") + cls.listener = WhiteListEventListener("aggregate", "getMore") cls.client = rs_or_single_client(event_listeners=[cls.listener]) @classmethod @@ -1086,14 +1049,66 @@ class TestAllScenarios(unittest.TestCase): self.listener.results.clear() def setUpCluster(self, scenario_dict): - assets = [ - (scenario_dict["database_name"], scenario_dict["collection_name"]), - (scenario_dict["database2_name"], scenario_dict["collection2_name"]), - ] + assets = [(scenario_dict["database_name"], + scenario_dict["collection_name"]), + (scenario_dict.get("database2_name", "db2"), + scenario_dict.get("collection2_name", "coll2"))] for db, coll in assets: self.client.drop_database(db) self.client[db].create_collection(coll) + def setFailPoint(self, scenario_dict): + fail_point = scenario_dict.get("failPoint") + if fail_point is None: + return + + fail_cmd = SON([('configureFailPoint', 'failCommand')]) + fail_cmd.update(fail_point) + client_context.client.admin.command(fail_cmd) + self.addCleanup( + client_context.client.admin.command, + 'configureFailPoint', fail_cmd['configureFailPoint'], mode='off') + + def assert_list_contents_are_subset(self, superlist, sublist): + """Check that each element in sublist is a subset of the corresponding + element in superlist.""" + self.assertEqual(len(superlist), len(sublist)) + for sup, sub in zip(superlist, sublist): + if isinstance(sub, dict): + self.assert_dict_is_subset(sup, sub) + continue + if isinstance(sub, (list, tuple)): + self.assert_list_contents_are_subset(sup, sub) + continue + self.assertEqual(sup, sub) + + def assert_dict_is_subset(self, superdict, subdict): + """Check that subdict is a subset of superdict.""" + exempt_fields = ["documentKey", "_id", "getMore"] + for key, value in iteritems(subdict): + if key not in superdict: + self.fail('Key %s not found in %s' % (key, superdict)) + if isinstance(value, dict): + self.assert_dict_is_subset(superdict[key], value) + continue + if isinstance(value, (list, tuple)): + self.assert_list_contents_are_subset(superdict[key], value) + continue + if key in exempt_fields: + # Only check for presence of these exempt fields, but not value. + self.assertIn(key, superdict) + else: + self.assertEqual(superdict[key], value) + + def check_event(self, event, expectation_dict): + if event is None: + self.fail() + for key, value in iteritems(expectation_dict): + if isinstance(value, dict): + self.assert_dict_is_subset(getattr(event, key), value) + else: + self.assertEqual(getattr(event, key), value) + def tearDown(self): self.listener.results.clear() @@ -1147,36 +1162,11 @@ def run_operation(client, operation): return cmd(**arguments) -def assert_dict_is_subset(superdict, subdict): - """Check that subdict is a subset of superdict.""" - exempt_fields = ["documentKey", "_id"] - for key, value in iteritems(subdict): - if key not in superdict: - assert False - if isinstance(value, dict): - assert_dict_is_subset(superdict[key], value) - continue - if key in exempt_fields: - superdict[key] = "42" - assert superdict[key] == value - - -def check_event(event, expectation_dict): - if event is None: - raise AssertionError - for key, value in iteritems(expectation_dict): - if isinstance(value, dict): - assert_dict_is_subset( - getattr(event, key), value - ) - else: - assert getattr(event, key) == value - - def create_test(scenario_def, test): def run_scenario(self): # Set up self.setUpCluster(scenario_def) + self.setFailPoint(test) is_error = test["result"].get("error", False) try: with get_change_stream( @@ -1202,17 +1192,17 @@ def create_test(scenario_def, test): else: # Check for expected output from change streams for change, expected_changes in zip(changes, test["result"]["success"]): - assert_dict_is_subset(change, expected_changes) + self.assert_dict_is_subset(change, expected_changes) self.assertEqual(len(changes), len(test["result"]["success"])) finally: # Check for expected events results = self.listener.results - for expectation in test.get("expectations", []): - for idx, (event_type, event_desc) in enumerate(iteritems(expectation)): + for idx, expectation in enumerate(test.get("expectations", [])): + for event_type, event_desc in iteritems(expectation): results_key = event_type.split("_")[1] event = results[results_key][idx] if len(results[results_key]) > idx else None - check_event(event, event_desc) + self.check_event(event, event_desc) return run_scenario diff --git a/test/test_database.py b/test/test_database.py index 43fac30b0..18eb322ea 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -954,10 +954,10 @@ class TestDatabase(IntegrationTest): # command document will have no 'ok' field. We should raise # OperationFailure instead of KeyError. self.assertRaises(OperationFailure, - helpers._check_command_response, {}) + helpers._check_command_response, {}, None) try: - helpers._check_command_response({'$err': 'foo'}) + helpers._check_command_response({'$err': 'foo'}, None) except OperationFailure as e: self.assertEqual(e.args[0], 'foo') else: @@ -970,7 +970,7 @@ class TestDatabase(IntegrationTest): 'raw': {'shard0/host0,host1': {'ok': 0, 'errmsg': 'inner'}}} with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document) + helpers._check_command_response(error_document, None) self.assertIn('inner', str(context.exception)) @@ -983,7 +983,7 @@ class TestDatabase(IntegrationTest): 'raw': {'shard0/host0,host1': {}}} with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document) + helpers._check_command_response(error_document, None) self.assertIn('outer', str(context.exception)) @@ -994,7 +994,7 @@ class TestDatabase(IntegrationTest): 'raw': {'shard0/host0,host1': {'ok': 0}}} with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document) + helpers._check_command_response(error_document, None) self.assertIn('outer', str(context.exception)) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 5f4f6c076..ef97bcc67 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -110,7 +110,7 @@ def got_app_error(topology, app_error): # Pool/SocketInfo. try: if error_type == 'command': - _check_command_response(app_error['response']) + _check_command_response(app_error['response'], max_wire_version) elif error_type == 'network': raise AutoReconnect('mock non-timeout network error') elif error_type == 'timeout': @@ -334,7 +334,8 @@ class TestIntegration(SpecRunner): Assert the given event was published exactly `count` times. """ - self.assertEqual(self._event_count(event), count) + self.assertEqual(self._event_count(event), count, + 'expected %s not %r' % (count, event)) def wait_for_event(self, event, count): """Run the waitForEvent test operation.