diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index ab83df965..0d253cea1 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -44,6 +44,14 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: "install with: python -m pip install 'pymongo[aws]'" ) + # Delayed import. + from pymongo_auth_aws.auth import ( # type:ignore[import] + set_cached_credentials, + set_use_cached_credentials, + ) + + set_use_cached_credentials(True) + if conn.max_wire_version < 9: raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later") @@ -87,12 +95,12 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: break except pymongo_auth_aws.PyMongoAuthAwsError as exc: # Clear the cached credentials if we hit a failure in auth. - pymongo_auth_aws.set_cached_credentials(None) + set_cached_credentials(None) # Convert to OperationFailure and include pymongo-auth-aws version. raise OperationFailure( f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})" ) from None except Exception: # Clear the cached credentials if we hit a failure in auth. - pymongo_auth_aws.set_cached_credentials(None) + set_cached_credentials(None) raise diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index d0bb41b73..3e5dcec56 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -60,8 +60,13 @@ class TestAuthAWS(unittest.TestCase): def setup_cache(self): if os.environ.get("AWS_ACCESS_KEY_ID", None) or "@" in self.uri: self.skipTest("Not testing cached credentials") - if not hasattr(auth, "set_cached_credentials"): - self.skipTest("Cached credentials not available") + + # Make a connection to ensure that we enable caching. + client = MongoClient(self.uri) + client.get_database().test.find_one() + client.close() + + self.assertTrue(auth.get_use_cached_credentials()) # Ensure cleared credentials. auth.set_cached_credentials(None)