PYTHON-1613 Invalidate cache on changed salt or iterations
This commit is contained in:
parent
afca040b98
commit
a9f08c573a
@ -272,15 +272,19 @@ def _authenticate_scram(credentials, sock_info, mechanism):
|
||||
raise OperationFailure("Server returned an invalid nonce.")
|
||||
|
||||
without_proof = b"c=biws,r=" + rnonce
|
||||
keys = cache.data
|
||||
if keys:
|
||||
client_key, server_key = keys
|
||||
if cache.data:
|
||||
client_key, server_key, csalt, citerations = cache.data
|
||||
else:
|
||||
client_key, server_key, csalt, citerations = None, None, None, None
|
||||
|
||||
# Salt and / or iterations could change for a number of different
|
||||
# reasons. Either changing invalidates the cache.
|
||||
if not client_key or salt != csalt or iterations != citerations:
|
||||
salted_pass = _hi(
|
||||
digest, data, standard_b64decode(salt), iterations)
|
||||
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
|
||||
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
|
||||
cache.data = (client_key, server_key)
|
||||
cache.data = (client_key, server_key, salt, iterations)
|
||||
stored_key = digestmod(client_key).digest()
|
||||
auth_msg = b",".join((first_bare, server_first, without_proof))
|
||||
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
|
||||
|
||||
@ -584,11 +584,14 @@ class TestSCRAM(unittest.TestCase):
|
||||
credentials = all_credentials.get('admin')
|
||||
cache = credentials.cache
|
||||
self.assertIsNotNone(cache)
|
||||
keys = cache.data
|
||||
self.assertIsNotNone(keys)
|
||||
self.assertEqual(len(keys), 2)
|
||||
for elt in keys:
|
||||
self.assertIsInstance(elt, bytes)
|
||||
data = cache.data
|
||||
self.assertIsNotNone(data)
|
||||
self.assertEqual(len(data), 4)
|
||||
ckey, skey, salt, iterations = data
|
||||
self.assertIsInstance(ckey, bytes)
|
||||
self.assertIsInstance(skey, bytes)
|
||||
self.assertIsInstance(salt, bytes)
|
||||
self.assertIsInstance(iterations, int)
|
||||
|
||||
pool = next(iter(client._topology._servers.values()))._pool
|
||||
with pool.get_socket(all_credentials) as sock_info:
|
||||
@ -601,7 +604,7 @@ class TestSCRAM(unittest.TestCase):
|
||||
sock_credentials = next(iter(authset))
|
||||
sock_cache = sock_credentials.cache
|
||||
self.assertIsNotNone(sock_cache)
|
||||
self.assertEqual(sock_cache.data, keys)
|
||||
self.assertEqual(sock_cache.data, data)
|
||||
|
||||
def test_scram_threaded(self):
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user