PYTHON-1613 Invalidate cache on changed salt or iterations

This commit is contained in:
Bernie Hackett 2018-07-14 12:14:56 -07:00
parent 47b0d8ebfd
commit a22719853e
2 changed files with 17 additions and 10 deletions

View File

@ -272,15 +272,19 @@ def _authenticate_scram(credentials, sock_info, mechanism):
raise OperationFailure("Server returned an invalid nonce.")
without_proof = b"c=biws,r=" + rnonce
keys = cache.data
if keys:
client_key, server_key = keys
if cache.data:
client_key, server_key, csalt, citerations = cache.data
else:
client_key, server_key, csalt, citerations = None, None, None, None
# Salt and / or iterations could change for a number of different
# reasons. Either changing invalidates the cache.
if not client_key or salt != csalt or iterations != citerations:
salted_pass = _hi(
digest, data, standard_b64decode(salt), iterations)
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
cache.data = (client_key, server_key)
cache.data = (client_key, server_key, salt, iterations)
stored_key = digestmod(client_key).digest()
auth_msg = b",".join((first_bare, server_first, without_proof))
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()

View File

@ -584,11 +584,14 @@ class TestSCRAM(unittest.TestCase):
credentials = all_credentials.get('admin')
cache = credentials.cache
self.assertIsNotNone(cache)
keys = cache.data
self.assertIsNotNone(keys)
self.assertEqual(len(keys), 2)
for elt in keys:
self.assertIsInstance(elt, bytes)
data = cache.data
self.assertIsNotNone(data)
self.assertEqual(len(data), 4)
ckey, skey, salt, iterations = data
self.assertIsInstance(ckey, bytes)
self.assertIsInstance(skey, bytes)
self.assertIsInstance(salt, bytes)
self.assertIsInstance(iterations, int)
pool = next(iter(client._topology._servers.values()))._pool
with pool.get_socket(all_credentials) as sock_info:
@ -601,7 +604,7 @@ class TestSCRAM(unittest.TestCase):
sock_credentials = next(iter(authset))
sock_cache = sock_credentials.cache
self.assertIsNotNone(sock_cache)
self.assertEqual(sock_cache.data, keys)
self.assertEqual(sock_cache.data, data)
def test_scram_threaded(self):