diff --git a/.evergreen/config.yml b/.evergreen/config.yml index f0840dfce..0277cad14 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -79,6 +79,17 @@ functions: export PATH="$MONGODB_BINARIES:$PATH" export PROJECT="${project}" export PIP_QUIET=1 + ENSURE_TOOLCHAIN_PYTHON_BINARY: | + # Make sure PYTHON_BINARY is set to a suitable toolchain python. + if [ -z "$PYTHON_BINARY" ]; then + if [ "$(uname -s)" = "Darwin" ]; then + export PYTHON_BINARY=/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 + elif [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin + export PYTHON_BINARY=/cygdrive/c/python/Python39/python + else + export PYTHON_BINARY=/opt/python/3.9/bin/python3 + fi + fi EOT # Load the expansion file to make an evergreen variable with the current unique version @@ -655,7 +666,7 @@ functions: .evergreen/run-mongodb-aws-test.sh fi - "bootstrap oidc": + "run oidc auth test with aws credentials": - command: ec2.assume_role params: role_arn: ${aws_test_secrets_role} @@ -664,58 +675,11 @@ functions: params: working_dir: "src" shell: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] script: | ${PREPARE_SHELL} - if [ "${skip_EC2_auth_test}" = "true" ]; then - echo "This platform does not support the oidc auth test, skipping..." - exit 0 - fi - - cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc - export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} - export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} - export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN} - export OIDC_TOKEN_DIR=/tmp/tokens - - . ./activate-authoidcvenv.sh - python oidc_write_orchestration.py - python oidc_get_tokens.py - - "run oidc auth test with aws credentials": - - command: shell.exec - type: test - params: - working_dir: "src" - shell: bash - script: | - ${PREPARE_SHELL} - if [ "${skip_EC2_auth_test}" = "true" ]; then - echo "This platform does not support the oidc auth test, skipping..." - exit 0 - fi - cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc - mongosh setup_oidc.js - - command: shell.exec - type: test - params: - working_dir: "src" - silent: true - script: | - # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) - cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" - export OIDC_TOKEN_DIR=/tmp/tokens - EOF - - command: shell.exec - type: test - params: - working_dir: "src" - script: | - ${PREPARE_SHELL} - if [ "${skip_web_identity_auth_test}" = "true" ]; then - echo "This platform does not support the oidc auth test, skipping..." - exit 0 - fi - PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-oidc-test.sh + ${ENSURE_TOOLCHAIN_PYTHON_BINARY} + bash .evergreen/run-mongodb-oidc-test.sh "run aws auth test with aws credentials as environment variables": - command: shell.exec @@ -2133,16 +2097,7 @@ tasks: - name: "oidc-auth-test-latest" commands: - - func: "bootstrap oidc" - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-oidc.json" - TOPOLOGY: "replica_set" - VERSION: "latest" - func: "run oidc auth test with aws credentials" - vars: - AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1 - name: load-balancer-test commands: @@ -3192,9 +3147,8 @@ buildvariants: - matrix_name: "oidc-auth-test" matrix_spec: - platform: [ rhel8 ] - python-version: ["3.9"] - display_name: "MONGODB-OIDC Auth ${platform} ${python-version}" + platform: [ rhel8, macos-1100, windows-64-vsMulti-small ] + display_name: "MONGODB-OIDC Auth ${platform}" tasks: - name: "oidc-auth-test-latest" diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 8ab560531..cc92a5065 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -1,56 +1,48 @@ #!/bin/bash -set -o xtrace +set +x # Disable debug trace set -o errexit # Exit the script with error if any of the commands fail -############################################ -# Main Program # -############################################ - -# Supported/used environment variables: -# MONGODB_URI Set the URI, including an optional username/password to use -# to connect to the server via MONGODB-OIDC authentication -# mechanism. -# PYTHON_BINARY The Python version to use. - echo "Running MONGODB-OIDC authentication tests" -# ensure no secrets are printed in log files -set +x -# load the script -shopt -s expand_aliases # needed for `urlencode` alias -[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" - -MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} -MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" -MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" - -if [ -z "${OIDC_TOKEN_DIR}" ]; then - echo "Must specify OIDC_TOKEN_DIR" +# Make sure DRIVERS_TOOLS is set. +if [ -z "$DRIVERS_TOOLS" ]; then + echo "Must specify DRIVERS_TOOLS" exit 1 fi -export MONGODB_URI_SINGLE="$MONGODB_URI_SINGLE" -export MONGODB_URI_MULTIPLE="$MONGODB_URI_MULTIPLE" -export MONGODB_URI="$MONGODB_URI" +# Get the drivers secrets. Use an existing secrets file first. +if [ ! -f "./secrets-export.sh" ]; then + bash .evergreen/tox.sh -m aws-secrets -- drivers/oidc +fi +source ./secrets-export.sh -echo $MONGODB_URI_SINGLE -echo $MONGODB_URI_MULTIPLE -echo $MONGODB_URI - -if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then - if echo "$MONGODB_URI" | grep -q "@"; then - echo "MONGODB_URI unexpectedly contains user credentials!"; - exit 1 - fi +# # If the file did not have our creds, get them from the vault. +if [ -z "$OIDC_ATLAS_URI_SINGLE" ]; then + bash .evergreen/tox.sh -m aws-secrets -- drivers/oidc + source ./secrets-export.sh fi -if [ -z "$PYTHON_BINARY" ]; then - echo "Cannot test without specifying PYTHON_BINARY" - exit 1 +# Make the OIDC tokens. +set -x +pushd ${DRIVERS_TOOLS}/.evergreen/auth_oidc +. ./oidc_get_tokens.sh +popd + +# Set up variables and run the test. +if [ -n "$LOCAL_OIDC_SERVER" ]; then + export MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} + export MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" + export MONGODB_URI_MULTI="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" +else + set +x # turn off xtrace for this portion + export MONGODB_URI="$OIDC_ATLAS_URI_SINGLE" + export MONGODB_URI_SINGLE="$OIDC_ATLAS_URI_SINGLE/?authMechanism=MONGODB-OIDC" + export MONGODB_URI_MULTI="$OIDC_ATLAS_URI_MULTI/?authMechanism=MONGODB-OIDC" + set -x fi export TEST_AUTH_OIDC=1 +export COVERAGE=1 export AUTH="auth" -export SET_XTRACE_ON=1 bash ./.evergreen/tox.sh -m test-eg diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 02bd86eaa..21cba87ec 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -49,6 +49,9 @@ if [ "$AUTH" != "noauth" ]; then elif [ ! -z "$TEST_SERVERLESS" ]; then export DB_USER=$SERVERLESS_ATLAS_USER export DB_PASSWORD=$SERVERLESS_ATLAS_PASSWORD + elif [ ! -z "$TEST_AUTH_OIDC" ]; then + export DB_USER=$OIDC_ALTAS_USER + export DB_PASSWORD=$OIDC_ATLAS_PASSWORD else export DB_USER="bob" export DB_PASSWORD="pwd123" @@ -109,7 +112,7 @@ fi if [ -n "$TEST_ENCRYPTION" ] || [ -n "$TEST_FLE_AZURE_AUTO" ] || [ -n "$TEST_FLE_GCP_AUTO" ]; then # Work around for root certifi not being installed. - # TODO: Remove after PYTHON-3827 + # TODO: Remove after PYTHON-3952 is deployed. if [ "$(uname -s)" = "Darwin" ]; then python -m pip install certifi CERT_PATH=$(python -c "import certifi; print(certifi.where())") @@ -224,6 +227,17 @@ fi if [ -n "$TEST_AUTH_OIDC" ]; then python -m pip install ".[aws]" + + # Work around for root certifi not being installed. + # TODO: Remove after PYTHON-3952 is deployed. + if [ "$(uname -s)" = "Darwin" ]; then + python -m pip install certifi + CERT_PATH=$(python -c "import certifi; print(certifi.where())") + export SSL_CERT_FILE=${CERT_PATH} + export REQUESTS_CA_BUNDLE=${CERT_PATH} + export AWS_CA_BUNDLE=${CERT_PATH} + fi + TEST_ARGS="test/auth_oidc/test_auth_oidc.py" fi diff --git a/pymongo/auth.py b/pymongo/auth.py index b20239ac3..d7237b633 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -164,11 +164,11 @@ def _build_credentials_tuple( elif mech == "MONGODB-OIDC": properties = extra.get("authmechanismproperties", {}) request_token_callback = properties.get("request_token_callback") - refresh_token_callback = properties.get("refresh_token_callback", None) provider_name = properties.get("PROVIDER_NAME", "") default_allowed = [ "*.mongodb.net", "*.mongodb-dev.net", + "*.mongodb-qa.net", "*.mongodbgov.net", "localhost", "127.0.0.1", @@ -181,11 +181,10 @@ def _build_credentials_tuple( ) oidc_props = _OIDCProperties( request_token_callback=request_token_callback, - refresh_token_callback=refresh_token_callback, provider_name=provider_name, allowed_hosts=allowed_hosts, ) - return MongoCredential(mech, "$external", user, passwd, oidc_props, None) + return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache()) elif mech == "PLAIN": source_database = source or database or "$external" diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 0ca74fc49..220f91e67 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -15,27 +15,14 @@ """MONGODB-OIDC Authentication helpers.""" from __future__ import annotations -import os import threading from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Mapping, - MutableMapping, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional import bson from bson.binary import Binary from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE if TYPE_CHECKING: from pymongo.auth import MongoCredential @@ -44,39 +31,27 @@ if TYPE_CHECKING: @dataclass class _OIDCProperties: - request_token_callback: Optional[Callable[..., Dict]] - refresh_token_callback: Optional[Callable[..., Dict]] + request_token_callback: Optional[Callable[..., dict]] provider_name: Optional[str] - allowed_hosts: List[str] + allowed_hosts: list[str] """Mechanism properties for MONGODB-OIDC authentication.""" TOKEN_BUFFER_MINUTES = 5 CALLBACK_TIMEOUT_SECONDS = 5 * 60 -CACHE_TIMEOUT_MINUTES = 60 * 5 -CALLBACK_VERSION = 0 - -_CACHE: Dict[str, "_OIDCAuthenticator"] = {} +CALLBACK_VERSION = 1 def _get_authenticator( - credentials: MongoCredential, address: Tuple[str, int] + credentials: MongoCredential, address: tuple[str, int] ) -> _OIDCAuthenticator: - # Clear out old items in the cache. - now_utc = datetime.now(timezone.utc) - to_remove = [] - for key, value in _CACHE.items(): - if value.cache_exp_utc is not None and value.cache_exp_utc < now_utc: - to_remove.append(key) - for key in to_remove: - del _CACHE[key] + if credentials.cache.data: + return credentials.cache.data # Extract values. principal_name = credentials.username properties = credentials.mechanism_properties - request_cb = properties.request_token_callback - refresh_cb = properties.refresh_token_callback # Validate that the address is allowed. if not properties.provider_name: @@ -92,150 +67,99 @@ def _get_authenticator( f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" ) - # Get or create the cache item. - cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" - _CACHE.setdefault(cache_key, _OIDCAuthenticator(username=principal_name, properties=properties)) - - return _CACHE[cache_key] - - -def _get_cache_exp() -> datetime: - return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES) + # Get or create the cache data. + credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties) + return credentials.cache.data @dataclass class _OIDCAuthenticator: username: str properties: _OIDCProperties - idp_info: Optional[Dict] = field(default=None) - idp_resp: Optional[Dict] = field(default=None) - reauth_gen_id: int = field(default=0) - idp_info_gen_id: int = field(default=0) + refresh_token: Optional[str] = field(default=None) + access_token: Optional[str] = field(default=None) + idp_info: Optional[dict] = field(default=None) token_gen_id: int = field(default=0) - token_exp_utc: Optional[datetime] = field(default=None) - cache_exp_utc: datetime = field(default_factory=_get_cache_exp) lock: threading.Lock = field(default_factory=threading.Lock) - def get_current_token(self, use_callbacks: bool = True) -> Optional[str]: + def get_current_token(self, use_callback: bool = True) -> Optional[str]: properties = self.properties - request_cb = properties.request_token_callback - refresh_cb = properties.refresh_token_callback - if not use_callbacks: - request_cb = None - refresh_cb = None + # TODO: DRIVERS-2672, handle machine callback here as well. + cb = properties.request_token_callback if use_callback else None + cb_type = "human" - current_valid_token = False - if self.token_exp_utc is not None: - now_utc = datetime.now(timezone.utc) - exp_utc = self.token_exp_utc - buffer_seconds = TOKEN_BUFFER_MINUTES * 60 - if (exp_utc - now_utc).total_seconds() >= buffer_seconds: - current_valid_token = True + prev_token = self.access_token + if prev_token: + return prev_token - timeout = CALLBACK_TIMEOUT_SECONDS - if not use_callbacks and not current_valid_token: + if not use_callback and not prev_token: return None - if not current_valid_token and request_cb is not None: - prev_token = self.idp_resp["access_token"] if self.idp_resp else None + if not prev_token and cb is not None: with self.lock: # See if the token was changed while we were waiting for the # lock. - new_token = self.idp_resp["access_token"] if self.idp_resp else None + new_token = self.access_token if new_token != prev_token: return new_token - refresh_token = self.idp_resp and self.idp_resp.get("refresh_token") - refresh_token = refresh_token or "" - context = { - "timeout_seconds": timeout, - "version": CALLBACK_VERSION, - "refresh_token": refresh_token, - } + # TODO: DRIVERS-2672 handle machine callback here. + if cb_type == "human": + context = { + "timeout_seconds": CALLBACK_TIMEOUT_SECONDS, + "version": CALLBACK_VERSION, + "refresh_token": self.refresh_token, + } + resp = cb(self.idp_info, context) + + self.validate_request_token_response(resp) - if self.idp_resp is None or refresh_cb is None: - self.idp_resp = request_cb(self.idp_info, context) - elif request_cb is not None: - self.idp_resp = refresh_cb(self.idp_info, context) - cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=CACHE_TIMEOUT_MINUTES - ) - self.cache_exp_utc = cache_exp_utc self.token_gen_id += 1 - token_result = self.idp_resp + return self.access_token + def validate_request_token_response(self, resp: Mapping[str, Any]) -> None: # Validate callback return value. - if not isinstance(token_result, dict): + if not isinstance(resp, dict): raise ValueError("OIDC callback returned invalid result") - if "access_token" not in token_result: + if "access_token" not in resp: raise ValueError("OIDC callback did not return an access_token") - expected = ["access_token", "expires_in_seconds", "refesh_token"] - for key in token_result: + expected = ["access_token", "refresh_token", "expires_in_seconds"] + for key in resp: if key not in expected: raise ValueError(f'Unexpected field in callback result "{key}"') - token = token_result["access_token"] + self.access_token = resp["access_token"] + self.refresh_token = resp.get("refresh_token") - if "expires_in_seconds" in token_result: - expires_in = int(token_result["expires_in_seconds"]) - buffer_seconds = TOKEN_BUFFER_MINUTES * 60 - if expires_in >= buffer_seconds: - now_utc = datetime.now(timezone.utc) - exp_utc = now_utc + timedelta(seconds=expires_in) - self.token_exp_utc = exp_utc - - return token - - def auth_start_cmd(self, use_callbacks: bool = True) -> Optional[SON[str, Any]]: - properties = self.properties - - # Handle aws provider credentials. - if properties.provider_name == "aws": - aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] - with open(aws_identity_file) as fid: - token: Optional[str] = fid.read().strip() - payload = {"jwt": token} - cmd = SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ] - ) - return cmd + def principal_step_cmd(self) -> SON[str, Any]: + """Get a SASL start command with an optional principal name""" + # Send the SASL start with the optional principal name. + payload = {} principal_name = self.username + if principal_name: + payload["n"] = principal_name - if self.idp_info is not None: - self.cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=CACHE_TIMEOUT_MINUTES - ) + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), + ("autoAuthorize", 1), + ] + ) + return cmd + def auth_start_cmd(self, use_callback: bool = True) -> Optional[SON[str, Any]]: + # TODO: DRIVERS-2672, check for provider_name in self.properties here. if self.idp_info is None: - self.cache_exp_utc = _get_cache_exp() + return self.principal_step_cmd() - if self.idp_info is None: - # Send the SASL start with the optional principal name. - payload = {} - - if principal_name: - payload["n"] = principal_name - - cmd = SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ("autoAuthorize", 1), - ] - ) - return cmd - - token = self.get_current_token(use_callbacks) + token = self.get_current_token(use_callback) if not token: return None bin_payload = Binary(bson.encode({"jwt": token})) @@ -247,37 +171,59 @@ class _OIDCAuthenticator: ] ) - def clear(self) -> None: - self.idp_info = None - self.idp_resp = None - self.token_exp_utc = None - def run_command( self, conn: Connection, cmd: MutableMapping[str, Any] ) -> Optional[Mapping[str, Any]]: try: return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] - except OperationFailure as exc: - self.clear() - if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - if "jwt" in bson.decode(cmd["payload"]): - if self.idp_info_gen_id > self.reauth_gen_id: - raise - return self.authenticate(conn, reauthenticate=True) + except OperationFailure: + self.access_token = None raise - def authenticate( - self, conn: Connection, reauthenticate: bool = False - ) -> Optional[Mapping[str, Any]]: - if reauthenticate: - prev_id = getattr(conn, "oidc_token_gen_id", None) - # Check if we've already changed tokens. - if prev_id == self.token_gen_id: - self.reauth_gen_id = self.idp_info_gen_id - self.token_exp_utc = None - if not self.properties.refresh_token_callback: - self.clear() + def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + """Handle a reauthenticate from the server.""" + # First see if we have the a newer token on the authenticator. + prev_id = conn.oidc_token_gen_id or 0 + # If we've already changed tokens, make one optimistic attempt. + if (prev_id < self.token_gen_id) and self.access_token: + try: + return self.authenticate(conn) + except OperationFailure: + pass + self.access_token = None + + # TODO: DRIVERS-2672, check for provider_name in self.properties here. + # If so, we clear the access token and return finish_auth. + + # Next see if the idp info has changed. + prev_idp_info = self.idp_info + self.idp_info = None + cmd = self.principal_step_cmd() + resp = self.run_command(conn, cmd) + assert resp is not None + server_resp: dict = bson.decode(resp["payload"]) + if "issuer" in server_resp: + self.idp_info = server_resp + + # Handle the case of changed idp info. + if not self.idp_info == prev_idp_info: + self.access_token = None + self.refresh_token = None + + # If we have a refresh token, try using that. + if self.refresh_token: + try: + return self.finish_auth(resp, conn) + except OperationFailure: + self.refresh_token = None + # If that fails, try again without the refresh token. + return self.authenticate(conn) + + # If we don't have a refresh token, just try once. + return self.finish_auth(resp, conn) + + def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: ctx = conn.auth_ctx cmd = None @@ -293,12 +239,16 @@ class _OIDCAuthenticator: conn.oidc_token_gen_id = self.token_gen_id return None - server_resp: Dict = bson.decode(resp["payload"]) + server_resp: dict = bson.decode(resp["payload"]) if "issuer" in server_resp: self.idp_info = server_resp - self.idp_info_gen_id += 1 - conversation_id = resp["conversationId"] + return self.finish_auth(resp, conn) + + def finish_auth( + self, orig_resp: Mapping[str, Any], conn: Connection + ) -> Optional[Mapping[str, Any]]: + conversation_id = orig_resp["conversationId"] token = self.get_current_token() conn.oidc_token_gen_id = self.token_gen_id bin_payload = Binary(bson.encode({"jwt": token})) @@ -312,7 +262,6 @@ class _OIDCAuthenticator: resp = self.run_command(conn, cmd) assert resp is not None if not resp["done"]: - self.clear() raise OperationFailure("SASL conversation failed to complete.") return resp @@ -322,4 +271,7 @@ def _authenticate_oidc( ) -> Optional[Mapping[str, Any]]: """Authenticate using MONGODB-OIDC.""" authenticator = _get_authenticator(credentials, conn.address) - return authenticator.authenticate(conn, reauthenticate=reauthenticate) + if reauthenticate: + return authenticator.reauthenticate(conn) + else: + return authenticator.authenticate(conn) diff --git a/pymongo/common.py b/pymongo/common.py index f5c4da71e..fad24030f 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -440,8 +440,6 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni signature = inspect.signature(value) if key == "request_token_callback": expected_params = 2 - elif key == "refresh_token_callback": - expected_params = 2 else: raise ValueError(f"Unrecognized Auth mechanism function {key}") if len(signature.parameters) != expected_params: diff --git a/test/auth/legacy/connection-string.json b/test/auth/legacy/connection-string.json index ca979010a..0463a5141 100644 --- a/test/auth/legacy/connection-string.json +++ b/test/auth/legacy/connection-string.json @@ -475,22 +475,6 @@ } } }, - { - "description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest", "oidcRefresh"], - "valid": true, - "credential": { - "username": null, - "password": null, - "source": "$external", - "mechanism": "MONGODB-OIDC", - "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true, - "REFRESH_TOKEN_CALLBACK": true - } - } - }, { "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", @@ -554,18 +538,11 @@ "credential": null }, { - "description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)", + "description": "should throw an exception if neither deviceName nor callback specified (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", "valid": false, "credential": null }, - { - "description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRefresh"], - "valid": false, - "credential": null - }, { "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", @@ -573,4 +550,4 @@ "credential": null } ] -} \ No newline at end of file +} diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 7b42f98a1..b1315455f 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -16,7 +16,6 @@ import os import sys -import threading import time import unittest from contextlib import contextmanager @@ -29,7 +28,6 @@ from test.utils import EventListener from bson import SON from pymongo import MongoClient from pymongo.auth import _AUTH_MAP, _authenticate_oidc -from pymongo.auth_oidc import _CACHE as _oidc_cache from pymongo.cursor import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat @@ -45,19 +43,16 @@ class TestAuthOIDC(unittest.TestCase): @classmethod def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] - cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"] + cls.uri_multiple = os.environ["MONGODB_URI_MULTI"] cls.uri_admin = os.environ["MONGODB_URI"] cls.token_dir = os.environ["OIDC_TOKEN_DIR"] def setUp(self): self.request_called = 0 - self.refresh_called = 0 - _oidc_cache.clear() - os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") - def create_request_cb(self, username="test_user1", expires_in_seconds=None, sleep=0): + def create_request_cb(self, username="test_user1", sleep=0): - token_file = os.path.join(self.token_dir, username) + token_file = os.path.join(self.token_dir, username).replace(os.sep, "/") def request_token(server_info, context): # Validate the info. @@ -69,43 +64,14 @@ class TestAuthOIDC(unittest.TestCase): self.assertEqual(timeout_seconds, 60 * 5) with open(token_file) as fid: token = fid.read() - resp = {"access_token": token} + resp = {"access_token": token, "refresh_token": token} time.sleep(sleep) - - if expires_in_seconds is not None: - resp["expires_in_seconds"] = expires_in_seconds self.request_called += 1 return resp return request_token - def create_refresh_cb(self, username="test_user1", expires_in_seconds=None): - - token_file = os.path.join(self.token_dir, username) - - def refresh_token(server_info, context): - with open(token_file) as fid: - token = fid.read() - - # Validate the info. - self.assertIn("issuer", server_info) - self.assertIn("clientId", server_info) - - # Validate the creds - self.assertIsNotNone(context["refresh_token"]) - - # Validate the timeout. - self.assertEqual(context["timeout_seconds"], 60 * 5) - - resp = {"access_token": token} - if expires_in_seconds is not None: - resp["expires_in_seconds"] = expires_in_seconds - self.refresh_called += 1 - return resp - - return refresh_token - @contextmanager def fail_point(self, command_args): cmd_on = SON([("configureFailPoint", "failCommand")]) @@ -117,21 +83,21 @@ class TestAuthOIDC(unittest.TestCase): finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") - def test_connect_callbacks_single_implicit_username(self): + def test_connect_request_callback_single_implicit_username(self): request_token = self.create_request_cb() props: Dict = {"request_token_callback": request_token} client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() - def test_connect_callbacks_single_explicit_username(self): + def test_connect_request_callback_single_explicit_username(self): request_token = self.create_request_cb() props: Dict = {"request_token_callback": request_token} client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props) client.test.test.find_one() client.close() - def test_connect_callbacks_multiple_principal_user1(self): + def test_connect_request_callback_multiple_principal_user1(self): request_token = self.create_request_cb() props: Dict = {"request_token_callback": request_token} client = MongoClient( @@ -140,7 +106,7 @@ class TestAuthOIDC(unittest.TestCase): client.test.test.find_one() client.close() - def test_connect_callbacks_multiple_principal_user2(self): + def test_connect_request_callback_multiple_principal_user2(self): request_token = self.create_request_cb("test_user2") props: Dict = {"request_token_callback": request_token} client = MongoClient( @@ -149,7 +115,7 @@ class TestAuthOIDC(unittest.TestCase): client.test.test.find_one() client.close() - def test_connect_callbacks_multiple_no_username(self): + def test_connect_request_callback_multiple_no_username(self): request_token = self.create_request_cb() props: Dict = {"request_token_callback": request_token} client = MongoClient(self.uri_multiple, authmechanismproperties=props) @@ -173,38 +139,11 @@ class TestAuthOIDC(unittest.TestCase): client.test.test.find_one() client.close() - def test_connect_aws_single_principal(self): - props = {"PROVIDER_NAME": "aws"} - client = MongoClient(self.uri_single, authmechanismproperties=props) - client.test.test.find_one() - client.close() - - def test_connect_aws_multiple_principal_user1(self): - props = {"PROVIDER_NAME": "aws"} - client = MongoClient(self.uri_multiple, authmechanismproperties=props) - client.test.test.find_one() - client.close() - - def test_connect_aws_multiple_principal_user2(self): - os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2") - props = {"PROVIDER_NAME": "aws"} - client = MongoClient(self.uri_multiple, authmechanismproperties=props) - client.test.test.find_one() - client.close() - - def test_connect_aws_allowed_hosts_ignored(self): - props = {"PROVIDER_NAME": "aws", "allowed_hosts": []} - client = MongoClient(self.uri_multiple, authmechanismproperties=props) - client.test.test.find_one() - client.close() - - def test_valid_callbacks(self): - request_cb = self.create_request_cb(expires_in_seconds=60) - refresh_cb = self.create_refresh_cb() + def test_valid_request_token_callback(self): + request_cb = self.create_request_cb() props: Dict = { "request_token_callback": request_cb, - "refresh_token_callback": refresh_cb, } client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() @@ -214,31 +153,6 @@ class TestAuthOIDC(unittest.TestCase): client.test.test.find_one() client.close() - def test_lock_avoids_extra_callbacks(self): - request_cb = self.create_request_cb(sleep=0.5) - refresh_cb = self.create_refresh_cb() - - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} - - def run_test(): - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - t1 = threading.Thread(target=run_test) - t2 = threading.Thread(target=run_test) - t1.start() - t2.start() - t1.join() - t2.join() - - self.assertEqual(self.request_called, 1) - self.assertEqual(self.refresh_called, 2) - def test_request_callback_returns_null(self): def request_token_null(a, b): return None @@ -249,25 +163,6 @@ class TestAuthOIDC(unittest.TestCase): client.test.test.find_one() client.close() - def test_refresh_callback_returns_null(self): - request_cb = self.create_request_cb(expires_in_seconds=60) - - def refresh_token_null(a, b): - return None - - props: Dict = { - "request_token_callback": request_cb, - "refresh_token_callback": refresh_token_null, - } - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): - client.test.test.find_one() - client.close() - def test_request_callback_invalid_result(self): def request_token_invalid(a, b): return {} @@ -289,166 +184,10 @@ class TestAuthOIDC(unittest.TestCase): client.test.test.find_one() client.close() - def test_refresh_callback_missing_data(self): - request_cb = self.create_request_cb(expires_in_seconds=60) - - def refresh_cb_no_token(a, b): - return {} - - props: Dict = { - "request_token_callback": request_cb, - "refresh_token_callback": refresh_cb_no_token, - } - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): - client.test.test.find_one() - client.close() - - def test_refresh_callback_extra_data(self): - request_cb = self.create_request_cb(expires_in_seconds=60) - - def refresh_cb_extra_value(server_info, context): - result = self.create_refresh_cb()(server_info, context) - result["foo"] = "bar" - return result - - props: Dict = { - "request_token_callback": request_cb, - "refresh_token_callback": refresh_cb_extra_value, - } - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): - client.test.test.find_one() - client.close() - - def test_cache_with_refresh(self): - # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. - - # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. - request_cb = self.create_request_cb(expires_in_seconds=60) - refresh_cb = self.create_refresh_cb() - - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} - - # Ensure that a ``find`` operation adds credentials to the cache. - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - self.assertEqual(len(_oidc_cache), 1) - - # Create a new client with the same request callback and a refresh callback. - # Ensure that a ``find`` operation results in a call to the refresh callback. - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - self.assertEqual(self.refresh_called, 1) - self.assertEqual(len(_oidc_cache), 1) - - def test_cache_with_no_refresh(self): - # Create a new client with a request callback callback. - # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. - request_cb = self.create_request_cb() - - props = {"request_token_callback": request_cb} - client = MongoClient(self.uri_single, authMechanismProperties=props) - - # Ensure that a ``find`` operation adds credentials to the cache. - self.request_called = 0 - client.test.test.find_one() - client.close() - self.assertEqual(self.request_called, 1) - self.assertEqual(len(_oidc_cache), 1) - - # Create a new client with the same request callback. - # Ensure that a ``find`` operation results in a call to the request callback. - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - self.assertEqual(self.request_called, 2) - self.assertEqual(len(_oidc_cache), 1) - - def test_cache_key_includes_callback(self): - request_cb = self.create_request_cb() - - props: Dict = {"request_token_callback": request_cb} - - # Ensure that a ``find`` operation adds a new entry to the cache. - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - # Create a new client with a different request callback. - def request_token_2(a, b): - return request_cb(a, b) - - props["request_token_callback"] = request_token_2 - client = MongoClient(self.uri_single, authMechanismProperties=props) - - # Ensure that a ``find`` operation adds a new entry to the cache. - client.test.test.find_one() - client.close() - self.assertEqual(len(_oidc_cache), 2) - - def test_cache_clears_on_error(self): - request_cb = self.create_request_cb() - - # Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials. - def refresh_cb(a, b): - return {"access_token": "bad"} - - # Add a token to the cache that will expire soon. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - # Create a new client with the same callbacks. - client = MongoClient(self.uri_single, authMechanismProperties=props) - - # Ensure that another ``find`` operation results in an error. - with self.assertRaises(OperationFailure): - client.test.test.find_one() - - client.close() - - # Ensure that the cache has been cleared. - authenticator = list(_oidc_cache.values())[0] - self.assertIsNone(authenticator.idp_info) - - def test_cache_is_not_used_in_aws_automatic_workflow(self): - # Create a new client using the AWS device workflow. - # Ensure that a ``find`` operation does not add credentials to the cache. - props = {"PROVIDER_NAME": "aws"} - client = MongoClient(self.uri_single, authmechanismproperties=props) - client.test.test.find_one() - client.close() - - # Ensure that the cache has been cleared. - authenticator = list(_oidc_cache.values())[0] - self.assertIsNone(authenticator.idp_info) - def test_speculative_auth_success(self): - # Clear the cache - _oidc_cache.clear() - token_file = os.path.join(self.token_dir, "test_user1") + request_token = self.create_request_cb() - def request_token(a, b): - with open(token_file) as fid: - token = fid.read() - return {"access_token": token, "expires_in_seconds": 1000} - - # Create a client with a request callback that returns a valid token - # that will not expire soon. + # Create a client with a request callback that returns a valid token. props: Dict = {"request_token_callback": request_token} client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -465,32 +204,14 @@ class TestAuthOIDC(unittest.TestCase): # Close the client. client.close() - # Create a new client. - client = MongoClient(self.uri_single, authmechanismproperties=props) - - # Set a fail point for saslStart commands. - with self.fail_point( - { - "mode": {"times": 2}, - "data": {"failCommands": ["saslStart"], "errorCode": 18}, - } - ): - # Perform a find operation. - client.test.test.find_one() - - # Close the client. - client.close() - def test_reauthenticate_succeeds(self): listener = EventListener() - # Create request and refresh callbacks that return valid credentials - # that will not expire soon. + # Create request callback that returns valid credentials. request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) @@ -498,8 +219,8 @@ class TestAuthOIDC(unittest.TestCase): # Perform a find operation. client.test.test.find_one() - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) listener.reset() @@ -534,23 +255,109 @@ class TestAuthOIDC(unittest.TestCase): self.assertEqual(succeeded_events, ["find"]) self.assertEqual(failed_events, ["find"]) - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) client.close() - def test_reauthenticate_succeeds_bulk_write(self): - request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() + def test_reauthenticate_succeeds_no_refresh(self): + cb = self.create_request_cb() - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} + def request_cb(*args, **kwargs): + result = cb(*args, **kwargs) + del result["refresh_token"] + return result + + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform a find operation. client.test.test.find_one() - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) + client.close() + + def test_reauthenticate_succeeds_after_refresh_fails(self): + + # Create request callback that returns valid credentials. + request_cb = self.create_request_cb() + + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) + + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["find", "saslContinue"], "errorCode": 391}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + # Assert that the request callback has been called three times. + self.assertEqual(self.request_called, 3) + + def test_reauthenticate_fails(self): + + # Create request callback that returns valid credentials. + request_cb = self.create_request_cb() + + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) + + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation that fails. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) + client.close() + + def test_reauthenticate_succeeds_bulk_write(self): + request_cb = self.create_request_cb() + + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) with self.fail_point( { @@ -561,16 +368,15 @@ class TestAuthOIDC(unittest.TestCase): # Perform a bulk write operation. client.test.test.bulk_write([InsertOne({})]) - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) client.close() def test_reauthenticate_succeeds_bulk_read(self): request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform a find operation. @@ -579,8 +385,8 @@ class TestAuthOIDC(unittest.TestCase): # Perform a bulk write operation. client.test.test.bulk_write([InsertOne({})]) - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) with self.fail_point( { @@ -592,23 +398,22 @@ class TestAuthOIDC(unittest.TestCase): cursor = client.test.test.find_raw_batches({}) list(cursor) - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) client.close() def test_reauthenticate_succeeds_cursor(self): request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform an insert operation. client.test.test.insert_one({"a": 1}) - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) with self.fail_point( { @@ -620,23 +425,22 @@ class TestAuthOIDC(unittest.TestCase): cursor = client.test.test.find({"a": 1}) self.assertGreaterEqual(len(list(cursor)), 1) - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) client.close() def test_reauthenticate_succeeds_get_more(self): request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform an insert operation. client.test.test.insert_many([{"a": 1}, {"a": 1}]) - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) with self.fail_point( { @@ -648,30 +452,29 @@ class TestAuthOIDC(unittest.TestCase): cursor = client.test.test.find({"a": 1}, batch_size=1) self.assertGreaterEqual(len(list(cursor)), 1) - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) client.close() def test_reauthenticate_succeeds_get_more_exhaust(self): # Ensure no mongos - props = {"PROVIDER_NAME": "aws"} + props = {"request_token_callback": self.create_request_cb()} client = MongoClient(self.uri_single, authmechanismproperties=props) hello = client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") != "isdbgrid": raise unittest.SkipTest("Must not be a mongos") request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform an insert operation. client.test.test.insert_many([{"a": 1}, {"a": 1}]) - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) with self.fail_point( { @@ -683,16 +486,15 @@ class TestAuthOIDC(unittest.TestCase): cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST) self.assertGreaterEqual(len(list(cursor)), 1) - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) client.close() def test_reauthenticate_succeeds_command(self): request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} print("start of test") client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -700,8 +502,8 @@ class TestAuthOIDC(unittest.TestCase): # Perform an insert operation. client.test.test.insert_one({"a": 1}) - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) + # Assert that the request callback has been called once. + self.assertEqual(self.request_called, 1) with self.fail_point( { @@ -714,112 +516,54 @@ class TestAuthOIDC(unittest.TestCase): self.assertGreaterEqual(len(list(cursor)), 1) - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) + # Assert that the request callback has been called twice. + self.assertEqual(self.request_called, 2) client.close() - def test_reauthenticate_retries_and_succeeds_with_cache(self): - listener = EventListener() - - # Create request and refresh callbacks that return valid credentials - # that will not expire soon. + def test_reauthentication_succeeds_multiple_connections(self): request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} - client = MongoClient( - self.uri_single, event_listeners=[listener], authmechanismproperties=props + # Create a client with the callback. + props: Dict = {"request_token_callback": request_cb} + + client1 = MongoClient(self.uri_single, authmechanismproperties=props) + client2 = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client1.test.test.insert_many([{"a": 1}, {"a": 1}]) + client2.test.test.find_one() + self.assertEqual(self.request_called, 2) + + # Use the same authenticator for both clients + # to simulate a race condition with separate connections. + # We should only see one extra callback despite both connections + # needing to reauthenticate. + client2.options.pool_options._credentials.cache.data = ( + client1.options.pool_options._credentials.cache.data ) - # Perform a find operation. - client.test.test.find_one() - - # Set a fail point for ``saslStart`` commands of the form - with self.fail_point( - { - "mode": {"times": 2}, - "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, - } - ): - # Perform a find operation that succeeds. - client.test.test.find_one() - - # Close the client. - client.close() - - def test_reauthenticate_fails_with_no_cache(self): - listener = EventListener() - - # Create request and refresh callbacks that return valid credentials - # that will not expire soon. - request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - - # Create a client with the callbacks. - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} - client = MongoClient( - self.uri_single, event_listeners=[listener], authmechanismproperties=props - ) - - # Perform a find operation. - client.test.test.find_one() - - # Clear the cache. - _oidc_cache.clear() - - with self.fail_point( - { - "mode": {"times": 2}, - "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, - } - ): - # Perform a find operation that fails. - with self.assertRaises(OperationFailure): - client.test.test.find_one() - - client.close() - - def test_late_reauth_avoids_callback(self): - # Step 1: connect with both clients - request_cb = self.create_request_cb(expires_in_seconds=1e6) - refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6) - - props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} - client1 = MongoClient(self.uri_single, authMechanismProperties=props) client1.test.test.find_one() - client2 = MongoClient(self.uri_single, authMechanismProperties=props) client2.test.test.find_one() - self.assertEqual(self.refresh_called, 0) - self.assertEqual(self.request_called, 1) - - # Step 2: cause a find 391 on the first client with self.fail_point( { "mode": {"times": 1}, "data": {"failCommands": ["find"], "errorCode": 391}, } ): - # Perform a find operation that succeeds. client1.test.test.find_one() - self.assertEqual(self.refresh_called, 1) - self.assertEqual(self.request_called, 1) + self.assertEqual(self.request_called, 3) - # Step 3: cause a find 391 on the second client with self.fail_point( { "mode": {"times": 1}, "data": {"failCommands": ["find"], "errorCode": 391}, } ): - # Perform a find operation that succeeds. client2.test.test.find_one() - self.assertEqual(self.refresh_called, 1) - self.assertEqual(self.request_called, 1) - + self.assertEqual(self.request_called, 3) client1.close() client2.close() diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index ebcc4eeb7..6a8ec35ec 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -48,9 +48,6 @@ def create_test(test_case): if props.get("REQUEST_TOKEN_CALLBACK"): props["request_token_callback"] = lambda x, y: 1 del props["REQUEST_TOKEN_CALLBACK"] - if props.get("REFRESH_TOKEN_CALLBACK"): - props["refresh_token_callback"] = lambda a, b: 1 - del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials if credential is None: @@ -86,10 +83,6 @@ def create_test(test_case): self.assertEqual( actual.request_token_callback, expected["request_token_callback"] ) - elif "refresh_token_callback" in expected: - self.assertEqual( - actual.refresh_token_callback, expected["refresh_token_callback"] - ) else: self.fail(f"Unhandled property: {key}") else: