diff --git a/pymongo/auth.py b/pymongo/auth.py index 365d13238..edc01ebf8 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -58,9 +58,21 @@ MECHANISMS = frozenset( """The authentication mechanisms supported by PyMongo.""" +class _Cache(object): + __slots__ = ("data",) + + def __init__(self): + self.data = None + + MongoCredential = namedtuple( 'MongoCredential', - ['mechanism', 'source', 'username', 'password', 'mechanism_properties']) + ['mechanism', + 'source', + 'username', + 'password', + 'mechanism_properties', + 'cache']) """A hashable namedtuple of values used for authentication.""" @@ -88,7 +100,7 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): canonicalize_host_name=canonicalize, service_realm=service_realm) # Source is always $external. - return MongoCredential(mech, '$external', user, passwd, props) + return MongoCredential(mech, '$external', user, passwd, props, None) elif mech == 'MONGODB-X509': if passwd is not None: raise ConfigurationError( @@ -98,15 +110,16 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): "authentication source must be " "$external or None for MONGODB-X509") # user can be None. - return MongoCredential(mech, '$external', user, None, None) + return MongoCredential(mech, '$external', user, None, None, None) elif mech == 'PLAIN': source_database = source or database or '$external' - return MongoCredential(mech, source_database, user, passwd, None) + return MongoCredential(mech, source_database, user, passwd, None, None) else: source_database = source or database or 'admin' if passwd is None: raise ConfigurationError("A password is required.") - return MongoCredential(mech, source_database, user, passwd, None) + return MongoCredential( + mech, source_database, user, passwd, None, _Cache()) if PY3: @@ -215,6 +228,7 @@ def _authenticate_scram(credentials, sock_info, mechanism): digestmod = hashlib.sha1 data = _password_digest(username, credentials.password).encode("utf-8") source = credentials.source + cache = credentials.cache # Make local _hmac = hmac.HMAC @@ -241,16 +255,21 @@ def _authenticate_scram(credentials, sock_info, mechanism): raise OperationFailure("Server returned an invalid nonce.") without_proof = b"c=biws,r=" + rnonce - salted_pass = _hi( - digest, data, standard_b64decode(salt), iterations) - client_key = _hmac(salted_pass, b"Client Key", digestmod).digest() + keys = cache.data + if keys: + client_key, server_key = keys + else: + salted_pass = _hi( + digest, data, standard_b64decode(salt), iterations) + client_key = _hmac(salted_pass, b"Client Key", digestmod).digest() + server_key = _hmac(salted_pass, b"Server Key", digestmod).digest() + cache.data = (client_key, server_key) stored_key = digestmod(client_key).digest() auth_msg = b",".join((first_bare, server_first, without_proof)) client_sig = _hmac(stored_key, auth_msg, digestmod).digest() client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig)) client_final = b",".join((without_proof, client_proof)) - server_key = _hmac(salted_pass, b"Server Key", digestmod).digest() server_sig = standard_b64encode( _hmac(server_key, auth_msg, digestmod).digest()) diff --git a/test/test_auth.py b/test/test_auth.py index 15a68b8ef..a9c69f371 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -34,6 +34,8 @@ from pymongo.saslprep import HAVE_STRINGPREP from test import client_context, SkipTest, unittest, Version from test.utils import (delay, ignore_deprecations, + single_client, + rs_or_single_client, rs_or_single_client_noauth, single_client_noauth, WhiteListEventListener) @@ -382,6 +384,7 @@ class TestSCRAM(unittest.TestCase): client_context.client.testscram.command("dropAllUsersFromDatabase") client_context.client.drop_database("testscram") + @ignore_deprecations def test_scram(self): host, port = client_context.host, client_context.port @@ -548,6 +551,50 @@ class TestSCRAM(unittest.TestCase): 'testscram', read_preference=ReadPreference.SECONDARY) db.command('dbstats') + def test_cache(self): + client = single_client() + # Force authentication. + client.admin.command('ismaster') + all_credentials = client._MongoClient__all_credentials + credentials = all_credentials.get('admin') + cache = credentials.cache + self.assertIsNotNone(cache) + keys = cache.data + self.assertIsNotNone(keys) + self.assertEqual(len(keys), 2) + for elt in keys: + self.assertIsInstance(elt, bytes) + + pool = next(iter(client._topology._servers.values()))._pool + with pool.get_socket(all_credentials) as sock_info: + authset = sock_info.authset + cached = set(all_credentials.values()) + self.assertEqual(len(cached), 1) + self.assertFalse(authset - cached) + self.assertFalse(cached - authset) + + sock_credentials = next(iter(authset)) + sock_cache = sock_credentials.cache + self.assertIsNotNone(sock_cache) + self.assertEqual(sock_cache.data, keys) + + def test_scram_threaded(self): + + coll = client_context.client.db.test + coll.drop() + coll.insert_one({'_id': 1}) + + # The first thread to call find() will authenticate + coll = rs_or_single_client().db.test + threads = [] + for _ in range(4): + threads.append(AutoAuthenticateThread(coll)) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertTrue(thread.success) + class TestAuthURIOptions(unittest.TestCase):