PYTHON-1273 - Cache SCRAM ClientKey and ServerKey

This commit is contained in:
Bernie Hackett 2018-06-08 14:24:05 -07:00
parent fc4e8558d6
commit bb8130abd8
2 changed files with 75 additions and 9 deletions

View File

@ -58,9 +58,21 @@ MECHANISMS = frozenset(
"""The authentication mechanisms supported by PyMongo."""
class _Cache(object):
__slots__ = ("data",)
def __init__(self):
self.data = None
MongoCredential = namedtuple(
'MongoCredential',
['mechanism', 'source', 'username', 'password', 'mechanism_properties'])
['mechanism',
'source',
'username',
'password',
'mechanism_properties',
'cache'])
"""A hashable namedtuple of values used for authentication."""
@ -88,7 +100,7 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
canonicalize_host_name=canonicalize,
service_realm=service_realm)
# Source is always $external.
return MongoCredential(mech, '$external', user, passwd, props)
return MongoCredential(mech, '$external', user, passwd, props, None)
elif mech == 'MONGODB-X509':
if passwd is not None:
raise ConfigurationError(
@ -98,15 +110,16 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
"authentication source must be "
"$external or None for MONGODB-X509")
# user can be None.
return MongoCredential(mech, '$external', user, None, None)
return MongoCredential(mech, '$external', user, None, None, None)
elif mech == 'PLAIN':
source_database = source or database or '$external'
return MongoCredential(mech, source_database, user, passwd, None)
return MongoCredential(mech, source_database, user, passwd, None, None)
else:
source_database = source or database or 'admin'
if passwd is None:
raise ConfigurationError("A password is required.")
return MongoCredential(mech, source_database, user, passwd, None)
return MongoCredential(
mech, source_database, user, passwd, None, _Cache())
if PY3:
@ -215,6 +228,7 @@ def _authenticate_scram(credentials, sock_info, mechanism):
digestmod = hashlib.sha1
data = _password_digest(username, credentials.password).encode("utf-8")
source = credentials.source
cache = credentials.cache
# Make local
_hmac = hmac.HMAC
@ -241,16 +255,21 @@ def _authenticate_scram(credentials, sock_info, mechanism):
raise OperationFailure("Server returned an invalid nonce.")
without_proof = b"c=biws,r=" + rnonce
salted_pass = _hi(
digest, data, standard_b64decode(salt), iterations)
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
keys = cache.data
if keys:
client_key, server_key = keys
else:
salted_pass = _hi(
digest, data, standard_b64decode(salt), iterations)
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
cache.data = (client_key, server_key)
stored_key = digestmod(client_key).digest()
auth_msg = b",".join((first_bare, server_first, without_proof))
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
client_final = b",".join((without_proof, client_proof))
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
server_sig = standard_b64encode(
_hmac(server_key, auth_msg, digestmod).digest())

View File

@ -34,6 +34,8 @@ from pymongo.saslprep import HAVE_STRINGPREP
from test import client_context, SkipTest, unittest, Version
from test.utils import (delay,
ignore_deprecations,
single_client,
rs_or_single_client,
rs_or_single_client_noauth,
single_client_noauth,
WhiteListEventListener)
@ -382,6 +384,7 @@ class TestSCRAM(unittest.TestCase):
client_context.client.testscram.command("dropAllUsersFromDatabase")
client_context.client.drop_database("testscram")
@ignore_deprecations
def test_scram(self):
host, port = client_context.host, client_context.port
@ -548,6 +551,50 @@ class TestSCRAM(unittest.TestCase):
'testscram', read_preference=ReadPreference.SECONDARY)
db.command('dbstats')
def test_cache(self):
client = single_client()
# Force authentication.
client.admin.command('ismaster')
all_credentials = client._MongoClient__all_credentials
credentials = all_credentials.get('admin')
cache = credentials.cache
self.assertIsNotNone(cache)
keys = cache.data
self.assertIsNotNone(keys)
self.assertEqual(len(keys), 2)
for elt in keys:
self.assertIsInstance(elt, bytes)
pool = next(iter(client._topology._servers.values()))._pool
with pool.get_socket(all_credentials) as sock_info:
authset = sock_info.authset
cached = set(all_credentials.values())
self.assertEqual(len(cached), 1)
self.assertFalse(authset - cached)
self.assertFalse(cached - authset)
sock_credentials = next(iter(authset))
sock_cache = sock_credentials.cache
self.assertIsNotNone(sock_cache)
self.assertEqual(sock_cache.data, keys)
def test_scram_threaded(self):
coll = client_context.client.db.test
coll.drop()
coll.insert_one({'_id': 1})
# The first thread to call find() will authenticate
coll = rs_or_single_client().db.test
threads = []
for _ in range(4):
threads.append(AutoAuthenticateThread(coll))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
self.assertTrue(thread.success)
class TestAuthURIOptions(unittest.TestCase):