diff --git a/.evergreen/run-mongodb-aws-ecs-test.sh b/.evergreen/run-mongodb-aws-ecs-test.sh index 83f3975e9..fcadea208 100755 --- a/.evergreen/run-mongodb-aws-ecs-test.sh +++ b/.evergreen/run-mongodb-aws-ecs-test.sh @@ -30,7 +30,7 @@ authtest () { $PYTHON -m pip install --upgrade wheel setuptools pip cd src $PYTHON -m pip install '.[aws]' - $PYTHON test/auth_aws/test_auth_aws.py + $PYTHON test/auth_aws/test_auth_aws.py -v cd - } diff --git a/.evergreen/run-mongodb-aws-test.sh b/.evergreen/run-mongodb-aws-test.sh index 9a33507cc..b2a4fd146 100755 --- a/.evergreen/run-mongodb-aws-test.sh +++ b/.evergreen/run-mongodb-aws-test.sh @@ -61,7 +61,7 @@ authtest () { . venvaws/bin/activate fi python -m pip install '.[aws]' - python test/auth_aws/test_auth_aws.py + python test/auth_aws/test_auth_aws.py -v deactivate rm -rf venvaws } diff --git a/doc/changelog.rst b/doc/changelog.rst index eb9d1233b..c11ac9588 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -26,6 +26,9 @@ PyMongo 4.3 brings a number of improvements including: now allow for new types of events (such as DDL and C2C replication events) to be recorded with the new parameter ``show_expanded_events`` that can be passed to methods such as :meth:`~pymongo.collection.Collection.watch`. +- PyMongo now internally caches AWS credentials that it fetches from AWS + endpoints, to avoid rate limitations. The cache is cleared when the + credentials expire or an error is encountered. Bug fixes ......... diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index 4b2af35ea..e84465ea6 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -27,6 +27,17 @@ except ImportError: _HAVE_MONGODB_AWS = False +try: + from pymongo_auth_aws.auth import set_cached_credentials, set_use_cached_credentials + + # Enable credential caching. + set_use_cached_credentials(True) +except ImportError: + + def set_cached_credentials(creds): + pass + + import bson from bson.binary import Binary from bson.son import SON @@ -88,7 +99,13 @@ def _authenticate_aws(credentials, sock_info): # SASL complete. break except PyMongoAuthAwsError as exc: + # Clear the cached credentials if we hit a failure in auth. + set_cached_credentials(None) # Convert to OperationFailure and include pymongo-auth-aws version. raise OperationFailure( "%s (pymongo-auth-aws version %s)" % (exc, pymongo_auth_aws.__version__) ) + except Exception: + # Clear the cached credentials if we hit a failure in auth. + set_cached_credentials(None) + raise diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index a63e60718..372806bd2 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -20,6 +20,8 @@ import unittest sys.path[0:0] = [""] +from pymongo_auth_aws import AwsCredential, auth + from pymongo import MongoClient from pymongo.errors import OperationFailure from pymongo.uri_parser import parse_uri @@ -53,6 +55,62 @@ class TestAuthAWS(unittest.TestCase): with MongoClient(self.uri) as client: client.get_database().test.find_one() + def setup_cache(self): + if os.environ.get("AWS_ACCESS_KEY_ID", None) or "@" in self.uri: + self.skipTest("Not testing cached credentials") + if not hasattr(auth, "set_cached_credentials"): + self.skipTest("Cached credentials not available") + + # Ensure cleared credentials. + auth.set_cached_credentials(None) + self.assertEqual(auth.get_cached_credentials(), None) + + client = MongoClient(self.uri) + client.get_database().test.find_one() + client.close() + return auth.get_cached_credentials() + + def test_cache_credentials(self): + creds = self.setup_cache() + self.assertIsNotNone(creds) + + def test_cache_about_to_expire(self): + creds = self.setup_cache() + client = MongoClient(self.uri) + self.addCleanup(client.close) + + # Make the creds about to expire. + creds = auth.get_cached_credentials() + assert creds is not None + + creds = AwsCredential(creds.username, creds.password, creds.token, lambda x: True) + auth.set_cached_credentials(creds) + + client.get_database().test.find_one() + new_creds = auth.get_cached_credentials() + self.assertNotEqual(creds, new_creds) + + def test_poisoned_cache(self): + creds = self.setup_cache() + + client = MongoClient(self.uri) + self.addCleanup(client.close) + + # Poison the creds with invalid password. + assert creds is not None + creds = AwsCredential("a" * 24, "b" * 24, "c" * 24) + auth.set_cached_credentials(creds) + + with self.assertRaises(OperationFailure): + client.get_database().test.find_one() + + # Make sure the cache was cleared. + self.assertEqual(auth.get_cached_credentials(), None) + + # The next attempt should generate a new cred and succeed. + client.get_database().test.find_one() + self.assertNotEqual(auth.get_cached_credentials(), None) + class TestAWSLambdaExamples(unittest.TestCase): def test_shared_client(self):