diff --git a/pymongo/auth.py b/pymongo/auth.py index fef4386f1..455717a6a 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -259,7 +259,8 @@ def _authenticate_scram(credentials, sock_info, mechanism): cmd = SON([('saslStart', 1), ('mechanism', mechanism), ('payload', Binary(b"n,," + first_bare)), - ('autoAuthorize', 1)]) + ('autoAuthorize', 1), + ('options', {'skipEmptyExchange': True})]) res = sock_info.command(source, cmd) server_first = res['payload'] @@ -304,8 +305,8 @@ def _authenticate_scram(credentials, sock_info, mechanism): if not compare_digest(parsed[b'v'], server_sig): raise OperationFailure("Server returned an invalid signature.") - # Depending on how it's configured, Cyrus SASL (which the server uses) - # requires a third empty challenge. + # A third empty challenge may be required if the server does not support + # skipEmptyExchange: SERVER-44857. if not res['done']: cmd = SON([('saslContinue', 1), ('conversationId', res['conversationId']), diff --git a/test/test_auth.py b/test/test_auth.py index 8e41e100f..14b2f9439 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -409,6 +409,31 @@ class TestSCRAM(unittest.TestCase): client_context.client.testscram.command("dropAllUsersFromDatabase") client_context.client.drop_database("testscram") + def test_scram_skip_empty_exchange(self): + listener = WhiteListEventListener("saslStart", "saslContinue") + client_context.create_user( + 'testscram', 'sha256', 'pwd', roles=['dbOwner'], + mechanisms=['SCRAM-SHA-256']) + + client = rs_or_single_client_noauth( + username='sha256', password='pwd', authSource='testscram', + event_listeners=[listener]) + client.admin.command('isMaster') + + # Assert we sent the skipEmptyExchange option. + first_event = listener.results['started'][0] + self.assertEqual(first_event.command_name, 'saslStart') + self.assertEqual( + first_event.command['options'], {'skipEmptyExchange': True}) + + # Assert the third exchange was skipped on servers that support it. + started = listener.started_command_names() + if client_context.version.at_least(4, 3, 3): + self.assertEqual(started, ['saslStart', 'saslContinue']) + else: + self.assertEqual( + started, ['saslStart', 'saslContinue', 'saslContinue']) + @ignore_deprecations def test_scram(self): host, port = client_context.host, client_context.port diff --git a/test/utils.py b/test/utils.py index 768a77996..95c088585 100644 --- a/test/utils.py +++ b/test/utils.py @@ -49,25 +49,6 @@ from test import (client_context, IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) -class WhiteListEventListener(monitoring.CommandListener): - - def __init__(self, *commands): - self.commands = set(commands) - self.results = defaultdict(list) - - def started(self, event): - if event.command_name in self.commands: - self.results['started'].append(event) - - def succeeded(self, event): - if event.command_name in self.commands: - self.results['succeeded'].append(event) - - def failed(self, event): - if event.command_name in self.commands: - self.results['failed'].append(event) - - class CMAPListener(ConnectionPoolListener): def __init__(self): self.events = [] @@ -136,6 +117,25 @@ class EventListener(monitoring.CommandListener): self.results.clear() +class WhiteListEventListener(EventListener): + + def __init__(self, *commands): + self.commands = set(commands) + super(WhiteListEventListener, self).__init__() + + def started(self, event): + if event.command_name in self.commands: + super(WhiteListEventListener, self).started(event) + + def succeeded(self, event): + if event.command_name in self.commands: + super(WhiteListEventListener, self).succeeded(event) + + def failed(self, event): + if event.command_name in self.commands: + super(WhiteListEventListener, self).failed(event) + + class OvertCommandListener(EventListener): """A CommandListener that ignores sensitive commands.""" def started(self, event):