PYTHON-3313 Cache AWS Credentials Where Possible (#982)

This commit is contained in:
Steven Silvester 2022-10-12 10:21:06 -05:00 committed by GitHub
parent 775c0203ca
commit 4a5e0f6655
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 2 deletions

View File

@ -30,7 +30,7 @@ authtest () {
$PYTHON -m pip install --upgrade wheel setuptools pip
cd src
$PYTHON -m pip install '.[aws]'
$PYTHON test/auth_aws/test_auth_aws.py
$PYTHON test/auth_aws/test_auth_aws.py -v
cd -
}

View File

@ -61,7 +61,7 @@ authtest () {
. venvaws/bin/activate
fi
python -m pip install '.[aws]'
python test/auth_aws/test_auth_aws.py
python test/auth_aws/test_auth_aws.py -v
deactivate
rm -rf venvaws
}

View File

@ -26,6 +26,9 @@ PyMongo 4.3 brings a number of improvements including:
now allow for new types of events (such as DDL and C2C replication events)
to be recorded with the new parameter ``show_expanded_events``
that can be passed to methods such as :meth:`~pymongo.collection.Collection.watch`.
- PyMongo now internally caches AWS credentials that it fetches from AWS
endpoints, to avoid rate limitations. The cache is cleared when the
credentials expire or an error is encountered.
Bug fixes
.........

View File

@ -27,6 +27,17 @@ except ImportError:
_HAVE_MONGODB_AWS = False
try:
from pymongo_auth_aws.auth import set_cached_credentials, set_use_cached_credentials
# Enable credential caching.
set_use_cached_credentials(True)
except ImportError:
def set_cached_credentials(creds):
pass
import bson
from bson.binary import Binary
from bson.son import SON
@ -88,7 +99,13 @@ def _authenticate_aws(credentials, sock_info):
# SASL complete.
break
except PyMongoAuthAwsError as exc:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)
# Convert to OperationFailure and include pymongo-auth-aws version.
raise OperationFailure(
"%s (pymongo-auth-aws version %s)" % (exc, pymongo_auth_aws.__version__)
)
except Exception:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)
raise

View File

@ -20,6 +20,8 @@ import unittest
sys.path[0:0] = [""]
from pymongo_auth_aws import AwsCredential, auth
from pymongo import MongoClient
from pymongo.errors import OperationFailure
from pymongo.uri_parser import parse_uri
@ -53,6 +55,62 @@ class TestAuthAWS(unittest.TestCase):
with MongoClient(self.uri) as client:
client.get_database().test.find_one()
def setup_cache(self):
if os.environ.get("AWS_ACCESS_KEY_ID", None) or "@" in self.uri:
self.skipTest("Not testing cached credentials")
if not hasattr(auth, "set_cached_credentials"):
self.skipTest("Cached credentials not available")
# Ensure cleared credentials.
auth.set_cached_credentials(None)
self.assertEqual(auth.get_cached_credentials(), None)
client = MongoClient(self.uri)
client.get_database().test.find_one()
client.close()
return auth.get_cached_credentials()
def test_cache_credentials(self):
creds = self.setup_cache()
self.assertIsNotNone(creds)
def test_cache_about_to_expire(self):
creds = self.setup_cache()
client = MongoClient(self.uri)
self.addCleanup(client.close)
# Make the creds about to expire.
creds = auth.get_cached_credentials()
assert creds is not None
creds = AwsCredential(creds.username, creds.password, creds.token, lambda x: True)
auth.set_cached_credentials(creds)
client.get_database().test.find_one()
new_creds = auth.get_cached_credentials()
self.assertNotEqual(creds, new_creds)
def test_poisoned_cache(self):
creds = self.setup_cache()
client = MongoClient(self.uri)
self.addCleanup(client.close)
# Poison the creds with invalid password.
assert creds is not None
creds = AwsCredential("a" * 24, "b" * 24, "c" * 24)
auth.set_cached_credentials(creds)
with self.assertRaises(OperationFailure):
client.get_database().test.find_one()
# Make sure the cache was cleared.
self.assertEqual(auth.get_cached_credentials(), None)
# The next attempt should generate a new cred and succeed.
client.get_database().test.find_one()
self.assertNotEqual(auth.get_cached_credentials(), None)
class TestAWSLambdaExamples(unittest.TestCase):
def test_shared_client(self):