PYTHON-3716 OIDC-SASL Follow-Up (#1365)
This commit is contained in:
parent
9b6f2e18cf
commit
0590ce49ca
@ -79,6 +79,17 @@ functions:
|
||||
export PATH="$MONGODB_BINARIES:$PATH"
|
||||
export PROJECT="${project}"
|
||||
export PIP_QUIET=1
|
||||
ENSURE_TOOLCHAIN_PYTHON_BINARY: |
|
||||
# Make sure PYTHON_BINARY is set to a suitable toolchain python.
|
||||
if [ -z "$PYTHON_BINARY" ]; then
|
||||
if [ "$(uname -s)" = "Darwin" ]; then
|
||||
export PYTHON_BINARY=/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3
|
||||
elif [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin
|
||||
export PYTHON_BINARY=/cygdrive/c/python/Python39/python
|
||||
else
|
||||
export PYTHON_BINARY=/opt/python/3.9/bin/python3
|
||||
fi
|
||||
fi
|
||||
EOT
|
||||
|
||||
# Load the expansion file to make an evergreen variable with the current unique version
|
||||
@ -655,7 +666,7 @@ functions:
|
||||
.evergreen/run-mongodb-aws-test.sh
|
||||
fi
|
||||
|
||||
"bootstrap oidc":
|
||||
"run oidc auth test with aws credentials":
|
||||
- command: ec2.assume_role
|
||||
params:
|
||||
role_arn: ${aws_test_secrets_role}
|
||||
@ -664,58 +675,11 @@ functions:
|
||||
params:
|
||||
working_dir: "src"
|
||||
shell: bash
|
||||
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
|
||||
script: |
|
||||
${PREPARE_SHELL}
|
||||
if [ "${skip_EC2_auth_test}" = "true" ]; then
|
||||
echo "This platform does not support the oidc auth test, skipping..."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
|
||||
export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
|
||||
export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
|
||||
export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN}
|
||||
export OIDC_TOKEN_DIR=/tmp/tokens
|
||||
|
||||
. ./activate-authoidcvenv.sh
|
||||
python oidc_write_orchestration.py
|
||||
python oidc_get_tokens.py
|
||||
|
||||
"run oidc auth test with aws credentials":
|
||||
- command: shell.exec
|
||||
type: test
|
||||
params:
|
||||
working_dir: "src"
|
||||
shell: bash
|
||||
script: |
|
||||
${PREPARE_SHELL}
|
||||
if [ "${skip_EC2_auth_test}" = "true" ]; then
|
||||
echo "This platform does not support the oidc auth test, skipping..."
|
||||
exit 0
|
||||
fi
|
||||
cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
|
||||
mongosh setup_oidc.js
|
||||
- command: shell.exec
|
||||
type: test
|
||||
params:
|
||||
working_dir: "src"
|
||||
silent: true
|
||||
script: |
|
||||
# DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does)
|
||||
cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh"
|
||||
export OIDC_TOKEN_DIR=/tmp/tokens
|
||||
EOF
|
||||
- command: shell.exec
|
||||
type: test
|
||||
params:
|
||||
working_dir: "src"
|
||||
script: |
|
||||
${PREPARE_SHELL}
|
||||
if [ "${skip_web_identity_auth_test}" = "true" ]; then
|
||||
echo "This platform does not support the oidc auth test, skipping..."
|
||||
exit 0
|
||||
fi
|
||||
PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-oidc-test.sh
|
||||
${ENSURE_TOOLCHAIN_PYTHON_BINARY}
|
||||
bash .evergreen/run-mongodb-oidc-test.sh
|
||||
|
||||
"run aws auth test with aws credentials as environment variables":
|
||||
- command: shell.exec
|
||||
@ -2133,16 +2097,7 @@ tasks:
|
||||
|
||||
- name: "oidc-auth-test-latest"
|
||||
commands:
|
||||
- func: "bootstrap oidc"
|
||||
- func: "bootstrap mongo-orchestration"
|
||||
vars:
|
||||
AUTH: "auth"
|
||||
ORCHESTRATION_FILE: "auth-oidc.json"
|
||||
TOPOLOGY: "replica_set"
|
||||
VERSION: "latest"
|
||||
- func: "run oidc auth test with aws credentials"
|
||||
vars:
|
||||
AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1
|
||||
|
||||
- name: load-balancer-test
|
||||
commands:
|
||||
@ -3192,9 +3147,8 @@ buildvariants:
|
||||
|
||||
- matrix_name: "oidc-auth-test"
|
||||
matrix_spec:
|
||||
platform: [ rhel8 ]
|
||||
python-version: ["3.9"]
|
||||
display_name: "MONGODB-OIDC Auth ${platform} ${python-version}"
|
||||
platform: [ rhel8, macos-1100, windows-64-vsMulti-small ]
|
||||
display_name: "MONGODB-OIDC Auth ${platform}"
|
||||
tasks:
|
||||
- name: "oidc-auth-test-latest"
|
||||
|
||||
|
||||
@ -1,56 +1,48 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -o xtrace
|
||||
set +x # Disable debug trace
|
||||
set -o errexit # Exit the script with error if any of the commands fail
|
||||
|
||||
############################################
|
||||
# Main Program #
|
||||
############################################
|
||||
|
||||
# Supported/used environment variables:
|
||||
# MONGODB_URI Set the URI, including an optional username/password to use
|
||||
# to connect to the server via MONGODB-OIDC authentication
|
||||
# mechanism.
|
||||
# PYTHON_BINARY The Python version to use.
|
||||
|
||||
echo "Running MONGODB-OIDC authentication tests"
|
||||
# ensure no secrets are printed in log files
|
||||
set +x
|
||||
|
||||
# load the script
|
||||
shopt -s expand_aliases # needed for `urlencode` alias
|
||||
[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh"
|
||||
|
||||
MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"}
|
||||
MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC"
|
||||
MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true"
|
||||
|
||||
if [ -z "${OIDC_TOKEN_DIR}" ]; then
|
||||
echo "Must specify OIDC_TOKEN_DIR"
|
||||
# Make sure DRIVERS_TOOLS is set.
|
||||
if [ -z "$DRIVERS_TOOLS" ]; then
|
||||
echo "Must specify DRIVERS_TOOLS"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export MONGODB_URI_SINGLE="$MONGODB_URI_SINGLE"
|
||||
export MONGODB_URI_MULTIPLE="$MONGODB_URI_MULTIPLE"
|
||||
export MONGODB_URI="$MONGODB_URI"
|
||||
# Get the drivers secrets. Use an existing secrets file first.
|
||||
if [ ! -f "./secrets-export.sh" ]; then
|
||||
bash .evergreen/tox.sh -m aws-secrets -- drivers/oidc
|
||||
fi
|
||||
source ./secrets-export.sh
|
||||
|
||||
echo $MONGODB_URI_SINGLE
|
||||
echo $MONGODB_URI_MULTIPLE
|
||||
echo $MONGODB_URI
|
||||
|
||||
if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then
|
||||
if echo "$MONGODB_URI" | grep -q "@"; then
|
||||
echo "MONGODB_URI unexpectedly contains user credentials!";
|
||||
exit 1
|
||||
fi
|
||||
# # If the file did not have our creds, get them from the vault.
|
||||
if [ -z "$OIDC_ATLAS_URI_SINGLE" ]; then
|
||||
bash .evergreen/tox.sh -m aws-secrets -- drivers/oidc
|
||||
source ./secrets-export.sh
|
||||
fi
|
||||
|
||||
if [ -z "$PYTHON_BINARY" ]; then
|
||||
echo "Cannot test without specifying PYTHON_BINARY"
|
||||
exit 1
|
||||
# Make the OIDC tokens.
|
||||
set -x
|
||||
pushd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
|
||||
. ./oidc_get_tokens.sh
|
||||
popd
|
||||
|
||||
# Set up variables and run the test.
|
||||
if [ -n "$LOCAL_OIDC_SERVER" ]; then
|
||||
export MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"}
|
||||
export MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC"
|
||||
export MONGODB_URI_MULTI="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true"
|
||||
else
|
||||
set +x # turn off xtrace for this portion
|
||||
export MONGODB_URI="$OIDC_ATLAS_URI_SINGLE"
|
||||
export MONGODB_URI_SINGLE="$OIDC_ATLAS_URI_SINGLE/?authMechanism=MONGODB-OIDC"
|
||||
export MONGODB_URI_MULTI="$OIDC_ATLAS_URI_MULTI/?authMechanism=MONGODB-OIDC"
|
||||
set -x
|
||||
fi
|
||||
|
||||
export TEST_AUTH_OIDC=1
|
||||
export COVERAGE=1
|
||||
export AUTH="auth"
|
||||
export SET_XTRACE_ON=1
|
||||
bash ./.evergreen/tox.sh -m test-eg
|
||||
|
||||
@ -49,6 +49,9 @@ if [ "$AUTH" != "noauth" ]; then
|
||||
elif [ ! -z "$TEST_SERVERLESS" ]; then
|
||||
export DB_USER=$SERVERLESS_ATLAS_USER
|
||||
export DB_PASSWORD=$SERVERLESS_ATLAS_PASSWORD
|
||||
elif [ ! -z "$TEST_AUTH_OIDC" ]; then
|
||||
export DB_USER=$OIDC_ALTAS_USER
|
||||
export DB_PASSWORD=$OIDC_ATLAS_PASSWORD
|
||||
else
|
||||
export DB_USER="bob"
|
||||
export DB_PASSWORD="pwd123"
|
||||
@ -109,7 +112,7 @@ fi
|
||||
if [ -n "$TEST_ENCRYPTION" ] || [ -n "$TEST_FLE_AZURE_AUTO" ] || [ -n "$TEST_FLE_GCP_AUTO" ]; then
|
||||
|
||||
# Work around for root certifi not being installed.
|
||||
# TODO: Remove after PYTHON-3827
|
||||
# TODO: Remove after PYTHON-3952 is deployed.
|
||||
if [ "$(uname -s)" = "Darwin" ]; then
|
||||
python -m pip install certifi
|
||||
CERT_PATH=$(python -c "import certifi; print(certifi.where())")
|
||||
@ -224,6 +227,17 @@ fi
|
||||
|
||||
if [ -n "$TEST_AUTH_OIDC" ]; then
|
||||
python -m pip install ".[aws]"
|
||||
|
||||
# Work around for root certifi not being installed.
|
||||
# TODO: Remove after PYTHON-3952 is deployed.
|
||||
if [ "$(uname -s)" = "Darwin" ]; then
|
||||
python -m pip install certifi
|
||||
CERT_PATH=$(python -c "import certifi; print(certifi.where())")
|
||||
export SSL_CERT_FILE=${CERT_PATH}
|
||||
export REQUESTS_CA_BUNDLE=${CERT_PATH}
|
||||
export AWS_CA_BUNDLE=${CERT_PATH}
|
||||
fi
|
||||
|
||||
TEST_ARGS="test/auth_oidc/test_auth_oidc.py"
|
||||
fi
|
||||
|
||||
|
||||
@ -164,11 +164,11 @@ def _build_credentials_tuple(
|
||||
elif mech == "MONGODB-OIDC":
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
request_token_callback = properties.get("request_token_callback")
|
||||
refresh_token_callback = properties.get("refresh_token_callback", None)
|
||||
provider_name = properties.get("PROVIDER_NAME", "")
|
||||
default_allowed = [
|
||||
"*.mongodb.net",
|
||||
"*.mongodb-dev.net",
|
||||
"*.mongodb-qa.net",
|
||||
"*.mongodbgov.net",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
@ -181,11 +181,10 @@ def _build_credentials_tuple(
|
||||
)
|
||||
oidc_props = _OIDCProperties(
|
||||
request_token_callback=request_token_callback,
|
||||
refresh_token_callback=refresh_token_callback,
|
||||
provider_name=provider_name,
|
||||
allowed_hosts=allowed_hosts,
|
||||
)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, None)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
|
||||
|
||||
elif mech == "PLAIN":
|
||||
source_database = source or database or "$external"
|
||||
|
||||
@ -15,27 +15,14 @@
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from bson.son import SON
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.auth import MongoCredential
|
||||
@ -44,39 +31,27 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class _OIDCProperties:
|
||||
request_token_callback: Optional[Callable[..., Dict]]
|
||||
refresh_token_callback: Optional[Callable[..., Dict]]
|
||||
request_token_callback: Optional[Callable[..., dict]]
|
||||
provider_name: Optional[str]
|
||||
allowed_hosts: List[str]
|
||||
allowed_hosts: list[str]
|
||||
|
||||
|
||||
"""Mechanism properties for MONGODB-OIDC authentication."""
|
||||
|
||||
TOKEN_BUFFER_MINUTES = 5
|
||||
CALLBACK_TIMEOUT_SECONDS = 5 * 60
|
||||
CACHE_TIMEOUT_MINUTES = 60 * 5
|
||||
CALLBACK_VERSION = 0
|
||||
|
||||
_CACHE: Dict[str, "_OIDCAuthenticator"] = {}
|
||||
CALLBACK_VERSION = 1
|
||||
|
||||
|
||||
def _get_authenticator(
|
||||
credentials: MongoCredential, address: Tuple[str, int]
|
||||
credentials: MongoCredential, address: tuple[str, int]
|
||||
) -> _OIDCAuthenticator:
|
||||
# Clear out old items in the cache.
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
to_remove = []
|
||||
for key, value in _CACHE.items():
|
||||
if value.cache_exp_utc is not None and value.cache_exp_utc < now_utc:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del _CACHE[key]
|
||||
if credentials.cache.data:
|
||||
return credentials.cache.data
|
||||
|
||||
# Extract values.
|
||||
principal_name = credentials.username
|
||||
properties = credentials.mechanism_properties
|
||||
request_cb = properties.request_token_callback
|
||||
refresh_cb = properties.refresh_token_callback
|
||||
|
||||
# Validate that the address is allowed.
|
||||
if not properties.provider_name:
|
||||
@ -92,150 +67,99 @@ def _get_authenticator(
|
||||
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
|
||||
)
|
||||
|
||||
# Get or create the cache item.
|
||||
cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}"
|
||||
_CACHE.setdefault(cache_key, _OIDCAuthenticator(username=principal_name, properties=properties))
|
||||
|
||||
return _CACHE[cache_key]
|
||||
|
||||
|
||||
def _get_cache_exp() -> datetime:
|
||||
return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES)
|
||||
# Get or create the cache data.
|
||||
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
|
||||
return credentials.cache.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
properties: _OIDCProperties
|
||||
idp_info: Optional[Dict] = field(default=None)
|
||||
idp_resp: Optional[Dict] = field(default=None)
|
||||
reauth_gen_id: int = field(default=0)
|
||||
idp_info_gen_id: int = field(default=0)
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
access_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[dict] = field(default=None)
|
||||
token_gen_id: int = field(default=0)
|
||||
token_exp_utc: Optional[datetime] = field(default=None)
|
||||
cache_exp_utc: datetime = field(default_factory=_get_cache_exp)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def get_current_token(self, use_callbacks: bool = True) -> Optional[str]:
|
||||
def get_current_token(self, use_callback: bool = True) -> Optional[str]:
|
||||
properties = self.properties
|
||||
|
||||
request_cb = properties.request_token_callback
|
||||
refresh_cb = properties.refresh_token_callback
|
||||
if not use_callbacks:
|
||||
request_cb = None
|
||||
refresh_cb = None
|
||||
# TODO: DRIVERS-2672, handle machine callback here as well.
|
||||
cb = properties.request_token_callback if use_callback else None
|
||||
cb_type = "human"
|
||||
|
||||
current_valid_token = False
|
||||
if self.token_exp_utc is not None:
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
exp_utc = self.token_exp_utc
|
||||
buffer_seconds = TOKEN_BUFFER_MINUTES * 60
|
||||
if (exp_utc - now_utc).total_seconds() >= buffer_seconds:
|
||||
current_valid_token = True
|
||||
prev_token = self.access_token
|
||||
if prev_token:
|
||||
return prev_token
|
||||
|
||||
timeout = CALLBACK_TIMEOUT_SECONDS
|
||||
if not use_callbacks and not current_valid_token:
|
||||
if not use_callback and not prev_token:
|
||||
return None
|
||||
|
||||
if not current_valid_token and request_cb is not None:
|
||||
prev_token = self.idp_resp["access_token"] if self.idp_resp else None
|
||||
if not prev_token and cb is not None:
|
||||
with self.lock:
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.idp_resp["access_token"] if self.idp_resp else None
|
||||
new_token = self.access_token
|
||||
if new_token != prev_token:
|
||||
return new_token
|
||||
|
||||
refresh_token = self.idp_resp and self.idp_resp.get("refresh_token")
|
||||
refresh_token = refresh_token or ""
|
||||
context = {
|
||||
"timeout_seconds": timeout,
|
||||
"version": CALLBACK_VERSION,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
# TODO: DRIVERS-2672 handle machine callback here.
|
||||
if cb_type == "human":
|
||||
context = {
|
||||
"timeout_seconds": CALLBACK_TIMEOUT_SECONDS,
|
||||
"version": CALLBACK_VERSION,
|
||||
"refresh_token": self.refresh_token,
|
||||
}
|
||||
resp = cb(self.idp_info, context)
|
||||
|
||||
self.validate_request_token_response(resp)
|
||||
|
||||
if self.idp_resp is None or refresh_cb is None:
|
||||
self.idp_resp = request_cb(self.idp_info, context)
|
||||
elif request_cb is not None:
|
||||
self.idp_resp = refresh_cb(self.idp_info, context)
|
||||
cache_exp_utc = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=CACHE_TIMEOUT_MINUTES
|
||||
)
|
||||
self.cache_exp_utc = cache_exp_utc
|
||||
self.token_gen_id += 1
|
||||
|
||||
token_result = self.idp_resp
|
||||
return self.access_token
|
||||
|
||||
def validate_request_token_response(self, resp: Mapping[str, Any]) -> None:
|
||||
# Validate callback return value.
|
||||
if not isinstance(token_result, dict):
|
||||
if not isinstance(resp, dict):
|
||||
raise ValueError("OIDC callback returned invalid result")
|
||||
|
||||
if "access_token" not in token_result:
|
||||
if "access_token" not in resp:
|
||||
raise ValueError("OIDC callback did not return an access_token")
|
||||
|
||||
expected = ["access_token", "expires_in_seconds", "refesh_token"]
|
||||
for key in token_result:
|
||||
expected = ["access_token", "refresh_token", "expires_in_seconds"]
|
||||
for key in resp:
|
||||
if key not in expected:
|
||||
raise ValueError(f'Unexpected field in callback result "{key}"')
|
||||
|
||||
token = token_result["access_token"]
|
||||
self.access_token = resp["access_token"]
|
||||
self.refresh_token = resp.get("refresh_token")
|
||||
|
||||
if "expires_in_seconds" in token_result:
|
||||
expires_in = int(token_result["expires_in_seconds"])
|
||||
buffer_seconds = TOKEN_BUFFER_MINUTES * 60
|
||||
if expires_in >= buffer_seconds:
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
exp_utc = now_utc + timedelta(seconds=expires_in)
|
||||
self.token_exp_utc = exp_utc
|
||||
|
||||
return token
|
||||
|
||||
def auth_start_cmd(self, use_callbacks: bool = True) -> Optional[SON[str, Any]]:
|
||||
properties = self.properties
|
||||
|
||||
# Handle aws provider credentials.
|
||||
if properties.provider_name == "aws":
|
||||
aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"]
|
||||
with open(aws_identity_file) as fid:
|
||||
token: Optional[str] = fid.read().strip()
|
||||
payload = {"jwt": token}
|
||||
cmd = SON(
|
||||
[
|
||||
("saslStart", 1),
|
||||
("mechanism", "MONGODB-OIDC"),
|
||||
("payload", Binary(bson.encode(payload))),
|
||||
]
|
||||
)
|
||||
return cmd
|
||||
def principal_step_cmd(self) -> SON[str, Any]:
|
||||
"""Get a SASL start command with an optional principal name"""
|
||||
# Send the SASL start with the optional principal name.
|
||||
payload = {}
|
||||
|
||||
principal_name = self.username
|
||||
if principal_name:
|
||||
payload["n"] = principal_name
|
||||
|
||||
if self.idp_info is not None:
|
||||
self.cache_exp_utc = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=CACHE_TIMEOUT_MINUTES
|
||||
)
|
||||
cmd = SON(
|
||||
[
|
||||
("saslStart", 1),
|
||||
("mechanism", "MONGODB-OIDC"),
|
||||
("payload", Binary(bson.encode(payload))),
|
||||
("autoAuthorize", 1),
|
||||
]
|
||||
)
|
||||
return cmd
|
||||
|
||||
def auth_start_cmd(self, use_callback: bool = True) -> Optional[SON[str, Any]]:
|
||||
# TODO: DRIVERS-2672, check for provider_name in self.properties here.
|
||||
if self.idp_info is None:
|
||||
self.cache_exp_utc = _get_cache_exp()
|
||||
return self.principal_step_cmd()
|
||||
|
||||
if self.idp_info is None:
|
||||
# Send the SASL start with the optional principal name.
|
||||
payload = {}
|
||||
|
||||
if principal_name:
|
||||
payload["n"] = principal_name
|
||||
|
||||
cmd = SON(
|
||||
[
|
||||
("saslStart", 1),
|
||||
("mechanism", "MONGODB-OIDC"),
|
||||
("payload", Binary(bson.encode(payload))),
|
||||
("autoAuthorize", 1),
|
||||
]
|
||||
)
|
||||
return cmd
|
||||
|
||||
token = self.get_current_token(use_callbacks)
|
||||
token = self.get_current_token(use_callback)
|
||||
if not token:
|
||||
return None
|
||||
bin_payload = Binary(bson.encode({"jwt": token}))
|
||||
@ -247,37 +171,59 @@ class _OIDCAuthenticator:
|
||||
]
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.idp_info = None
|
||||
self.idp_resp = None
|
||||
self.token_exp_utc = None
|
||||
|
||||
def run_command(
|
||||
self, conn: Connection, cmd: MutableMapping[str, Any]
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
try:
|
||||
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as exc:
|
||||
self.clear()
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
if "jwt" in bson.decode(cmd["payload"]):
|
||||
if self.idp_info_gen_id > self.reauth_gen_id:
|
||||
raise
|
||||
return self.authenticate(conn, reauthenticate=True)
|
||||
except OperationFailure:
|
||||
self.access_token = None
|
||||
raise
|
||||
|
||||
def authenticate(
|
||||
self, conn: Connection, reauthenticate: bool = False
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
if reauthenticate:
|
||||
prev_id = getattr(conn, "oidc_token_gen_id", None)
|
||||
# Check if we've already changed tokens.
|
||||
if prev_id == self.token_gen_id:
|
||||
self.reauth_gen_id = self.idp_info_gen_id
|
||||
self.token_exp_utc = None
|
||||
if not self.properties.refresh_token_callback:
|
||||
self.clear()
|
||||
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle a reauthenticate from the server."""
|
||||
# First see if we have the a newer token on the authenticator.
|
||||
prev_id = conn.oidc_token_gen_id or 0
|
||||
# If we've already changed tokens, make one optimistic attempt.
|
||||
if (prev_id < self.token_gen_id) and self.access_token:
|
||||
try:
|
||||
return self.authenticate(conn)
|
||||
except OperationFailure:
|
||||
pass
|
||||
|
||||
self.access_token = None
|
||||
|
||||
# TODO: DRIVERS-2672, check for provider_name in self.properties here.
|
||||
# If so, we clear the access token and return finish_auth.
|
||||
|
||||
# Next see if the idp info has changed.
|
||||
prev_idp_info = self.idp_info
|
||||
self.idp_info = None
|
||||
cmd = self.principal_step_cmd()
|
||||
resp = self.run_command(conn, cmd)
|
||||
assert resp is not None
|
||||
server_resp: dict = bson.decode(resp["payload"])
|
||||
if "issuer" in server_resp:
|
||||
self.idp_info = server_resp
|
||||
|
||||
# Handle the case of changed idp info.
|
||||
if not self.idp_info == prev_idp_info:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
|
||||
# If we have a refresh token, try using that.
|
||||
if self.refresh_token:
|
||||
try:
|
||||
return self.finish_auth(resp, conn)
|
||||
except OperationFailure:
|
||||
self.refresh_token = None
|
||||
# If that fails, try again without the refresh token.
|
||||
return self.authenticate(conn)
|
||||
|
||||
# If we don't have a refresh token, just try once.
|
||||
return self.finish_auth(resp, conn)
|
||||
|
||||
def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
ctx = conn.auth_ctx
|
||||
cmd = None
|
||||
|
||||
@ -293,12 +239,16 @@ class _OIDCAuthenticator:
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
return None
|
||||
|
||||
server_resp: Dict = bson.decode(resp["payload"])
|
||||
server_resp: dict = bson.decode(resp["payload"])
|
||||
if "issuer" in server_resp:
|
||||
self.idp_info = server_resp
|
||||
self.idp_info_gen_id += 1
|
||||
|
||||
conversation_id = resp["conversationId"]
|
||||
return self.finish_auth(resp, conn)
|
||||
|
||||
def finish_auth(
|
||||
self, orig_resp: Mapping[str, Any], conn: Connection
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
conversation_id = orig_resp["conversationId"]
|
||||
token = self.get_current_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
bin_payload = Binary(bson.encode({"jwt": token}))
|
||||
@ -312,7 +262,6 @@ class _OIDCAuthenticator:
|
||||
resp = self.run_command(conn, cmd)
|
||||
assert resp is not None
|
||||
if not resp["done"]:
|
||||
self.clear()
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
return resp
|
||||
|
||||
@ -322,4 +271,7 @@ def _authenticate_oidc(
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, conn.address)
|
||||
return authenticator.authenticate(conn, reauthenticate=reauthenticate)
|
||||
if reauthenticate:
|
||||
return authenticator.reauthenticate(conn)
|
||||
else:
|
||||
return authenticator.authenticate(conn)
|
||||
|
||||
@ -440,8 +440,6 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
|
||||
signature = inspect.signature(value)
|
||||
if key == "request_token_callback":
|
||||
expected_params = 2
|
||||
elif key == "refresh_token_callback":
|
||||
expected_params = 2
|
||||
else:
|
||||
raise ValueError(f"Unrecognized Auth mechanism function {key}")
|
||||
if len(signature.parameters) != expected_params:
|
||||
|
||||
@ -475,22 +475,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
|
||||
"callback": ["oidcRequest", "oidcRefresh"],
|
||||
"valid": true,
|
||||
"credential": {
|
||||
"username": null,
|
||||
"password": null,
|
||||
"source": "$external",
|
||||
"mechanism": "MONGODB-OIDC",
|
||||
"mechanism_properties": {
|
||||
"REQUEST_TOKEN_CALLBACK": true,
|
||||
"REFRESH_TOKEN_CALLBACK": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)",
|
||||
"uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC",
|
||||
@ -554,18 +538,11 @@
|
||||
"credential": null
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)",
|
||||
"description": "should throw an exception if neither deviceName nor callback specified (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
|
||||
"valid": false,
|
||||
"credential": null
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
|
||||
"callback": ["oidcRefresh"],
|
||||
"valid": false,
|
||||
"credential": null
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted",
|
||||
@ -573,4 +550,4 @@
|
||||
"credential": null
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
@ -29,7 +28,6 @@ from test.utils import EventListener
|
||||
from bson import SON
|
||||
from pymongo import MongoClient
|
||||
from pymongo.auth import _AUTH_MAP, _authenticate_oidc
|
||||
from pymongo.auth_oidc import _CACHE as _oidc_cache
|
||||
from pymongo.cursor import CursorType
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.hello import HelloCompat
|
||||
@ -45,19 +43,16 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.uri_single = os.environ["MONGODB_URI_SINGLE"]
|
||||
cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"]
|
||||
cls.uri_multiple = os.environ["MONGODB_URI_MULTI"]
|
||||
cls.uri_admin = os.environ["MONGODB_URI"]
|
||||
cls.token_dir = os.environ["OIDC_TOKEN_DIR"]
|
||||
|
||||
def setUp(self):
|
||||
self.request_called = 0
|
||||
self.refresh_called = 0
|
||||
_oidc_cache.clear()
|
||||
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1")
|
||||
|
||||
def create_request_cb(self, username="test_user1", expires_in_seconds=None, sleep=0):
|
||||
def create_request_cb(self, username="test_user1", sleep=0):
|
||||
|
||||
token_file = os.path.join(self.token_dir, username)
|
||||
token_file = os.path.join(self.token_dir, username).replace(os.sep, "/")
|
||||
|
||||
def request_token(server_info, context):
|
||||
# Validate the info.
|
||||
@ -69,43 +64,14 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
self.assertEqual(timeout_seconds, 60 * 5)
|
||||
with open(token_file) as fid:
|
||||
token = fid.read()
|
||||
resp = {"access_token": token}
|
||||
resp = {"access_token": token, "refresh_token": token}
|
||||
|
||||
time.sleep(sleep)
|
||||
|
||||
if expires_in_seconds is not None:
|
||||
resp["expires_in_seconds"] = expires_in_seconds
|
||||
self.request_called += 1
|
||||
return resp
|
||||
|
||||
return request_token
|
||||
|
||||
def create_refresh_cb(self, username="test_user1", expires_in_seconds=None):
|
||||
|
||||
token_file = os.path.join(self.token_dir, username)
|
||||
|
||||
def refresh_token(server_info, context):
|
||||
with open(token_file) as fid:
|
||||
token = fid.read()
|
||||
|
||||
# Validate the info.
|
||||
self.assertIn("issuer", server_info)
|
||||
self.assertIn("clientId", server_info)
|
||||
|
||||
# Validate the creds
|
||||
self.assertIsNotNone(context["refresh_token"])
|
||||
|
||||
# Validate the timeout.
|
||||
self.assertEqual(context["timeout_seconds"], 60 * 5)
|
||||
|
||||
resp = {"access_token": token}
|
||||
if expires_in_seconds is not None:
|
||||
resp["expires_in_seconds"] = expires_in_seconds
|
||||
self.refresh_called += 1
|
||||
return resp
|
||||
|
||||
return refresh_token
|
||||
|
||||
@contextmanager
|
||||
def fail_point(self, command_args):
|
||||
cmd_on = SON([("configureFailPoint", "failCommand")])
|
||||
@ -117,21 +83,21 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
finally:
|
||||
client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off")
|
||||
|
||||
def test_connect_callbacks_single_implicit_username(self):
|
||||
def test_connect_request_callback_single_implicit_username(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = {"request_token_callback": request_token}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_callbacks_single_explicit_username(self):
|
||||
def test_connect_request_callback_single_explicit_username(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = {"request_token_callback": request_token}
|
||||
client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_callbacks_multiple_principal_user1(self):
|
||||
def test_connect_request_callback_multiple_principal_user1(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = {"request_token_callback": request_token}
|
||||
client = MongoClient(
|
||||
@ -140,7 +106,7 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_callbacks_multiple_principal_user2(self):
|
||||
def test_connect_request_callback_multiple_principal_user2(self):
|
||||
request_token = self.create_request_cb("test_user2")
|
||||
props: Dict = {"request_token_callback": request_token}
|
||||
client = MongoClient(
|
||||
@ -149,7 +115,7 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_callbacks_multiple_no_username(self):
|
||||
def test_connect_request_callback_multiple_no_username(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = {"request_token_callback": request_token}
|
||||
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
|
||||
@ -173,38 +139,11 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_aws_single_principal(self):
|
||||
props = {"PROVIDER_NAME": "aws"}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_aws_multiple_principal_user1(self):
|
||||
props = {"PROVIDER_NAME": "aws"}
|
||||
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_aws_multiple_principal_user2(self):
|
||||
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2")
|
||||
props = {"PROVIDER_NAME": "aws"}
|
||||
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_aws_allowed_hosts_ignored(self):
|
||||
props = {"PROVIDER_NAME": "aws", "allowed_hosts": []}
|
||||
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_valid_callbacks(self):
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
def test_valid_request_token_callback(self):
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
props: Dict = {
|
||||
"request_token_callback": request_cb,
|
||||
"refresh_token_callback": refresh_cb,
|
||||
}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
@ -214,31 +153,6 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_lock_avoids_extra_callbacks(self):
|
||||
request_cb = self.create_request_cb(sleep=0.5)
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
|
||||
def run_test():
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
t1 = threading.Thread(target=run_test)
|
||||
t2 = threading.Thread(target=run_test)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
self.assertEqual(self.request_called, 1)
|
||||
self.assertEqual(self.refresh_called, 2)
|
||||
|
||||
def test_request_callback_returns_null(self):
|
||||
def request_token_null(a, b):
|
||||
return None
|
||||
@ -249,25 +163,6 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_refresh_callback_returns_null(self):
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
|
||||
def refresh_token_null(a, b):
|
||||
return None
|
||||
|
||||
props: Dict = {
|
||||
"request_token_callback": request_cb,
|
||||
"refresh_token_callback": refresh_token_null,
|
||||
}
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_request_callback_invalid_result(self):
|
||||
def request_token_invalid(a, b):
|
||||
return {}
|
||||
@ -289,166 +184,10 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_refresh_callback_missing_data(self):
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
|
||||
def refresh_cb_no_token(a, b):
|
||||
return {}
|
||||
|
||||
props: Dict = {
|
||||
"request_token_callback": request_cb,
|
||||
"refresh_token_callback": refresh_cb_no_token,
|
||||
}
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_refresh_callback_extra_data(self):
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
|
||||
def refresh_cb_extra_value(server_info, context):
|
||||
result = self.create_refresh_cb()(server_info, context)
|
||||
result["foo"] = "bar"
|
||||
return result
|
||||
|
||||
props: Dict = {
|
||||
"request_token_callback": request_cb,
|
||||
"refresh_token_callback": refresh_cb_extra_value,
|
||||
}
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_cache_with_refresh(self):
|
||||
# Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token.
|
||||
|
||||
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
|
||||
# Ensure that a ``find`` operation adds credentials to the cache.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
self.assertEqual(len(_oidc_cache), 1)
|
||||
|
||||
# Create a new client with the same request callback and a refresh callback.
|
||||
# Ensure that a ``find`` operation results in a call to the refresh callback.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
self.assertEqual(len(_oidc_cache), 1)
|
||||
|
||||
def test_cache_with_no_refresh(self):
|
||||
# Create a new client with a request callback callback.
|
||||
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
props = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
|
||||
# Ensure that a ``find`` operation adds credentials to the cache.
|
||||
self.request_called = 0
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
self.assertEqual(self.request_called, 1)
|
||||
self.assertEqual(len(_oidc_cache), 1)
|
||||
|
||||
# Create a new client with the same request callback.
|
||||
# Ensure that a ``find`` operation results in a call to the request callback.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
self.assertEqual(self.request_called, 2)
|
||||
self.assertEqual(len(_oidc_cache), 1)
|
||||
|
||||
def test_cache_key_includes_callback(self):
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
|
||||
# Ensure that a ``find`` operation adds a new entry to the cache.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
# Create a new client with a different request callback.
|
||||
def request_token_2(a, b):
|
||||
return request_cb(a, b)
|
||||
|
||||
props["request_token_callback"] = request_token_2
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
|
||||
# Ensure that a ``find`` operation adds a new entry to the cache.
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
self.assertEqual(len(_oidc_cache), 2)
|
||||
|
||||
def test_cache_clears_on_error(self):
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
# Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials.
|
||||
def refresh_cb(a, b):
|
||||
return {"access_token": "bad"}
|
||||
|
||||
# Add a token to the cache that will expire soon.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
# Create a new client with the same callbacks.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
|
||||
# Ensure that another ``find`` operation results in an error.
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.test.test.find_one()
|
||||
|
||||
client.close()
|
||||
|
||||
# Ensure that the cache has been cleared.
|
||||
authenticator = list(_oidc_cache.values())[0]
|
||||
self.assertIsNone(authenticator.idp_info)
|
||||
|
||||
def test_cache_is_not_used_in_aws_automatic_workflow(self):
|
||||
# Create a new client using the AWS device workflow.
|
||||
# Ensure that a ``find`` operation does not add credentials to the cache.
|
||||
props = {"PROVIDER_NAME": "aws"}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
# Ensure that the cache has been cleared.
|
||||
authenticator = list(_oidc_cache.values())[0]
|
||||
self.assertIsNone(authenticator.idp_info)
|
||||
|
||||
def test_speculative_auth_success(self):
|
||||
# Clear the cache
|
||||
_oidc_cache.clear()
|
||||
token_file = os.path.join(self.token_dir, "test_user1")
|
||||
request_token = self.create_request_cb()
|
||||
|
||||
def request_token(a, b):
|
||||
with open(token_file) as fid:
|
||||
token = fid.read()
|
||||
return {"access_token": token, "expires_in_seconds": 1000}
|
||||
|
||||
# Create a client with a request callback that returns a valid token
|
||||
# that will not expire soon.
|
||||
# Create a client with a request callback that returns a valid token.
|
||||
props: Dict = {"request_token_callback": request_token}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
@ -465,32 +204,14 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
# Close the client.
|
||||
client.close()
|
||||
|
||||
# Create a new client.
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Set a fail point for saslStart commands.
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["saslStart"], "errorCode": 18},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Close the client.
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds(self):
|
||||
listener = EventListener()
|
||||
|
||||
# Create request and refresh callbacks that return valid credentials
|
||||
# that will not expire soon.
|
||||
# Create request callback that returns valid credentials.
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(
|
||||
self.uri_single, event_listeners=[listener], authmechanismproperties=props
|
||||
)
|
||||
@ -498,8 +219,8 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
listener.reset()
|
||||
|
||||
@ -534,23 +255,109 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
self.assertEqual(succeeded_events, ["find"])
|
||||
self.assertEqual(failed_events, ["find"])
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_bulk_write(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
def test_reauthenticate_succeeds_no_refresh(self):
|
||||
cb = self.create_request_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
def request_cb(*args, **kwargs):
|
||||
result = cb(*args, **kwargs)
|
||||
del result["refresh_token"]
|
||||
return result
|
||||
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_after_refresh_fails(self):
|
||||
|
||||
# Create request callback that returns valid credentials.
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["find", "saslContinue"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the request callback has been called three times.
|
||||
self.assertEqual(self.request_called, 3)
|
||||
|
||||
def test_reauthenticate_fails(self):
|
||||
|
||||
# Create request callback that returns valid credentials.
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that fails.
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_bulk_write(self):
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
@ -561,16 +368,15 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
# Perform a bulk write operation.
|
||||
client.test.test.bulk_write([InsertOne({})])
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_bulk_read(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform a find operation.
|
||||
@ -579,8 +385,8 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
# Perform a bulk write operation.
|
||||
client.test.test.bulk_write([InsertOne({})])
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
@ -592,23 +398,22 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
cursor = client.test.test.find_raw_batches({})
|
||||
list(cursor)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_cursor(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform an insert operation.
|
||||
client.test.test.insert_one({"a": 1})
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
@ -620,23 +425,22 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
cursor = client.test.test.find({"a": 1})
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_get_more(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform an insert operation.
|
||||
client.test.test.insert_many([{"a": 1}, {"a": 1}])
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
@ -648,30 +452,29 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
cursor = client.test.test.find({"a": 1}, batch_size=1)
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_get_more_exhaust(self):
|
||||
# Ensure no mongos
|
||||
props = {"PROVIDER_NAME": "aws"}
|
||||
props = {"request_token_callback": self.create_request_cb()}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
hello = client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
if hello.get("msg") != "isdbgrid":
|
||||
raise unittest.SkipTest("Must not be a mongos")
|
||||
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform an insert operation.
|
||||
client.test.test.insert_many([{"a": 1}, {"a": 1}])
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
@ -683,16 +486,15 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST)
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_command(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
|
||||
print("start of test")
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
@ -700,8 +502,8 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
# Perform an insert operation.
|
||||
client.test.test.insert_one({"a": 1})
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
# Assert that the request callback has been called once.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
@ -714,112 +516,54 @@ class TestAuthOIDC(unittest.TestCase):
|
||||
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_retries_and_succeeds_with_cache(self):
|
||||
listener = EventListener()
|
||||
|
||||
# Create request and refresh callbacks that return valid credentials
|
||||
# that will not expire soon.
|
||||
def test_reauthentication_succeeds_multiple_connections(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
client = MongoClient(
|
||||
self.uri_single, event_listeners=[listener], authmechanismproperties=props
|
||||
# Create a client with the callback.
|
||||
props: Dict = {"request_token_callback": request_cb}
|
||||
|
||||
client1 = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client2 = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform an insert operation.
|
||||
client1.test.test.insert_many([{"a": 1}, {"a": 1}])
|
||||
client2.test.test.find_one()
|
||||
self.assertEqual(self.request_called, 2)
|
||||
|
||||
# Use the same authenticator for both clients
|
||||
# to simulate a race condition with separate connections.
|
||||
# We should only see one extra callback despite both connections
|
||||
# needing to reauthenticate.
|
||||
client2.options.pool_options._credentials.cache.data = (
|
||||
client1.options.pool_options._credentials.cache.data
|
||||
)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Set a fail point for ``saslStart`` commands of the form
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that succeeds.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Close the client.
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_fails_with_no_cache(self):
|
||||
listener = EventListener()
|
||||
|
||||
# Create request and refresh callbacks that return valid credentials
|
||||
# that will not expire soon.
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
client = MongoClient(
|
||||
self.uri_single, event_listeners=[listener], authmechanismproperties=props
|
||||
)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Clear the cache.
|
||||
_oidc_cache.clear()
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that fails.
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.test.test.find_one()
|
||||
|
||||
client.close()
|
||||
|
||||
def test_late_reauth_avoids_callback(self):
|
||||
# Step 1: connect with both clients
|
||||
request_cb = self.create_request_cb(expires_in_seconds=1e6)
|
||||
refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6)
|
||||
|
||||
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
|
||||
client1 = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client1.test.test.find_one()
|
||||
client2 = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client2.test.test.find_one()
|
||||
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
# Step 2: cause a find 391 on the first client
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that succeeds.
|
||||
client1.test.test.find_one()
|
||||
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
self.assertEqual(self.request_called, 1)
|
||||
self.assertEqual(self.request_called, 3)
|
||||
|
||||
# Step 3: cause a find 391 on the second client
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that succeeds.
|
||||
client2.test.test.find_one()
|
||||
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
self.assertEqual(self.request_called, 3)
|
||||
client1.close()
|
||||
client2.close()
|
||||
|
||||
|
||||
@ -48,9 +48,6 @@ def create_test(test_case):
|
||||
if props.get("REQUEST_TOKEN_CALLBACK"):
|
||||
props["request_token_callback"] = lambda x, y: 1
|
||||
del props["REQUEST_TOKEN_CALLBACK"]
|
||||
if props.get("REFRESH_TOKEN_CALLBACK"):
|
||||
props["refresh_token_callback"] = lambda a, b: 1
|
||||
del props["REFRESH_TOKEN_CALLBACK"]
|
||||
client = MongoClient(uri, connect=False, authmechanismproperties=props)
|
||||
credentials = client.options.pool_options._credentials
|
||||
if credential is None:
|
||||
@ -86,10 +83,6 @@ def create_test(test_case):
|
||||
self.assertEqual(
|
||||
actual.request_token_callback, expected["request_token_callback"]
|
||||
)
|
||||
elif "refresh_token_callback" in expected:
|
||||
self.assertEqual(
|
||||
actual.refresh_token_callback, expected["refresh_token_callback"]
|
||||
)
|
||||
else:
|
||||
self.fail(f"Unhandled property: {key}")
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user