PYTHON-3716 OIDC-SASL Follow-Up (#1365)

This commit is contained in:
Steven Silvester 2023-09-28 12:48:36 -05:00 committed by GitHub
parent 9b6f2e18cf
commit 0590ce49ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 349 additions and 726 deletions

View File

@ -79,6 +79,17 @@ functions:
export PATH="$MONGODB_BINARIES:$PATH"
export PROJECT="${project}"
export PIP_QUIET=1
ENSURE_TOOLCHAIN_PYTHON_BINARY: |
# Make sure PYTHON_BINARY is set to a suitable toolchain python.
if [ -z "$PYTHON_BINARY" ]; then
if [ "$(uname -s)" = "Darwin" ]; then
export PYTHON_BINARY=/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3
elif [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin
export PYTHON_BINARY=/cygdrive/c/python/Python39/python
else
export PYTHON_BINARY=/opt/python/3.9/bin/python3
fi
fi
EOT
# Load the expansion file to make an evergreen variable with the current unique version
@ -655,7 +666,7 @@ functions:
.evergreen/run-mongodb-aws-test.sh
fi
"bootstrap oidc":
"run oidc auth test with aws credentials":
- command: ec2.assume_role
params:
role_arn: ${aws_test_secrets_role}
@ -664,58 +675,11 @@ functions:
params:
working_dir: "src"
shell: bash
include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
script: |
${PREPARE_SHELL}
if [ "${skip_EC2_auth_test}" = "true" ]; then
echo "This platform does not support the oidc auth test, skipping..."
exit 0
fi
cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN}
export OIDC_TOKEN_DIR=/tmp/tokens
. ./activate-authoidcvenv.sh
python oidc_write_orchestration.py
python oidc_get_tokens.py
"run oidc auth test with aws credentials":
- command: shell.exec
type: test
params:
working_dir: "src"
shell: bash
script: |
${PREPARE_SHELL}
if [ "${skip_EC2_auth_test}" = "true" ]; then
echo "This platform does not support the oidc auth test, skipping..."
exit 0
fi
cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
mongosh setup_oidc.js
- command: shell.exec
type: test
params:
working_dir: "src"
silent: true
script: |
# DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does)
cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh"
export OIDC_TOKEN_DIR=/tmp/tokens
EOF
- command: shell.exec
type: test
params:
working_dir: "src"
script: |
${PREPARE_SHELL}
if [ "${skip_web_identity_auth_test}" = "true" ]; then
echo "This platform does not support the oidc auth test, skipping..."
exit 0
fi
PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-oidc-test.sh
${ENSURE_TOOLCHAIN_PYTHON_BINARY}
bash .evergreen/run-mongodb-oidc-test.sh
"run aws auth test with aws credentials as environment variables":
- command: shell.exec
@ -2133,16 +2097,7 @@ tasks:
- name: "oidc-auth-test-latest"
commands:
- func: "bootstrap oidc"
- func: "bootstrap mongo-orchestration"
vars:
AUTH: "auth"
ORCHESTRATION_FILE: "auth-oidc.json"
TOPOLOGY: "replica_set"
VERSION: "latest"
- func: "run oidc auth test with aws credentials"
vars:
AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1
- name: load-balancer-test
commands:
@ -3192,9 +3147,8 @@ buildvariants:
- matrix_name: "oidc-auth-test"
matrix_spec:
platform: [ rhel8 ]
python-version: ["3.9"]
display_name: "MONGODB-OIDC Auth ${platform} ${python-version}"
platform: [ rhel8, macos-1100, windows-64-vsMulti-small ]
display_name: "MONGODB-OIDC Auth ${platform}"
tasks:
- name: "oidc-auth-test-latest"

View File

@ -1,56 +1,48 @@
#!/bin/bash
set -o xtrace
set +x # Disable debug trace
set -o errexit # Exit the script with error if any of the commands fail
############################################
# Main Program #
############################################
# Supported/used environment variables:
# MONGODB_URI Set the URI, including an optional username/password to use
# to connect to the server via MONGODB-OIDC authentication
# mechanism.
# PYTHON_BINARY The Python version to use.
echo "Running MONGODB-OIDC authentication tests"
# ensure no secrets are printed in log files
set +x
# load the script
shopt -s expand_aliases # needed for `urlencode` alias
[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh"
MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"}
MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC"
MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true"
if [ -z "${OIDC_TOKEN_DIR}" ]; then
echo "Must specify OIDC_TOKEN_DIR"
# Make sure DRIVERS_TOOLS is set.
if [ -z "$DRIVERS_TOOLS" ]; then
echo "Must specify DRIVERS_TOOLS"
exit 1
fi
export MONGODB_URI_SINGLE="$MONGODB_URI_SINGLE"
export MONGODB_URI_MULTIPLE="$MONGODB_URI_MULTIPLE"
export MONGODB_URI="$MONGODB_URI"
# Get the drivers secrets. Use an existing secrets file first.
if [ ! -f "./secrets-export.sh" ]; then
bash .evergreen/tox.sh -m aws-secrets -- drivers/oidc
fi
source ./secrets-export.sh
echo $MONGODB_URI_SINGLE
echo $MONGODB_URI_MULTIPLE
echo $MONGODB_URI
if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then
if echo "$MONGODB_URI" | grep -q "@"; then
echo "MONGODB_URI unexpectedly contains user credentials!";
exit 1
fi
# # If the file did not have our creds, get them from the vault.
if [ -z "$OIDC_ATLAS_URI_SINGLE" ]; then
bash .evergreen/tox.sh -m aws-secrets -- drivers/oidc
source ./secrets-export.sh
fi
if [ -z "$PYTHON_BINARY" ]; then
echo "Cannot test without specifying PYTHON_BINARY"
exit 1
# Make the OIDC tokens.
set -x
pushd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
. ./oidc_get_tokens.sh
popd
# Set up variables and run the test.
if [ -n "$LOCAL_OIDC_SERVER" ]; then
export MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"}
export MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC"
export MONGODB_URI_MULTI="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true"
else
set +x # turn off xtrace for this portion
export MONGODB_URI="$OIDC_ATLAS_URI_SINGLE"
export MONGODB_URI_SINGLE="$OIDC_ATLAS_URI_SINGLE/?authMechanism=MONGODB-OIDC"
export MONGODB_URI_MULTI="$OIDC_ATLAS_URI_MULTI/?authMechanism=MONGODB-OIDC"
set -x
fi
export TEST_AUTH_OIDC=1
export COVERAGE=1
export AUTH="auth"
export SET_XTRACE_ON=1
bash ./.evergreen/tox.sh -m test-eg

View File

@ -49,6 +49,9 @@ if [ "$AUTH" != "noauth" ]; then
elif [ ! -z "$TEST_SERVERLESS" ]; then
export DB_USER=$SERVERLESS_ATLAS_USER
export DB_PASSWORD=$SERVERLESS_ATLAS_PASSWORD
elif [ ! -z "$TEST_AUTH_OIDC" ]; then
export DB_USER=$OIDC_ALTAS_USER
export DB_PASSWORD=$OIDC_ATLAS_PASSWORD
else
export DB_USER="bob"
export DB_PASSWORD="pwd123"
@ -109,7 +112,7 @@ fi
if [ -n "$TEST_ENCRYPTION" ] || [ -n "$TEST_FLE_AZURE_AUTO" ] || [ -n "$TEST_FLE_GCP_AUTO" ]; then
# Work around for root certifi not being installed.
# TODO: Remove after PYTHON-3827
# TODO: Remove after PYTHON-3952 is deployed.
if [ "$(uname -s)" = "Darwin" ]; then
python -m pip install certifi
CERT_PATH=$(python -c "import certifi; print(certifi.where())")
@ -224,6 +227,17 @@ fi
if [ -n "$TEST_AUTH_OIDC" ]; then
python -m pip install ".[aws]"
# Work around for root certifi not being installed.
# TODO: Remove after PYTHON-3952 is deployed.
if [ "$(uname -s)" = "Darwin" ]; then
python -m pip install certifi
CERT_PATH=$(python -c "import certifi; print(certifi.where())")
export SSL_CERT_FILE=${CERT_PATH}
export REQUESTS_CA_BUNDLE=${CERT_PATH}
export AWS_CA_BUNDLE=${CERT_PATH}
fi
TEST_ARGS="test/auth_oidc/test_auth_oidc.py"
fi

View File

@ -164,11 +164,11 @@ def _build_credentials_tuple(
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
request_token_callback = properties.get("request_token_callback")
refresh_token_callback = properties.get("refresh_token_callback", None)
provider_name = properties.get("PROVIDER_NAME", "")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
"*.mongodb-qa.net",
"*.mongodbgov.net",
"localhost",
"127.0.0.1",
@ -181,11 +181,10 @@ def _build_credentials_tuple(
)
oidc_props = _OIDCProperties(
request_token_callback=request_token_callback,
refresh_token_callback=refresh_token_callback,
provider_name=provider_name,
allowed_hosts=allowed_hosts,
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, None)
return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache())
elif mech == "PLAIN":
source_database = source or database or "$external"

View File

@ -15,27 +15,14 @@
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations
import os
import threading
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
MutableMapping,
Optional,
Tuple,
)
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Optional
import bson
from bson.binary import Binary
from bson.son import SON
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
if TYPE_CHECKING:
from pymongo.auth import MongoCredential
@ -44,39 +31,27 @@ if TYPE_CHECKING:
@dataclass
class _OIDCProperties:
request_token_callback: Optional[Callable[..., Dict]]
refresh_token_callback: Optional[Callable[..., Dict]]
request_token_callback: Optional[Callable[..., dict]]
provider_name: Optional[str]
allowed_hosts: List[str]
allowed_hosts: list[str]
"""Mechanism properties for MONGODB-OIDC authentication."""
TOKEN_BUFFER_MINUTES = 5
CALLBACK_TIMEOUT_SECONDS = 5 * 60
CACHE_TIMEOUT_MINUTES = 60 * 5
CALLBACK_VERSION = 0
_CACHE: Dict[str, "_OIDCAuthenticator"] = {}
CALLBACK_VERSION = 1
def _get_authenticator(
credentials: MongoCredential, address: Tuple[str, int]
credentials: MongoCredential, address: tuple[str, int]
) -> _OIDCAuthenticator:
# Clear out old items in the cache.
now_utc = datetime.now(timezone.utc)
to_remove = []
for key, value in _CACHE.items():
if value.cache_exp_utc is not None and value.cache_exp_utc < now_utc:
to_remove.append(key)
for key in to_remove:
del _CACHE[key]
if credentials.cache.data:
return credentials.cache.data
# Extract values.
principal_name = credentials.username
properties = credentials.mechanism_properties
request_cb = properties.request_token_callback
refresh_cb = properties.refresh_token_callback
# Validate that the address is allowed.
if not properties.provider_name:
@ -92,150 +67,99 @@ def _get_authenticator(
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
)
# Get or create the cache item.
cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}"
_CACHE.setdefault(cache_key, _OIDCAuthenticator(username=principal_name, properties=properties))
return _CACHE[cache_key]
def _get_cache_exp() -> datetime:
return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES)
# Get or create the cache data.
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
return credentials.cache.data
@dataclass
class _OIDCAuthenticator:
username: str
properties: _OIDCProperties
idp_info: Optional[Dict] = field(default=None)
idp_resp: Optional[Dict] = field(default=None)
reauth_gen_id: int = field(default=0)
idp_info_gen_id: int = field(default=0)
refresh_token: Optional[str] = field(default=None)
access_token: Optional[str] = field(default=None)
idp_info: Optional[dict] = field(default=None)
token_gen_id: int = field(default=0)
token_exp_utc: Optional[datetime] = field(default=None)
cache_exp_utc: datetime = field(default_factory=_get_cache_exp)
lock: threading.Lock = field(default_factory=threading.Lock)
def get_current_token(self, use_callbacks: bool = True) -> Optional[str]:
def get_current_token(self, use_callback: bool = True) -> Optional[str]:
properties = self.properties
request_cb = properties.request_token_callback
refresh_cb = properties.refresh_token_callback
if not use_callbacks:
request_cb = None
refresh_cb = None
# TODO: DRIVERS-2672, handle machine callback here as well.
cb = properties.request_token_callback if use_callback else None
cb_type = "human"
current_valid_token = False
if self.token_exp_utc is not None:
now_utc = datetime.now(timezone.utc)
exp_utc = self.token_exp_utc
buffer_seconds = TOKEN_BUFFER_MINUTES * 60
if (exp_utc - now_utc).total_seconds() >= buffer_seconds:
current_valid_token = True
prev_token = self.access_token
if prev_token:
return prev_token
timeout = CALLBACK_TIMEOUT_SECONDS
if not use_callbacks and not current_valid_token:
if not use_callback and not prev_token:
return None
if not current_valid_token and request_cb is not None:
prev_token = self.idp_resp["access_token"] if self.idp_resp else None
if not prev_token and cb is not None:
with self.lock:
# See if the token was changed while we were waiting for the
# lock.
new_token = self.idp_resp["access_token"] if self.idp_resp else None
new_token = self.access_token
if new_token != prev_token:
return new_token
refresh_token = self.idp_resp and self.idp_resp.get("refresh_token")
refresh_token = refresh_token or ""
context = {
"timeout_seconds": timeout,
"version": CALLBACK_VERSION,
"refresh_token": refresh_token,
}
# TODO: DRIVERS-2672 handle machine callback here.
if cb_type == "human":
context = {
"timeout_seconds": CALLBACK_TIMEOUT_SECONDS,
"version": CALLBACK_VERSION,
"refresh_token": self.refresh_token,
}
resp = cb(self.idp_info, context)
self.validate_request_token_response(resp)
if self.idp_resp is None or refresh_cb is None:
self.idp_resp = request_cb(self.idp_info, context)
elif request_cb is not None:
self.idp_resp = refresh_cb(self.idp_info, context)
cache_exp_utc = datetime.now(timezone.utc) + timedelta(
minutes=CACHE_TIMEOUT_MINUTES
)
self.cache_exp_utc = cache_exp_utc
self.token_gen_id += 1
token_result = self.idp_resp
return self.access_token
def validate_request_token_response(self, resp: Mapping[str, Any]) -> None:
# Validate callback return value.
if not isinstance(token_result, dict):
if not isinstance(resp, dict):
raise ValueError("OIDC callback returned invalid result")
if "access_token" not in token_result:
if "access_token" not in resp:
raise ValueError("OIDC callback did not return an access_token")
expected = ["access_token", "expires_in_seconds", "refesh_token"]
for key in token_result:
expected = ["access_token", "refresh_token", "expires_in_seconds"]
for key in resp:
if key not in expected:
raise ValueError(f'Unexpected field in callback result "{key}"')
token = token_result["access_token"]
self.access_token = resp["access_token"]
self.refresh_token = resp.get("refresh_token")
if "expires_in_seconds" in token_result:
expires_in = int(token_result["expires_in_seconds"])
buffer_seconds = TOKEN_BUFFER_MINUTES * 60
if expires_in >= buffer_seconds:
now_utc = datetime.now(timezone.utc)
exp_utc = now_utc + timedelta(seconds=expires_in)
self.token_exp_utc = exp_utc
return token
def auth_start_cmd(self, use_callbacks: bool = True) -> Optional[SON[str, Any]]:
properties = self.properties
# Handle aws provider credentials.
if properties.provider_name == "aws":
aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"]
with open(aws_identity_file) as fid:
token: Optional[str] = fid.read().strip()
payload = {"jwt": token}
cmd = SON(
[
("saslStart", 1),
("mechanism", "MONGODB-OIDC"),
("payload", Binary(bson.encode(payload))),
]
)
return cmd
def principal_step_cmd(self) -> SON[str, Any]:
"""Get a SASL start command with an optional principal name"""
# Send the SASL start with the optional principal name.
payload = {}
principal_name = self.username
if principal_name:
payload["n"] = principal_name
if self.idp_info is not None:
self.cache_exp_utc = datetime.now(timezone.utc) + timedelta(
minutes=CACHE_TIMEOUT_MINUTES
)
cmd = SON(
[
("saslStart", 1),
("mechanism", "MONGODB-OIDC"),
("payload", Binary(bson.encode(payload))),
("autoAuthorize", 1),
]
)
return cmd
def auth_start_cmd(self, use_callback: bool = True) -> Optional[SON[str, Any]]:
# TODO: DRIVERS-2672, check for provider_name in self.properties here.
if self.idp_info is None:
self.cache_exp_utc = _get_cache_exp()
return self.principal_step_cmd()
if self.idp_info is None:
# Send the SASL start with the optional principal name.
payload = {}
if principal_name:
payload["n"] = principal_name
cmd = SON(
[
("saslStart", 1),
("mechanism", "MONGODB-OIDC"),
("payload", Binary(bson.encode(payload))),
("autoAuthorize", 1),
]
)
return cmd
token = self.get_current_token(use_callbacks)
token = self.get_current_token(use_callback)
if not token:
return None
bin_payload = Binary(bson.encode({"jwt": token}))
@ -247,37 +171,59 @@ class _OIDCAuthenticator:
]
)
def clear(self) -> None:
self.idp_info = None
self.idp_resp = None
self.token_exp_utc = None
def run_command(
self, conn: Connection, cmd: MutableMapping[str, Any]
) -> Optional[Mapping[str, Any]]:
try:
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure as exc:
self.clear()
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
if "jwt" in bson.decode(cmd["payload"]):
if self.idp_info_gen_id > self.reauth_gen_id:
raise
return self.authenticate(conn, reauthenticate=True)
except OperationFailure:
self.access_token = None
raise
def authenticate(
self, conn: Connection, reauthenticate: bool = False
) -> Optional[Mapping[str, Any]]:
if reauthenticate:
prev_id = getattr(conn, "oidc_token_gen_id", None)
# Check if we've already changed tokens.
if prev_id == self.token_gen_id:
self.reauth_gen_id = self.idp_info_gen_id
self.token_exp_utc = None
if not self.properties.refresh_token_callback:
self.clear()
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
"""Handle a reauthenticate from the server."""
# First see if we have the a newer token on the authenticator.
prev_id = conn.oidc_token_gen_id or 0
# If we've already changed tokens, make one optimistic attempt.
if (prev_id < self.token_gen_id) and self.access_token:
try:
return self.authenticate(conn)
except OperationFailure:
pass
self.access_token = None
# TODO: DRIVERS-2672, check for provider_name in self.properties here.
# If so, we clear the access token and return finish_auth.
# Next see if the idp info has changed.
prev_idp_info = self.idp_info
self.idp_info = None
cmd = self.principal_step_cmd()
resp = self.run_command(conn, cmd)
assert resp is not None
server_resp: dict = bson.decode(resp["payload"])
if "issuer" in server_resp:
self.idp_info = server_resp
# Handle the case of changed idp info.
if not self.idp_info == prev_idp_info:
self.access_token = None
self.refresh_token = None
# If we have a refresh token, try using that.
if self.refresh_token:
try:
return self.finish_auth(resp, conn)
except OperationFailure:
self.refresh_token = None
# If that fails, try again without the refresh token.
return self.authenticate(conn)
# If we don't have a refresh token, just try once.
return self.finish_auth(resp, conn)
def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
ctx = conn.auth_ctx
cmd = None
@ -293,12 +239,16 @@ class _OIDCAuthenticator:
conn.oidc_token_gen_id = self.token_gen_id
return None
server_resp: Dict = bson.decode(resp["payload"])
server_resp: dict = bson.decode(resp["payload"])
if "issuer" in server_resp:
self.idp_info = server_resp
self.idp_info_gen_id += 1
conversation_id = resp["conversationId"]
return self.finish_auth(resp, conn)
def finish_auth(
self, orig_resp: Mapping[str, Any], conn: Connection
) -> Optional[Mapping[str, Any]]:
conversation_id = orig_resp["conversationId"]
token = self.get_current_token()
conn.oidc_token_gen_id = self.token_gen_id
bin_payload = Binary(bson.encode({"jwt": token}))
@ -312,7 +262,6 @@ class _OIDCAuthenticator:
resp = self.run_command(conn, cmd)
assert resp is not None
if not resp["done"]:
self.clear()
raise OperationFailure("SASL conversation failed to complete.")
return resp
@ -322,4 +271,7 @@ def _authenticate_oidc(
) -> Optional[Mapping[str, Any]]:
"""Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, conn.address)
return authenticator.authenticate(conn, reauthenticate=reauthenticate)
if reauthenticate:
return authenticator.reauthenticate(conn)
else:
return authenticator.authenticate(conn)

View File

@ -440,8 +440,6 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
signature = inspect.signature(value)
if key == "request_token_callback":
expected_params = 2
elif key == "refresh_token_callback":
expected_params = 2
else:
raise ValueError(f"Unrecognized Auth mechanism function {key}")
if len(signature.parameters) != expected_params:

View File

@ -475,22 +475,6 @@
}
}
},
{
"description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
"callback": ["oidcRequest", "oidcRefresh"],
"valid": true,
"credential": {
"username": null,
"password": null,
"source": "$external",
"mechanism": "MONGODB-OIDC",
"mechanism_properties": {
"REQUEST_TOKEN_CALLBACK": true,
"REFRESH_TOKEN_CALLBACK": true
}
}
},
{
"description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)",
"uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC",
@ -554,18 +538,11 @@
"credential": null
},
{
"description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)",
"description": "should throw an exception if neither deviceName nor callback specified (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
"valid": false,
"credential": null
},
{
"description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
"callback": ["oidcRefresh"],
"valid": false,
"credential": null
},
{
"description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted",
@ -573,4 +550,4 @@
"credential": null
}
]
}
}

View File

@ -16,7 +16,6 @@
import os
import sys
import threading
import time
import unittest
from contextlib import contextmanager
@ -29,7 +28,6 @@ from test.utils import EventListener
from bson import SON
from pymongo import MongoClient
from pymongo.auth import _AUTH_MAP, _authenticate_oidc
from pymongo.auth_oidc import _CACHE as _oidc_cache
from pymongo.cursor import CursorType
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.hello import HelloCompat
@ -45,19 +43,16 @@ class TestAuthOIDC(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.uri_single = os.environ["MONGODB_URI_SINGLE"]
cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"]
cls.uri_multiple = os.environ["MONGODB_URI_MULTI"]
cls.uri_admin = os.environ["MONGODB_URI"]
cls.token_dir = os.environ["OIDC_TOKEN_DIR"]
def setUp(self):
self.request_called = 0
self.refresh_called = 0
_oidc_cache.clear()
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1")
def create_request_cb(self, username="test_user1", expires_in_seconds=None, sleep=0):
def create_request_cb(self, username="test_user1", sleep=0):
token_file = os.path.join(self.token_dir, username)
token_file = os.path.join(self.token_dir, username).replace(os.sep, "/")
def request_token(server_info, context):
# Validate the info.
@ -69,43 +64,14 @@ class TestAuthOIDC(unittest.TestCase):
self.assertEqual(timeout_seconds, 60 * 5)
with open(token_file) as fid:
token = fid.read()
resp = {"access_token": token}
resp = {"access_token": token, "refresh_token": token}
time.sleep(sleep)
if expires_in_seconds is not None:
resp["expires_in_seconds"] = expires_in_seconds
self.request_called += 1
return resp
return request_token
def create_refresh_cb(self, username="test_user1", expires_in_seconds=None):
token_file = os.path.join(self.token_dir, username)
def refresh_token(server_info, context):
with open(token_file) as fid:
token = fid.read()
# Validate the info.
self.assertIn("issuer", server_info)
self.assertIn("clientId", server_info)
# Validate the creds
self.assertIsNotNone(context["refresh_token"])
# Validate the timeout.
self.assertEqual(context["timeout_seconds"], 60 * 5)
resp = {"access_token": token}
if expires_in_seconds is not None:
resp["expires_in_seconds"] = expires_in_seconds
self.refresh_called += 1
return resp
return refresh_token
@contextmanager
def fail_point(self, command_args):
cmd_on = SON([("configureFailPoint", "failCommand")])
@ -117,21 +83,21 @@ class TestAuthOIDC(unittest.TestCase):
finally:
client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off")
def test_connect_callbacks_single_implicit_username(self):
def test_connect_request_callback_single_implicit_username(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_callbacks_single_explicit_username(self):
def test_connect_request_callback_single_explicit_username(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_principal_user1(self):
def test_connect_request_callback_multiple_principal_user1(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token}
client = MongoClient(
@ -140,7 +106,7 @@ class TestAuthOIDC(unittest.TestCase):
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_principal_user2(self):
def test_connect_request_callback_multiple_principal_user2(self):
request_token = self.create_request_cb("test_user2")
props: Dict = {"request_token_callback": request_token}
client = MongoClient(
@ -149,7 +115,7 @@ class TestAuthOIDC(unittest.TestCase):
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_no_username(self):
def test_connect_request_callback_multiple_no_username(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
@ -173,38 +139,11 @@ class TestAuthOIDC(unittest.TestCase):
client.test.test.find_one()
client.close()
def test_connect_aws_single_principal(self):
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_multiple_principal_user1(self):
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_multiple_principal_user2(self):
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2")
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_allowed_hosts_ignored(self):
props = {"PROVIDER_NAME": "aws", "allowed_hosts": []}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_valid_callbacks(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
refresh_cb = self.create_refresh_cb()
def test_valid_request_token_callback(self):
request_cb = self.create_request_cb()
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb,
}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
@ -214,31 +153,6 @@ class TestAuthOIDC(unittest.TestCase):
client.test.test.find_one()
client.close()
def test_lock_avoids_extra_callbacks(self):
request_cb = self.create_request_cb(sleep=0.5)
refresh_cb = self.create_refresh_cb()
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
def run_test():
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
t1 = threading.Thread(target=run_test)
t2 = threading.Thread(target=run_test)
t1.start()
t2.start()
t1.join()
t2.join()
self.assertEqual(self.request_called, 1)
self.assertEqual(self.refresh_called, 2)
def test_request_callback_returns_null(self):
def request_token_null(a, b):
return None
@ -249,25 +163,6 @@ class TestAuthOIDC(unittest.TestCase):
client.test.test.find_one()
client.close()
def test_refresh_callback_returns_null(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_token_null(a, b):
return None
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_token_null,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_request_callback_invalid_result(self):
def request_token_invalid(a, b):
return {}
@ -289,166 +184,10 @@ class TestAuthOIDC(unittest.TestCase):
client.test.test.find_one()
client.close()
def test_refresh_callback_missing_data(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_cb_no_token(a, b):
return {}
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb_no_token,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_refresh_callback_extra_data(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_cb_extra_value(server_info, context):
result = self.create_refresh_cb()(server_info, context)
result["foo"] = "bar"
return result
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb_extra_value,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_cache_with_refresh(self):
# Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token.
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
request_cb = self.create_request_cb(expires_in_seconds=60)
refresh_cb = self.create_refresh_cb()
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Ensure that a ``find`` operation adds credentials to the cache.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(len(_oidc_cache), 1)
# Create a new client with the same request callback and a refresh callback.
# Ensure that a ``find`` operation results in a call to the refresh callback.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(len(_oidc_cache), 1)
def test_cache_with_no_refresh(self):
# Create a new client with a request callback callback.
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
request_cb = self.create_request_cb()
props = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that a ``find`` operation adds credentials to the cache.
self.request_called = 0
client.test.test.find_one()
client.close()
self.assertEqual(self.request_called, 1)
self.assertEqual(len(_oidc_cache), 1)
# Create a new client with the same request callback.
# Ensure that a ``find`` operation results in a call to the request callback.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(self.request_called, 2)
self.assertEqual(len(_oidc_cache), 1)
def test_cache_key_includes_callback(self):
request_cb = self.create_request_cb()
props: Dict = {"request_token_callback": request_cb}
# Ensure that a ``find`` operation adds a new entry to the cache.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
# Create a new client with a different request callback.
def request_token_2(a, b):
return request_cb(a, b)
props["request_token_callback"] = request_token_2
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that a ``find`` operation adds a new entry to the cache.
client.test.test.find_one()
client.close()
self.assertEqual(len(_oidc_cache), 2)
def test_cache_clears_on_error(self):
request_cb = self.create_request_cb()
# Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials.
def refresh_cb(a, b):
return {"access_token": "bad"}
# Add a token to the cache that will expire soon.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
# Create a new client with the same callbacks.
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that another ``find`` operation results in an error.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
client.close()
# Ensure that the cache has been cleared.
authenticator = list(_oidc_cache.values())[0]
self.assertIsNone(authenticator.idp_info)
def test_cache_is_not_used_in_aws_automatic_workflow(self):
# Create a new client using the AWS device workflow.
# Ensure that a ``find`` operation does not add credentials to the cache.
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
# Ensure that the cache has been cleared.
authenticator = list(_oidc_cache.values())[0]
self.assertIsNone(authenticator.idp_info)
def test_speculative_auth_success(self):
# Clear the cache
_oidc_cache.clear()
token_file = os.path.join(self.token_dir, "test_user1")
request_token = self.create_request_cb()
def request_token(a, b):
with open(token_file) as fid:
token = fid.read()
return {"access_token": token, "expires_in_seconds": 1000}
# Create a client with a request callback that returns a valid token
# that will not expire soon.
# Create a client with a request callback that returns a valid token.
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, authmechanismproperties=props)
@ -465,32 +204,14 @@ class TestAuthOIDC(unittest.TestCase):
# Close the client.
client.close()
# Create a new client.
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Set a fail point for saslStart commands.
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["saslStart"], "errorCode": 18},
}
):
# Perform a find operation.
client.test.test.find_one()
# Close the client.
client.close()
def test_reauthenticate_succeeds(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
# Create request callback that returns valid credentials.
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
@ -498,8 +219,8 @@ class TestAuthOIDC(unittest.TestCase):
# Perform a find operation.
client.test.test.find_one()
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
listener.reset()
@ -534,23 +255,109 @@ class TestAuthOIDC(unittest.TestCase):
self.assertEqual(succeeded_events, ["find"])
self.assertEqual(failed_events, ["find"])
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_succeeds_bulk_write(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
def test_reauthenticate_succeeds_no_refresh(self):
cb = self.create_request_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
def request_cb(*args, **kwargs):
result = cb(*args, **kwargs)
del result["refresh_token"]
return result
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
client.test.test.find_one()
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation.
client.test.test.find_one()
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_succeeds_after_refresh_fails(self):
# Create request callback that returns valid credentials.
request_cb = self.create_request_cb()
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
client.test.test.find_one()
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find", "saslContinue"], "errorCode": 391},
}
):
# Perform a find operation.
client.test.test.find_one()
# Assert that the request callback has been called three times.
self.assertEqual(self.request_called, 3)
def test_reauthenticate_fails(self):
# Create request callback that returns valid credentials.
request_cb = self.create_request_cb()
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
client.test.test.find_one()
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that fails.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_succeeds_bulk_write(self):
request_cb = self.create_request_cb()
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
client.test.test.find_one()
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
@ -561,16 +368,15 @@ class TestAuthOIDC(unittest.TestCase):
# Perform a bulk write operation.
client.test.test.bulk_write([InsertOne({})])
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_succeeds_bulk_read(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
@ -579,8 +385,8 @@ class TestAuthOIDC(unittest.TestCase):
# Perform a bulk write operation.
client.test.test.bulk_write([InsertOne({})])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
@ -592,23 +398,22 @@ class TestAuthOIDC(unittest.TestCase):
cursor = client.test.test.find_raw_batches({})
list(cursor)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_succeeds_cursor(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_one({"a": 1})
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
@ -620,23 +425,22 @@ class TestAuthOIDC(unittest.TestCase):
cursor = client.test.test.find({"a": 1})
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_succeeds_get_more(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_many([{"a": 1}, {"a": 1}])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
@ -648,30 +452,29 @@ class TestAuthOIDC(unittest.TestCase):
cursor = client.test.test.find({"a": 1}, batch_size=1)
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_succeeds_get_more_exhaust(self):
# Ensure no mongos
props = {"PROVIDER_NAME": "aws"}
props = {"request_token_callback": self.create_request_cb()}
client = MongoClient(self.uri_single, authmechanismproperties=props)
hello = client.admin.command(HelloCompat.LEGACY_CMD)
if hello.get("msg") != "isdbgrid":
raise unittest.SkipTest("Must not be a mongos")
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_many([{"a": 1}, {"a": 1}])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
@ -683,16 +486,15 @@ class TestAuthOIDC(unittest.TestCase):
cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST)
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_succeeds_command(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
print("start of test")
client = MongoClient(self.uri_single, authmechanismproperties=props)
@ -700,8 +502,8 @@ class TestAuthOIDC(unittest.TestCase):
# Perform an insert operation.
client.test.test.insert_one({"a": 1})
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
# Assert that the request callback has been called once.
self.assertEqual(self.request_called, 1)
with self.fail_point(
{
@ -714,112 +516,54 @@ class TestAuthOIDC(unittest.TestCase):
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
client.close()
def test_reauthenticate_retries_and_succeeds_with_cache(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
def test_reauthentication_succeeds_multiple_connections(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
# Create a client with the callback.
props: Dict = {"request_token_callback": request_cb}
client1 = MongoClient(self.uri_single, authmechanismproperties=props)
client2 = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client1.test.test.insert_many([{"a": 1}, {"a": 1}])
client2.test.test.find_one()
self.assertEqual(self.request_called, 2)
# Use the same authenticator for both clients
# to simulate a race condition with separate connections.
# We should only see one extra callback despite both connections
# needing to reauthenticate.
client2.options.pool_options._credentials.cache.data = (
client1.options.pool_options._credentials.cache.data
)
# Perform a find operation.
client.test.test.find_one()
# Set a fail point for ``saslStart`` commands of the form
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client.test.test.find_one()
# Close the client.
client.close()
def test_reauthenticate_fails_with_no_cache(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
# Perform a find operation.
client.test.test.find_one()
# Clear the cache.
_oidc_cache.clear()
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
}
):
# Perform a find operation that fails.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
client.close()
def test_late_reauth_avoids_callback(self):
# Step 1: connect with both clients
request_cb = self.create_request_cb(expires_in_seconds=1e6)
refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client1 = MongoClient(self.uri_single, authMechanismProperties=props)
client1.test.test.find_one()
client2 = MongoClient(self.uri_single, authMechanismProperties=props)
client2.test.test.find_one()
self.assertEqual(self.refresh_called, 0)
self.assertEqual(self.request_called, 1)
# Step 2: cause a find 391 on the first client
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client1.test.test.find_one()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(self.request_called, 1)
self.assertEqual(self.request_called, 3)
# Step 3: cause a find 391 on the second client
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client2.test.test.find_one()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(self.request_called, 1)
self.assertEqual(self.request_called, 3)
client1.close()
client2.close()

View File

@ -48,9 +48,6 @@ def create_test(test_case):
if props.get("REQUEST_TOKEN_CALLBACK"):
props["request_token_callback"] = lambda x, y: 1
del props["REQUEST_TOKEN_CALLBACK"]
if props.get("REFRESH_TOKEN_CALLBACK"):
props["refresh_token_callback"] = lambda a, b: 1
del props["REFRESH_TOKEN_CALLBACK"]
client = MongoClient(uri, connect=False, authmechanismproperties=props)
credentials = client.options.pool_options._credentials
if credential is None:
@ -86,10 +83,6 @@ def create_test(test_case):
self.assertEqual(
actual.request_token_callback, expected["request_token_callback"]
)
elif "refresh_token_callback" in expected:
self.assertEqual(
actual.refresh_token_callback, expected["refresh_token_callback"]
)
else:
self.fail(f"Unhandled property: {key}")
else: