PYTHON-2029 Support shorter SCRAM conversation

This commit is contained in:
Shane Harvey 2020-02-10 15:30:07 -08:00
parent 5a1cbd8f20
commit d481363fd5
3 changed files with 48 additions and 22 deletions

View File

@ -259,7 +259,8 @@ def _authenticate_scram(credentials, sock_info, mechanism):
cmd = SON([('saslStart', 1),
('mechanism', mechanism),
('payload', Binary(b"n,," + first_bare)),
('autoAuthorize', 1)])
('autoAuthorize', 1),
('options', {'skipEmptyExchange': True})])
res = sock_info.command(source, cmd)
server_first = res['payload']
@ -304,8 +305,8 @@ def _authenticate_scram(credentials, sock_info, mechanism):
if not compare_digest(parsed[b'v'], server_sig):
raise OperationFailure("Server returned an invalid signature.")
# Depending on how it's configured, Cyrus SASL (which the server uses)
# requires a third empty challenge.
# A third empty challenge may be required if the server does not support
# skipEmptyExchange: SERVER-44857.
if not res['done']:
cmd = SON([('saslContinue', 1),
('conversationId', res['conversationId']),

View File

@ -409,6 +409,31 @@ class TestSCRAM(unittest.TestCase):
client_context.client.testscram.command("dropAllUsersFromDatabase")
client_context.client.drop_database("testscram")
def test_scram_skip_empty_exchange(self):
listener = WhiteListEventListener("saslStart", "saslContinue")
client_context.create_user(
'testscram', 'sha256', 'pwd', roles=['dbOwner'],
mechanisms=['SCRAM-SHA-256'])
client = rs_or_single_client_noauth(
username='sha256', password='pwd', authSource='testscram',
event_listeners=[listener])
client.admin.command('isMaster')
# Assert we sent the skipEmptyExchange option.
first_event = listener.results['started'][0]
self.assertEqual(first_event.command_name, 'saslStart')
self.assertEqual(
first_event.command['options'], {'skipEmptyExchange': True})
# Assert the third exchange was skipped on servers that support it.
started = listener.started_command_names()
if client_context.version.at_least(4, 3, 3):
self.assertEqual(started, ['saslStart', 'saslContinue'])
else:
self.assertEqual(
started, ['saslStart', 'saslContinue', 'saslContinue'])
@ignore_deprecations
def test_scram(self):
host, port = client_context.host, client_context.port

View File

@ -49,25 +49,6 @@ from test import (client_context,
IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50)
class WhiteListEventListener(monitoring.CommandListener):
def __init__(self, *commands):
self.commands = set(commands)
self.results = defaultdict(list)
def started(self, event):
if event.command_name in self.commands:
self.results['started'].append(event)
def succeeded(self, event):
if event.command_name in self.commands:
self.results['succeeded'].append(event)
def failed(self, event):
if event.command_name in self.commands:
self.results['failed'].append(event)
class CMAPListener(ConnectionPoolListener):
def __init__(self):
self.events = []
@ -136,6 +117,25 @@ class EventListener(monitoring.CommandListener):
self.results.clear()
class WhiteListEventListener(EventListener):
def __init__(self, *commands):
self.commands = set(commands)
super(WhiteListEventListener, self).__init__()
def started(self, event):
if event.command_name in self.commands:
super(WhiteListEventListener, self).started(event)
def succeeded(self, event):
if event.command_name in self.commands:
super(WhiteListEventListener, self).succeeded(event)
def failed(self, event):
if event.command_name in self.commands:
super(WhiteListEventListener, self).failed(event)
class OvertCommandListener(EventListener):
"""A CommandListener that ignores sensitive commands."""
def started(self, event):