diff --git a/tests/test_password_hasher.py b/tests/test_password_hasher.py index 5ad02d1..90adc1a 100644 --- a/tests/test_password_hasher.py +++ b/tests/test_password_hasher.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: MIT +import secrets +import sys +import threading + +from concurrent.futures import ThreadPoolExecutor from unittest import mock import pytest @@ -194,3 +199,38 @@ class TestPasswordHasher: hash = ph.hash("hello") assert ph.verify(hash, "hello") is True + + +def test_multithreaded_hashing(): + """ + Hash passwords in a thread pool and check for thread safety + """ + hasher = PasswordHasher(parallelism=2) + + num_passwords = 100 + + passwords = [secrets.token_urlsafe(15) for _ in range(num_passwords)] + + def closure(b, passwords): + b.wait() + for password in passwords: + assert hasher.verify(hasher.hash(password), password) + + max_workers = 4 + + chunks = [passwords[i::max_workers] for i in range(max_workers)] + orig_interval = sys.getswitchinterval() + + with ThreadPoolExecutor(max_workers=max_workers) as tpe: + barrier = threading.Barrier(max_workers) + futures = [] + try: + sys.setswitchinterval(0.00001) + for chunk in chunks: + futures.append(tpe.submit(closure, barrier, chunk)) # noqa: PERF401 + finally: + sys.setswitchinterval(orig_interval) + if len(futures) < max_workers: + barrier.abort() + for f in futures: + f.result()