PYTHON-2143 Use an allow-list to determine resumable change stream errors (#445)

This commit is contained in:
Prashant Mital 2020-07-01 18:12:05 -07:00 committed by GitHub
commit 26913ea8e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 181 additions and 111 deletions

View File

@ -25,6 +25,7 @@ from pymongo.aggregation import (_CollectionAggregationCommand,
from pymongo.collation import validate_collation_or_none
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (ConnectionFailure,
CursorNotFound,
InvalidOperation,
OperationFailure,
PyMongoError)
@ -32,11 +33,25 @@ from pymongo.errors import (ConnectionFailure,
# The change streams spec considers the following server errors from the
# getMore command non-resumable. All other getMore errors are resumable.
_NON_RESUMABLE_GETMORE_ERRORS = frozenset([
11601, # Interrupted
136, # CappedPositionLost
237, # CursorKilled
None, # No error code was returned.
_RESUMABLE_GETMORE_ERRORS = frozenset([
6, # HostUnreachable
7, # HostNotFound
89, # NetworkTimeout
91, # ShutdownInProgress
189, # PrimarySteppedDown
262, # ExceededTimeLimit
9001, # SocketException
10107, # NotMaster
11600, # InterruptedAtShutdown
11602, # InterruptedDueToReplStateChange
13435, # NotMasterNoSlaveOk
13436, # NotMasterOrSecondary
63, # StaleShardVersion
150, # StaleEpoch
13388, # StaleConfig
234, # RetryChangeStream
133, # FailedToSatisfyReadPreference
216, # ElectionInProgress
])
@ -283,12 +298,17 @@ class ChangeStream(object):
# one resume attempt.
try:
change = self._cursor._try_next(True)
except ConnectionFailure:
except (ConnectionFailure, CursorNotFound):
self._resume()
change = self._cursor._try_next(False)
except OperationFailure as exc:
if (exc.code in _NON_RESUMABLE_GETMORE_ERRORS or
exc.has_error_label("NonResumableChangeStreamError")):
if exc._max_wire_version is None:
raise
is_resumable = ((exc._max_wire_version >= 9 and
exc.has_error_label("ResumableChangeStreamError")) or
(exc._max_wire_version < 9 and
exc.code in _RESUMABLE_GETMORE_ERRORS))
if not is_resumable:
raise
self._resume()
change = self._cursor._try_next(False)

View File

@ -144,7 +144,7 @@ class OperationFailure(PyMongoError):
The :attr:`details` attribute.
"""
def __init__(self, error, code=None, details=None):
def __init__(self, error, code=None, details=None, max_wire_version=None):
error_labels = None
if details is not None:
error_labels = details.get('errorLabels')
@ -152,6 +152,11 @@ class OperationFailure(PyMongoError):
error, error_labels=error_labels)
self.__code = code
self.__details = details
self.__max_wire_version = max_wire_version
@property
def _max_wire_version(self):
return self.__max_wire_version
@property
def code(self):
@ -177,6 +182,7 @@ class OperationFailure(PyMongoError):
return output_str.encode('utf-8', errors='replace')
return output_str
class CursorNotFound(OperationFailure):
"""Raised while iterating query results if the cursor is
invalidated on the server.

View File

@ -102,7 +102,8 @@ def _index_document(index_list):
return index
def _check_command_response(response, msg=None, allowable_errors=None,
def _check_command_response(response, max_wire_version, msg=None,
allowable_errors=None,
parse_write_concern_error=False):
"""Check the response to a command for errors.
"""
@ -110,7 +111,8 @@ def _check_command_response(response, msg=None, allowable_errors=None,
# Server didn't recognize our message as a command.
raise OperationFailure(response.get("$err"),
response.get("code"),
response)
response,
max_wire_version)
if parse_write_concern_error and 'writeConcernError' in response:
_raise_write_concern_error(response['writeConcernError'])
@ -146,25 +148,30 @@ def _check_command_response(response, msg=None, allowable_errors=None,
details.get("assertion", ""))
raise OperationFailure(errmsg,
details.get("assertionCode"),
response)
response,
max_wire_version)
# Other errors
# findAndModify with upsert can raise duplicate key error
if code in (11000, 11001, 12582):
raise DuplicateKeyError(errmsg, code, response)
raise DuplicateKeyError(errmsg, code, response,
max_wire_version)
elif code == 50:
raise ExecutionTimeout(errmsg, code, response)
raise ExecutionTimeout(errmsg, code, response,
max_wire_version)
elif code == 43:
raise CursorNotFound(errmsg, code, response)
raise CursorNotFound(errmsg, code, response,
max_wire_version)
msg = msg or "%s"
raise OperationFailure(msg % errmsg, code, response)
raise OperationFailure(msg % errmsg, code, response,
max_wire_version)
def _check_gle_response(result):
def _check_gle_response(result, max_wire_version):
"""Return getlasterror response as a dict, or raise OperationFailure."""
# Did getlasterror itself fail?
_check_command_response(result)
_check_command_response(result, max_wire_version)
if result.get("wtimeout", False):
# MongoDB versions before 1.8.0 return the error message in an "errmsg"

View File

@ -157,7 +157,8 @@ def command(sock_info, dbname, spec, slave_ok, is_mongos,
client._process_response(response_doc, session)
if check:
helpers._check_command_response(
response_doc, None, allowable_errors,
response_doc, sock_info.max_wire_version, None,
allowable_errors,
parse_write_concern_error=parse_write_concern_error)
except Exception as exc:
if publish:

View File

@ -606,7 +606,7 @@ class SocketInfo(object):
self.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response()
response_doc = unpacked_docs[0]
helpers._check_command_response(response_doc)
helpers._check_command_response(response_doc, self.max_wire_version)
return response_doc
def command(self, dbname, spec, slave_ok=False,
@ -751,7 +751,8 @@ class SocketInfo(object):
self.send_message(msg, max_doc_size)
if with_last_error:
reply = self.receive_message(request_id)
return helpers._check_gle_response(reply.command_response())
return helpers._check_gle_response(reply.command_response(),
self.max_wire_version)
def write_command(self, request_id, msg):
"""Send "insert" etc. command, returning response as a dict.
@ -767,7 +768,7 @@ class SocketInfo(object):
result = reply.command_response()
# Raises NotMasterError or OperationFailure.
helpers._check_command_response(result)
helpers._check_command_response(result, self.max_wire_version)
return result
def check_auth(self, all_credentials):

View File

@ -133,7 +133,8 @@ class Server(object):
first = docs[0]
operation.client._process_response(
first, operation.session)
_check_command_response(first)
_check_command_response(
first, sock_info.max_wire_version)
except Exception as exc:
if publish:
duration = datetime.now() - start

View File

@ -75,7 +75,6 @@
{
"description": "Change Stream should error when _id is projected out",
"minServerVersion": "4.1.11",
"maxServerVersion": "4.3.3",
"target": "collection",
"topology": [
"replicaset",
@ -103,10 +102,54 @@
],
"result": {
"error": {
"code": 280,
"errorLabels": [
"NonResumableChangeStreamError"
]
"code": 280
}
}
},
{
"description": "change stream errors on MaxTimeMSExpired",
"minServerVersion": "4.2",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"getMore"
],
"errorCode": 50,
"closeConnection": false
}
},
"target": "collection",
"topology": [
"replicaset",
"sharded"
],
"changeStreamPipeline": [
{
"$project": {
"_id": 0
}
}
],
"changeStreamOptions": {},
"operations": [
{
"database": "change-stream-tests",
"collection": "test",
"name": "insertOne",
"arguments": {
"document": {
"z": 3
}
}
}
],
"result": {
"error": {
"code": 50
}
}
}

View File

@ -37,7 +37,6 @@ from bson.py3compat import iteritems
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument
from pymongo import MongoClient
from pymongo.change_stream import _NON_RESUMABLE_GETMORE_ERRORS
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (InvalidOperation, OperationFailure,
ServerSelectionTimeoutError)
@ -555,47 +554,11 @@ class ProseSpecTestsMixin(object):
self.assertEqual(listener.results['started'][0].command_name,
'aggregate')
# Prose test no. 5
def test_does_not_resume_fatal_errors(self):
"""ChangeStream will not attempt to resume fatal server errors."""
if client_context.supports_failCommand_fail_point:
# failCommand does not support returning no errorCode.
TEST_ERROR_CODES = _NON_RESUMABLE_GETMORE_ERRORS - {None}
@contextmanager
def generate_error(change_stream, code):
fail_point = {'mode': {'times': 1}, 'data': {
'errorCode': code, 'failCommands': ['getMore']}}
with self.fail_point(fail_point):
yield
else:
TEST_ERROR_CODES = _NON_RESUMABLE_GETMORE_ERRORS
@contextmanager
def generate_error(change_stream, code):
def mock_try_next(*args, **kwargs):
change_stream._cursor.close()
raise OperationFailure('Mock server error', code=code)
original_cursor = change_stream._cursor
change_stream._cursor._try_next = mock_try_next
try:
yield
finally:
# Un patch the instance.
del original_cursor._try_next
for code in TEST_ERROR_CODES:
with self.change_stream() as change_stream:
self.watched_collection().insert_one({})
with generate_error(change_stream, code):
with self.assertRaises(OperationFailure):
next(change_stream)
with self.assertRaises(StopIteration):
next(change_stream)
# Prose test no. 5 - REMOVED
# Prose test no. 6 - SKIPPED
# readPreference is not configurable using the watch() helpers so we can
# skip this test. Also, PyMongo performs server selection for each
# operation which ensure compliance with this prose test.
# Reason: readPreference is not configurable using the watch() helpers
# so we can skip this test. Also, PyMongo performs server selection for
# each operation which ensure compliance with this prose test.
# Prose test no. 7
def test_initial_empty_batch(self):
@ -1075,7 +1038,7 @@ class TestAllScenarios(unittest.TestCase):
@classmethod
@client_context.require_connection
def setUpClass(cls):
cls.listener = WhiteListEventListener("aggregate")
cls.listener = WhiteListEventListener("aggregate", "getMore")
cls.client = rs_or_single_client(event_listeners=[cls.listener])
@classmethod
@ -1086,14 +1049,66 @@ class TestAllScenarios(unittest.TestCase):
self.listener.results.clear()
def setUpCluster(self, scenario_dict):
assets = [
(scenario_dict["database_name"], scenario_dict["collection_name"]),
(scenario_dict["database2_name"], scenario_dict["collection2_name"]),
]
assets = [(scenario_dict["database_name"],
scenario_dict["collection_name"]),
(scenario_dict.get("database2_name", "db2"),
scenario_dict.get("collection2_name", "coll2"))]
for db, coll in assets:
self.client.drop_database(db)
self.client[db].create_collection(coll)
def setFailPoint(self, scenario_dict):
fail_point = scenario_dict.get("failPoint")
if fail_point is None:
return
fail_cmd = SON([('configureFailPoint', 'failCommand')])
fail_cmd.update(fail_point)
client_context.client.admin.command(fail_cmd)
self.addCleanup(
client_context.client.admin.command,
'configureFailPoint', fail_cmd['configureFailPoint'], mode='off')
def assert_list_contents_are_subset(self, superlist, sublist):
"""Check that each element in sublist is a subset of the corresponding
element in superlist."""
self.assertEqual(len(superlist), len(sublist))
for sup, sub in zip(superlist, sublist):
if isinstance(sub, dict):
self.assert_dict_is_subset(sup, sub)
continue
if isinstance(sub, (list, tuple)):
self.assert_list_contents_are_subset(sup, sub)
continue
self.assertEqual(sup, sub)
def assert_dict_is_subset(self, superdict, subdict):
"""Check that subdict is a subset of superdict."""
exempt_fields = ["documentKey", "_id", "getMore"]
for key, value in iteritems(subdict):
if key not in superdict:
self.fail('Key %s not found in %s' % (key, superdict))
if isinstance(value, dict):
self.assert_dict_is_subset(superdict[key], value)
continue
if isinstance(value, (list, tuple)):
self.assert_list_contents_are_subset(superdict[key], value)
continue
if key in exempt_fields:
# Only check for presence of these exempt fields, but not value.
self.assertIn(key, superdict)
else:
self.assertEqual(superdict[key], value)
def check_event(self, event, expectation_dict):
if event is None:
self.fail()
for key, value in iteritems(expectation_dict):
if isinstance(value, dict):
self.assert_dict_is_subset(getattr(event, key), value)
else:
self.assertEqual(getattr(event, key), value)
def tearDown(self):
self.listener.results.clear()
@ -1147,36 +1162,11 @@ def run_operation(client, operation):
return cmd(**arguments)
def assert_dict_is_subset(superdict, subdict):
"""Check that subdict is a subset of superdict."""
exempt_fields = ["documentKey", "_id"]
for key, value in iteritems(subdict):
if key not in superdict:
assert False
if isinstance(value, dict):
assert_dict_is_subset(superdict[key], value)
continue
if key in exempt_fields:
superdict[key] = "42"
assert superdict[key] == value
def check_event(event, expectation_dict):
if event is None:
raise AssertionError
for key, value in iteritems(expectation_dict):
if isinstance(value, dict):
assert_dict_is_subset(
getattr(event, key), value
)
else:
assert getattr(event, key) == value
def create_test(scenario_def, test):
def run_scenario(self):
# Set up
self.setUpCluster(scenario_def)
self.setFailPoint(test)
is_error = test["result"].get("error", False)
try:
with get_change_stream(
@ -1202,17 +1192,17 @@ def create_test(scenario_def, test):
else:
# Check for expected output from change streams
for change, expected_changes in zip(changes, test["result"]["success"]):
assert_dict_is_subset(change, expected_changes)
self.assert_dict_is_subset(change, expected_changes)
self.assertEqual(len(changes), len(test["result"]["success"]))
finally:
# Check for expected events
results = self.listener.results
for expectation in test.get("expectations", []):
for idx, (event_type, event_desc) in enumerate(iteritems(expectation)):
for idx, expectation in enumerate(test.get("expectations", [])):
for event_type, event_desc in iteritems(expectation):
results_key = event_type.split("_")[1]
event = results[results_key][idx] if len(results[results_key]) > idx else None
check_event(event, event_desc)
self.check_event(event, event_desc)
return run_scenario

View File

@ -954,10 +954,10 @@ class TestDatabase(IntegrationTest):
# command document will have no 'ok' field. We should raise
# OperationFailure instead of KeyError.
self.assertRaises(OperationFailure,
helpers._check_command_response, {})
helpers._check_command_response, {}, None)
try:
helpers._check_command_response({'$err': 'foo'})
helpers._check_command_response({'$err': 'foo'}, None)
except OperationFailure as e:
self.assertEqual(e.args[0], 'foo')
else:
@ -970,7 +970,7 @@ class TestDatabase(IntegrationTest):
'raw': {'shard0/host0,host1': {'ok': 0, 'errmsg': 'inner'}}}
with self.assertRaises(OperationFailure) as context:
helpers._check_command_response(error_document)
helpers._check_command_response(error_document, None)
self.assertIn('inner', str(context.exception))
@ -983,7 +983,7 @@ class TestDatabase(IntegrationTest):
'raw': {'shard0/host0,host1': {}}}
with self.assertRaises(OperationFailure) as context:
helpers._check_command_response(error_document)
helpers._check_command_response(error_document, None)
self.assertIn('outer', str(context.exception))
@ -994,7 +994,7 @@ class TestDatabase(IntegrationTest):
'raw': {'shard0/host0,host1': {'ok': 0}}}
with self.assertRaises(OperationFailure) as context:
helpers._check_command_response(error_document)
helpers._check_command_response(error_document, None)
self.assertIn('outer', str(context.exception))

View File

@ -110,7 +110,7 @@ def got_app_error(topology, app_error):
# Pool/SocketInfo.
try:
if error_type == 'command':
_check_command_response(app_error['response'])
_check_command_response(app_error['response'], max_wire_version)
elif error_type == 'network':
raise AutoReconnect('mock non-timeout network error')
elif error_type == 'timeout':
@ -334,7 +334,8 @@ class TestIntegration(SpecRunner):
Assert the given event was published exactly `count` times.
"""
self.assertEqual(self._event_count(event), count)
self.assertEqual(self._event_count(event), count,
'expected %s not %r' % (count, event))
def wait_for_event(self, event, count):
"""Run the waitForEvent test operation.