diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 543dc0200..a3afbdb3f 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -13,11 +13,13 @@ # limitations under the License. """MONGODB-OIDC Authentication helpers.""" +from __future__ import annotations + import os import threading from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from typing import Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple import bson from bson.binary import Binary @@ -25,6 +27,10 @@ from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE +if TYPE_CHECKING: + from pymongo.auth import MongoCredential + from pymongo.pool import SocketInfo + @dataclass class _OIDCProperties: @@ -44,7 +50,9 @@ CALLBACK_VERSION = 0 _CACHE: Dict[str, "_OIDCAuthenticator"] = {} -def _get_authenticator(credentials, address): +def _get_authenticator( + credentials: MongoCredential, address: Tuple[str, int] +) -> _OIDCAuthenticator: # Clear out old items in the cache. now_utc = datetime.now(timezone.utc) to_remove = [] @@ -81,7 +89,7 @@ def _get_authenticator(credentials, address): return _CACHE[cache_key] -def _get_cache_exp(): +def _get_cache_exp() -> datetime: return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES) @@ -98,7 +106,7 @@ class _OIDCAuthenticator: cache_exp_utc: datetime = field(default_factory=_get_cache_exp) lock: threading.Lock = field(default_factory=threading.Lock) - def get_current_token(self, use_callbacks=True): + def get_current_token(self, use_callbacks: bool = True) -> Optional[str]: properties = self.properties request_cb = properties.request_token_callback @@ -116,16 +124,15 @@ class _OIDCAuthenticator: current_valid_token = True timeout = CALLBACK_TIMEOUT_SECONDS - if not use_callbacks and not current_valid_token: return None if not current_valid_token and request_cb is not None: - prev_token = self.idp_resp and self.idp_resp["access_token"] + prev_token = self.idp_resp["access_token"] if self.idp_resp else None with self.lock: # See if the token was changed while we were waiting for the # lock. - new_token = self.idp_resp and self.idp_resp["access_token"] + new_token = self.idp_resp["access_token"] if self.idp_resp else None if new_token != prev_token: return new_token @@ -173,14 +180,14 @@ class _OIDCAuthenticator: return token - def auth_start_cmd(self, use_callbacks=True): + def auth_start_cmd(self, use_callbacks: bool = True) -> Optional[SON[str, Any]]: properties = self.properties # Handle aws provider credentials. if properties.provider_name == "aws": aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(aws_identity_file) as fid: - token = fid.read().strip() + token: Optional[str] = fid.read().strip() payload = {"jwt": token} cmd = SON( [ @@ -230,14 +237,16 @@ class _OIDCAuthenticator: ] ) - def clear(self): + def clear(self) -> None: self.idp_info = None self.idp_resp = None self.token_exp_utc = None - def run_command(self, sock_info, cmd): + def run_command( + self, sock_info: SocketInfo, cmd: Mapping[str, Any] + ) -> Optional[Mapping[str, Any]]: try: - return sock_info.command("$external", cmd, no_reauth=True) + return sock_info.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] except OperationFailure as exc: self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: @@ -247,7 +256,9 @@ class _OIDCAuthenticator: return self.authenticate(sock_info, reauthenticate=True) raise - def authenticate(self, sock_info, reauthenticate=False): + def authenticate( + self, sock_info: SocketInfo, reauthenticate: bool = False + ) -> Optional[Mapping[str, Any]]: if reauthenticate: prev_id = getattr(sock_info, "oidc_token_gen_id", None) # Check if we've already changed tokens. @@ -264,6 +275,7 @@ class _OIDCAuthenticator: resp = ctx.speculative_authenticate else: cmd = self.auth_start_cmd() + assert cmd is not None resp = self.run_command(sock_info, cmd) if resp["done"]: @@ -293,7 +305,9 @@ class _OIDCAuthenticator: return resp -def _authenticate_oidc(credentials, sock_info, reauthenticate): +def _authenticate_oidc( + credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool +) -> Optional[Mapping[str, Any]]: """Authenticate using MONGODB-OIDC.""" authenticator = _get_authenticator(credentials, sock_info.address) return authenticator.authenticate(sock_info, reauthenticate=reauthenticate) diff --git a/pymongo/pool.py b/pymongo/pool.py index 2b498078c..a827d10f9 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -647,6 +647,7 @@ class SocketInfo: self.compression_settings = pool.opts._compression_settings self.compression_context = None self.socket_checker = SocketChecker() + self.oidc_token_gen_id = None # Support for mechanism negotiation on the initial handshake. self.negotiated_mechs = None self.auth_ctx = None