From a9f08c573aaad824274387d2057c78c08cb1e05f Mon Sep 17 00:00:00 2001 From: Bernie Hackett Date: Sat, 14 Jul 2018 12:14:56 -0700 Subject: [PATCH] PYTHON-1613 Invalidate cache on changed salt or iterations --- pymongo/auth.py | 12 ++++++++---- test/test_auth.py | 15 +++++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index e696a50d5..de5a6574c 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -272,15 +272,19 @@ def _authenticate_scram(credentials, sock_info, mechanism): raise OperationFailure("Server returned an invalid nonce.") without_proof = b"c=biws,r=" + rnonce - keys = cache.data - if keys: - client_key, server_key = keys + if cache.data: + client_key, server_key, csalt, citerations = cache.data else: + client_key, server_key, csalt, citerations = None, None, None, None + + # Salt and / or iterations could change for a number of different + # reasons. Either changing invalidates the cache. + if not client_key or salt != csalt or iterations != citerations: salted_pass = _hi( digest, data, standard_b64decode(salt), iterations) client_key = _hmac(salted_pass, b"Client Key", digestmod).digest() server_key = _hmac(salted_pass, b"Server Key", digestmod).digest() - cache.data = (client_key, server_key) + cache.data = (client_key, server_key, salt, iterations) stored_key = digestmod(client_key).digest() auth_msg = b",".join((first_bare, server_first, without_proof)) client_sig = _hmac(stored_key, auth_msg, digestmod).digest() diff --git a/test/test_auth.py b/test/test_auth.py index 092cc71b7..8e41e100f 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -584,11 +584,14 @@ class TestSCRAM(unittest.TestCase): credentials = all_credentials.get('admin') cache = credentials.cache self.assertIsNotNone(cache) - keys = cache.data - self.assertIsNotNone(keys) - self.assertEqual(len(keys), 2) - for elt in keys: - self.assertIsInstance(elt, bytes) + data = cache.data + self.assertIsNotNone(data) + self.assertEqual(len(data), 4) + ckey, skey, salt, iterations = data + self.assertIsInstance(ckey, bytes) + self.assertIsInstance(skey, bytes) + self.assertIsInstance(salt, bytes) + self.assertIsInstance(iterations, int) pool = next(iter(client._topology._servers.values()))._pool with pool.get_socket(all_credentials) as sock_info: @@ -601,7 +604,7 @@ class TestSCRAM(unittest.TestCase): sock_credentials = next(iter(authset)) sock_cache = sock_credentials.cache self.assertIsNotNone(sock_cache) - self.assertEqual(sock_cache.data, keys) + self.assertEqual(sock_cache.data, data) def test_scram_threaded(self):