PYTHON-2029 Support shorter SCRAM conversation
This commit is contained in:
parent
5a1cbd8f20
commit
d481363fd5
@ -259,7 +259,8 @@ def _authenticate_scram(credentials, sock_info, mechanism):
|
||||
cmd = SON([('saslStart', 1),
|
||||
('mechanism', mechanism),
|
||||
('payload', Binary(b"n,," + first_bare)),
|
||||
('autoAuthorize', 1)])
|
||||
('autoAuthorize', 1),
|
||||
('options', {'skipEmptyExchange': True})])
|
||||
res = sock_info.command(source, cmd)
|
||||
|
||||
server_first = res['payload']
|
||||
@ -304,8 +305,8 @@ def _authenticate_scram(credentials, sock_info, mechanism):
|
||||
if not compare_digest(parsed[b'v'], server_sig):
|
||||
raise OperationFailure("Server returned an invalid signature.")
|
||||
|
||||
# Depending on how it's configured, Cyrus SASL (which the server uses)
|
||||
# requires a third empty challenge.
|
||||
# A third empty challenge may be required if the server does not support
|
||||
# skipEmptyExchange: SERVER-44857.
|
||||
if not res['done']:
|
||||
cmd = SON([('saslContinue', 1),
|
||||
('conversationId', res['conversationId']),
|
||||
|
||||
@ -409,6 +409,31 @@ class TestSCRAM(unittest.TestCase):
|
||||
client_context.client.testscram.command("dropAllUsersFromDatabase")
|
||||
client_context.client.drop_database("testscram")
|
||||
|
||||
def test_scram_skip_empty_exchange(self):
|
||||
listener = WhiteListEventListener("saslStart", "saslContinue")
|
||||
client_context.create_user(
|
||||
'testscram', 'sha256', 'pwd', roles=['dbOwner'],
|
||||
mechanisms=['SCRAM-SHA-256'])
|
||||
|
||||
client = rs_or_single_client_noauth(
|
||||
username='sha256', password='pwd', authSource='testscram',
|
||||
event_listeners=[listener])
|
||||
client.admin.command('isMaster')
|
||||
|
||||
# Assert we sent the skipEmptyExchange option.
|
||||
first_event = listener.results['started'][0]
|
||||
self.assertEqual(first_event.command_name, 'saslStart')
|
||||
self.assertEqual(
|
||||
first_event.command['options'], {'skipEmptyExchange': True})
|
||||
|
||||
# Assert the third exchange was skipped on servers that support it.
|
||||
started = listener.started_command_names()
|
||||
if client_context.version.at_least(4, 3, 3):
|
||||
self.assertEqual(started, ['saslStart', 'saslContinue'])
|
||||
else:
|
||||
self.assertEqual(
|
||||
started, ['saslStart', 'saslContinue', 'saslContinue'])
|
||||
|
||||
@ignore_deprecations
|
||||
def test_scram(self):
|
||||
host, port = client_context.host, client_context.port
|
||||
|
||||
@ -49,25 +49,6 @@ from test import (client_context,
|
||||
IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50)
|
||||
|
||||
|
||||
class WhiteListEventListener(monitoring.CommandListener):
|
||||
|
||||
def __init__(self, *commands):
|
||||
self.commands = set(commands)
|
||||
self.results = defaultdict(list)
|
||||
|
||||
def started(self, event):
|
||||
if event.command_name in self.commands:
|
||||
self.results['started'].append(event)
|
||||
|
||||
def succeeded(self, event):
|
||||
if event.command_name in self.commands:
|
||||
self.results['succeeded'].append(event)
|
||||
|
||||
def failed(self, event):
|
||||
if event.command_name in self.commands:
|
||||
self.results['failed'].append(event)
|
||||
|
||||
|
||||
class CMAPListener(ConnectionPoolListener):
|
||||
def __init__(self):
|
||||
self.events = []
|
||||
@ -136,6 +117,25 @@ class EventListener(monitoring.CommandListener):
|
||||
self.results.clear()
|
||||
|
||||
|
||||
class WhiteListEventListener(EventListener):
|
||||
|
||||
def __init__(self, *commands):
|
||||
self.commands = set(commands)
|
||||
super(WhiteListEventListener, self).__init__()
|
||||
|
||||
def started(self, event):
|
||||
if event.command_name in self.commands:
|
||||
super(WhiteListEventListener, self).started(event)
|
||||
|
||||
def succeeded(self, event):
|
||||
if event.command_name in self.commands:
|
||||
super(WhiteListEventListener, self).succeeded(event)
|
||||
|
||||
def failed(self, event):
|
||||
if event.command_name in self.commands:
|
||||
super(WhiteListEventListener, self).failed(event)
|
||||
|
||||
|
||||
class OvertCommandListener(EventListener):
|
||||
"""A CommandListener that ignores sensitive commands."""
|
||||
def started(self, event):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user