PYTHON-3460 Implement OIDC SASL mechanism (#1138)
This commit is contained in:
parent
d504322a74
commit
afd7e1c2cd
@ -749,6 +749,68 @@ functions:
|
||||
fi
|
||||
PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-aws-test.sh
|
||||
|
||||
"bootstrap oidc":
|
||||
- command: ec2.assume_role
|
||||
params:
|
||||
role_arn: ${aws_test_secrets_role}
|
||||
- command: shell.exec
|
||||
type: test
|
||||
params:
|
||||
working_dir: "src"
|
||||
shell: bash
|
||||
script: |
|
||||
${PREPARE_SHELL}
|
||||
if [ "${skip_EC2_auth_test}" = "true" ]; then
|
||||
echo "This platform does not support the oidc auth test, skipping..."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
|
||||
export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
|
||||
export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
|
||||
export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN}
|
||||
export OIDC_TOKEN_DIR=/tmp/tokens
|
||||
|
||||
. ./activate-authoidcvenv.sh
|
||||
python oidc_write_orchestration.py
|
||||
python oidc_get_tokens.py
|
||||
|
||||
"run oidc auth test with aws credentials":
|
||||
- command: shell.exec
|
||||
type: test
|
||||
params:
|
||||
working_dir: "src"
|
||||
shell: bash
|
||||
script: |
|
||||
${PREPARE_SHELL}
|
||||
if [ "${skip_EC2_auth_test}" = "true" ]; then
|
||||
echo "This platform does not support the oidc auth test, skipping..."
|
||||
exit 0
|
||||
fi
|
||||
cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
|
||||
mongosh setup_oidc.js
|
||||
- command: shell.exec
|
||||
type: test
|
||||
params:
|
||||
working_dir: "src"
|
||||
silent: true
|
||||
script: |
|
||||
# DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does)
|
||||
cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh"
|
||||
export OIDC_TOKEN_DIR=/tmp/tokens
|
||||
EOF
|
||||
- command: shell.exec
|
||||
type: test
|
||||
params:
|
||||
working_dir: "src"
|
||||
script: |
|
||||
${PREPARE_SHELL}
|
||||
if [ "${skip_web_identity_auth_test}" = "true" ]; then
|
||||
echo "This platform does not support the oidc auth test, skipping..."
|
||||
exit 0
|
||||
fi
|
||||
PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-oidc-test.sh
|
||||
|
||||
"run aws auth test with aws credentials as environment variables":
|
||||
- command: shell.exec
|
||||
type: test
|
||||
@ -2034,6 +2096,19 @@ tasks:
|
||||
- func: "run aws auth test with aws web identity credentials"
|
||||
- func: "run aws ECS auth test"
|
||||
|
||||
- name: "oidc-auth-test-latest"
|
||||
commands:
|
||||
- func: "bootstrap oidc"
|
||||
- func: "bootstrap mongo-orchestration"
|
||||
vars:
|
||||
AUTH: "auth"
|
||||
ORCHESTRATION_FILE: "auth-oidc.json"
|
||||
TOPOLOGY: "replica_set"
|
||||
VERSION: "latest"
|
||||
- func: "run oidc auth test with aws credentials"
|
||||
vars:
|
||||
AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1
|
||||
|
||||
- name: load-balancer-test
|
||||
commands:
|
||||
- func: "bootstrap mongo-orchestration"
|
||||
@ -3103,6 +3178,14 @@ buildvariants:
|
||||
# macOS MongoDB servers do not staple OCSP responses and only support RSA.
|
||||
- name: ".ocsp-rsa !.ocsp-staple"
|
||||
|
||||
- matrix_name: "oidc-auth-test"
|
||||
matrix_spec:
|
||||
platform: [ ubuntu-20.04 ]
|
||||
python-version: ["3.9"]
|
||||
display_name: "MONGODB-OIDC Auth ${platform} ${python-version}"
|
||||
tasks:
|
||||
- name: "oidc-auth-test-latest"
|
||||
|
||||
- matrix_name: "aws-auth-test"
|
||||
matrix_spec:
|
||||
platform: [ubuntu-20.04]
|
||||
|
||||
@ -70,6 +70,9 @@ for spec in "$@"
|
||||
do
|
||||
# Match the spec dir name, the python test dir name, and/or common abbreviations.
|
||||
case "$spec" in
|
||||
auth)
|
||||
cpjson auth/tests/ auth
|
||||
;;
|
||||
atlas-data-lake-testing|data_lake)
|
||||
cpjson atlas-data-lake-testing/tests/ data_lake
|
||||
;;
|
||||
|
||||
85
.evergreen/run-mongodb-oidc-test.sh
Executable file
85
.evergreen/run-mongodb-oidc-test.sh
Executable file
@ -0,0 +1,85 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -o xtrace
|
||||
set -o errexit # Exit the script with error if any of the commands fail
|
||||
|
||||
############################################
|
||||
# Main Program #
|
||||
############################################
|
||||
|
||||
# Supported/used environment variables:
|
||||
# MONGODB_URI Set the URI, including an optional username/password to use
|
||||
# to connect to the server via MONGODB-OIDC authentication
|
||||
# mechanism.
|
||||
# PYTHON_BINARY The Python version to use.
|
||||
|
||||
echo "Running MONGODB-OIDC authentication tests"
|
||||
# ensure no secrets are printed in log files
|
||||
set +x
|
||||
|
||||
# load the script
|
||||
shopt -s expand_aliases # needed for `urlencode` alias
|
||||
[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh"
|
||||
|
||||
MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"}
|
||||
MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC"
|
||||
MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true"
|
||||
|
||||
if [ -z "${OIDC_TOKEN_DIR}" ]; then
|
||||
echo "Must specify OIDC_TOKEN_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export MONGODB_URI_SINGLE="$MONGODB_URI_SINGLE"
|
||||
export MONGODB_URI_MULTIPLE="$MONGODB_URI_MULTIPLE"
|
||||
export MONGODB_URI="$MONGODB_URI"
|
||||
|
||||
echo $MONGODB_URI_SINGLE
|
||||
echo $MONGODB_URI_MULTIPLE
|
||||
echo $MONGODB_URI
|
||||
|
||||
if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then
|
||||
if echo "$MONGODB_URI" | grep -q "@"; then
|
||||
echo "MONGODB_URI unexpectedly contains user credentials!";
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# show test output
|
||||
set -x
|
||||
|
||||
# Workaround macOS python 3.9 incompatibility with system virtualenv.
|
||||
if [ "$(uname -s)" = "Darwin" ]; then
|
||||
VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.9/bin/python3 -m virtualenv"
|
||||
else
|
||||
VIRTUALENV=$(command -v virtualenv)
|
||||
fi
|
||||
|
||||
authtest () {
|
||||
if [ "Windows_NT" = "$OS" ]; then
|
||||
PYTHON=$(cygpath -m $PYTHON)
|
||||
fi
|
||||
|
||||
echo "Running MONGODB-OIDC authentication tests with $PYTHON"
|
||||
$PYTHON --version
|
||||
|
||||
$VIRTUALENV -p $PYTHON --never-download venvoidc
|
||||
if [ "Windows_NT" = "$OS" ]; then
|
||||
. venvoidc/Scripts/activate
|
||||
else
|
||||
. venvoidc/bin/activate
|
||||
fi
|
||||
python -m pip install -U pip setuptools
|
||||
python -m pip install '.[aws]'
|
||||
python test/auth_aws/test_auth_oidc.py -v
|
||||
deactivate
|
||||
rm -rf venvoidc
|
||||
}
|
||||
|
||||
PYTHON=${PYTHON_BINARY:-}
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "Cannot test without specifying PYTHON_BINARY"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
authtest
|
||||
@ -27,6 +27,7 @@ from urllib.parse import quote
|
||||
from bson.binary import Binary
|
||||
from bson.son import SON
|
||||
from pymongo.auth_aws import _authenticate_aws
|
||||
from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCProperties
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
|
||||
@ -48,6 +49,7 @@ MECHANISMS = frozenset(
|
||||
[
|
||||
"GSSAPI",
|
||||
"MONGODB-CR",
|
||||
"MONGODB-OIDC",
|
||||
"MONGODB-X509",
|
||||
"MONGODB-AWS",
|
||||
"PLAIN",
|
||||
@ -101,7 +103,7 @@ _AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
|
||||
|
||||
def _build_credentials_tuple(mech, source, user, passwd, extra, database):
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS") and user is None:
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError("%s requires a username." % (mech,))
|
||||
if mech == "GSSAPI":
|
||||
if source is not None and source != "$external":
|
||||
@ -137,6 +139,32 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
|
||||
aws_props = _AWSProperties(aws_session_token=aws_session_token)
|
||||
# user can be None for temporary link-local EC2 credentials.
|
||||
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
|
||||
elif mech == "MONGODB-OIDC":
|
||||
properties = extra.get("authmechanismproperties", {})
|
||||
request_token_callback = properties.get("request_token_callback")
|
||||
refresh_token_callback = properties.get("refresh_token_callback", None)
|
||||
provider_name = properties.get("PROVIDER_NAME", "")
|
||||
default_allowed = [
|
||||
"*.mongodb.net",
|
||||
"*.mongodb-dev.net",
|
||||
"*.mongodbgov.net",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
]
|
||||
allowed_hosts = properties.get("allowed_hosts", default_allowed)
|
||||
if not request_token_callback and provider_name != "aws":
|
||||
raise ConfigurationError(
|
||||
"authentication with MONGODB-OIDC requires providing an request_token_callback or a provider_name of 'aws'"
|
||||
)
|
||||
oidc_props = _OIDCProperties(
|
||||
request_token_callback=request_token_callback,
|
||||
refresh_token_callback=refresh_token_callback,
|
||||
provider_name=provider_name,
|
||||
allowed_hosts=allowed_hosts,
|
||||
)
|
||||
return MongoCredential(mech, "$external", user, passwd, oidc_props, None)
|
||||
|
||||
elif mech == "PLAIN":
|
||||
source_database = source or database or "$external"
|
||||
return MongoCredential(mech, source_database, user, passwd, None, None)
|
||||
@ -439,7 +467,7 @@ def _authenticate_x509(credentials, sock_info):
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
|
||||
cmd = _X509Context(credentials).speculate_command()
|
||||
cmd = _X509Context(credentials, sock_info.address).speculate_command()
|
||||
sock_info.command("$external", cmd)
|
||||
|
||||
|
||||
@ -482,6 +510,7 @@ _AUTH_MAP: Mapping[str, Callable] = {
|
||||
"MONGODB-CR": _authenticate_mongo_cr,
|
||||
"MONGODB-X509": _authenticate_x509,
|
||||
"MONGODB-AWS": _authenticate_aws,
|
||||
"MONGODB-OIDC": _authenticate_oidc,
|
||||
"PLAIN": _authenticate_plain,
|
||||
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
|
||||
@ -490,15 +519,16 @@ _AUTH_MAP: Mapping[str, Callable] = {
|
||||
|
||||
|
||||
class _AuthContext(object):
|
||||
def __init__(self, credentials):
|
||||
def __init__(self, credentials, address):
|
||||
self.credentials = credentials
|
||||
self.speculative_authenticate = None
|
||||
self.address = address
|
||||
|
||||
@staticmethod
|
||||
def from_credentials(creds):
|
||||
def from_credentials(creds, address):
|
||||
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
|
||||
if spec_cls:
|
||||
return spec_cls(creds)
|
||||
return spec_cls(creds, address)
|
||||
return None
|
||||
|
||||
def speculate_command(self):
|
||||
@ -512,8 +542,8 @@ class _AuthContext(object):
|
||||
|
||||
|
||||
class _ScramContext(_AuthContext):
|
||||
def __init__(self, credentials, mechanism):
|
||||
super(_ScramContext, self).__init__(credentials)
|
||||
def __init__(self, credentials, address, mechanism):
|
||||
super(_ScramContext, self).__init__(credentials, address)
|
||||
self.scram_data = None
|
||||
self.mechanism = mechanism
|
||||
|
||||
@ -534,16 +564,30 @@ class _X509Context(_AuthContext):
|
||||
return cmd
|
||||
|
||||
|
||||
class _OIDCContext(_AuthContext):
|
||||
def speculate_command(self):
|
||||
authenticator = _get_authenticator(self.credentials, self.address)
|
||||
cmd = authenticator.auth_start_cmd(False)
|
||||
if cmd is None:
|
||||
return
|
||||
cmd["db"] = self.credentials.source
|
||||
return cmd
|
||||
|
||||
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
|
||||
"MONGODB-X509": _X509Context,
|
||||
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
"MONGODB-OIDC": _OIDCContext,
|
||||
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
}
|
||||
|
||||
|
||||
def authenticate(credentials, sock_info):
|
||||
def authenticate(credentials, sock_info, reauthenticate=False):
|
||||
"""Authenticate sock_info."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
auth_func(credentials, sock_info)
|
||||
if mechanism == "MONGODB-OIDC":
|
||||
_authenticate_oidc(credentials, sock_info, reauthenticate)
|
||||
else:
|
||||
auth_func(credentials, sock_info)
|
||||
|
||||
299
pymongo/auth_oidc.py
Normal file
299
pymongo/auth_oidc.py
Normal file
@ -0,0 +1,299 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from bson.son import SON
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCProperties:
|
||||
request_token_callback: Optional[Callable[..., Dict]]
|
||||
refresh_token_callback: Optional[Callable[..., Dict]]
|
||||
provider_name: Optional[str]
|
||||
allowed_hosts: List[str]
|
||||
|
||||
|
||||
"""Mechanism properties for MONGODB-OIDC authentication."""
|
||||
|
||||
TOKEN_BUFFER_MINUTES = 5
|
||||
CALLBACK_TIMEOUT_SECONDS = 5 * 60
|
||||
CACHE_TIMEOUT_MINUTES = 60 * 5
|
||||
CALLBACK_VERSION = 0
|
||||
|
||||
_CACHE: Dict[str, "_OIDCAuthenticator"] = {}
|
||||
|
||||
|
||||
def _get_authenticator(credentials, address):
|
||||
# Clear out old items in the cache.
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
to_remove = []
|
||||
for key, value in _CACHE.items():
|
||||
if value.cache_exp_utc is not None and value.cache_exp_utc < now_utc:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del _CACHE[key]
|
||||
|
||||
# Extract values.
|
||||
principal_name = credentials.username
|
||||
properties = credentials.mechanism_properties
|
||||
request_cb = properties.request_token_callback
|
||||
refresh_cb = properties.refresh_token_callback
|
||||
|
||||
# Validate that the address is allowed.
|
||||
if not properties.provider_name:
|
||||
found = False
|
||||
allowed_hosts = properties.allowed_hosts
|
||||
for patt in allowed_hosts:
|
||||
if patt == address[0]:
|
||||
found = True
|
||||
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
|
||||
found = True
|
||||
if not found:
|
||||
raise ConfigurationError(
|
||||
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
|
||||
)
|
||||
|
||||
# Get or create the cache item.
|
||||
cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}"
|
||||
_CACHE.setdefault(cache_key, _OIDCAuthenticator(username=principal_name, properties=properties))
|
||||
|
||||
return _CACHE[cache_key]
|
||||
|
||||
|
||||
def _get_cache_exp():
|
||||
return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
properties: _OIDCProperties
|
||||
idp_info: Optional[Dict] = field(default=None)
|
||||
idp_resp: Optional[Dict] = field(default=None)
|
||||
reauth_gen_id: int = field(default=0)
|
||||
idp_info_gen_id: int = field(default=0)
|
||||
token_gen_id: int = field(default=0)
|
||||
token_exp_utc: Optional[datetime] = field(default=None)
|
||||
cache_exp_utc: datetime = field(default_factory=_get_cache_exp)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def get_current_token(self, use_callbacks=True):
|
||||
properties = self.properties
|
||||
|
||||
request_cb = properties.request_token_callback
|
||||
refresh_cb = properties.refresh_token_callback
|
||||
if not use_callbacks:
|
||||
request_cb = None
|
||||
refresh_cb = None
|
||||
|
||||
current_valid_token = False
|
||||
if self.token_exp_utc is not None:
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
exp_utc = self.token_exp_utc
|
||||
buffer_seconds = TOKEN_BUFFER_MINUTES * 60
|
||||
if (exp_utc - now_utc).total_seconds() >= buffer_seconds:
|
||||
current_valid_token = True
|
||||
|
||||
timeout = CALLBACK_TIMEOUT_SECONDS
|
||||
|
||||
if not use_callbacks and not current_valid_token:
|
||||
return None
|
||||
|
||||
if not current_valid_token and request_cb is not None:
|
||||
prev_token = self.idp_resp and self.idp_resp["access_token"]
|
||||
with self.lock:
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.idp_resp and self.idp_resp["access_token"]
|
||||
if new_token != prev_token:
|
||||
return new_token
|
||||
|
||||
refresh_token = self.idp_resp and self.idp_resp.get("refresh_token")
|
||||
refresh_token = refresh_token or ""
|
||||
context = dict(
|
||||
timeout_seconds=timeout,
|
||||
version=CALLBACK_VERSION,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
if self.idp_resp is None or refresh_cb is None:
|
||||
self.idp_resp = request_cb(self.idp_info, context)
|
||||
elif request_cb is not None:
|
||||
self.idp_resp = refresh_cb(self.idp_info, context)
|
||||
cache_exp_utc = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=CACHE_TIMEOUT_MINUTES
|
||||
)
|
||||
self.cache_exp_utc = cache_exp_utc
|
||||
self.token_gen_id += 1
|
||||
|
||||
token_result = self.idp_resp
|
||||
|
||||
# Validate callback return value.
|
||||
if not isinstance(token_result, dict):
|
||||
raise ValueError("OIDC callback returned invalid result")
|
||||
|
||||
if "access_token" not in token_result:
|
||||
raise ValueError("OIDC callback did not return an access_token")
|
||||
|
||||
expected = ["access_token", "expires_in_seconds", "refesh_token"]
|
||||
for key in token_result:
|
||||
if key not in expected:
|
||||
raise ValueError(f'Unexpected field in callback result "{key}"')
|
||||
|
||||
token = token_result["access_token"]
|
||||
|
||||
if "expires_in_seconds" in token_result:
|
||||
expires_in = int(token_result["expires_in_seconds"])
|
||||
buffer_seconds = TOKEN_BUFFER_MINUTES * 60
|
||||
if expires_in >= buffer_seconds:
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
exp_utc = now_utc + timedelta(seconds=expires_in)
|
||||
self.token_exp_utc = exp_utc
|
||||
|
||||
return token
|
||||
|
||||
def auth_start_cmd(self, use_callbacks=True):
|
||||
properties = self.properties
|
||||
|
||||
# Handle aws provider credentials.
|
||||
if properties.provider_name == "aws":
|
||||
aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"]
|
||||
with open(aws_identity_file) as fid:
|
||||
token = fid.read().strip()
|
||||
payload = dict(jwt=token)
|
||||
cmd = SON(
|
||||
[
|
||||
("saslStart", 1),
|
||||
("mechanism", "MONGODB-OIDC"),
|
||||
("payload", Binary(bson.encode(payload))),
|
||||
]
|
||||
)
|
||||
return cmd
|
||||
|
||||
principal_name = self.username
|
||||
|
||||
if self.idp_info is not None:
|
||||
self.cache_exp_utc = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=CACHE_TIMEOUT_MINUTES
|
||||
)
|
||||
|
||||
if self.idp_info is None:
|
||||
self.cache_exp_utc = _get_cache_exp()
|
||||
|
||||
if self.idp_info is None:
|
||||
# Send the SASL start with the optional principal name.
|
||||
payload = dict()
|
||||
|
||||
if principal_name:
|
||||
payload["n"] = principal_name
|
||||
|
||||
cmd = SON(
|
||||
[
|
||||
("saslStart", 1),
|
||||
("mechanism", "MONGODB-OIDC"),
|
||||
("payload", Binary(bson.encode(payload))),
|
||||
("autoAuthorize", 1),
|
||||
]
|
||||
)
|
||||
return cmd
|
||||
|
||||
token = self.get_current_token(use_callbacks)
|
||||
if not token:
|
||||
return None
|
||||
bin_payload = Binary(bson.encode(dict(jwt=token)))
|
||||
return SON(
|
||||
[
|
||||
("saslStart", 1),
|
||||
("mechanism", "MONGODB-OIDC"),
|
||||
("payload", bin_payload),
|
||||
]
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
self.idp_info = None
|
||||
self.idp_resp = None
|
||||
self.token_exp_utc = None
|
||||
|
||||
def run_command(self, sock_info, cmd):
|
||||
try:
|
||||
return sock_info.command("$external", cmd, no_reauth=True)
|
||||
except OperationFailure as exc:
|
||||
self.clear()
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined]
|
||||
if self.idp_info_gen_id > self.reauth_gen_id:
|
||||
raise
|
||||
return self.authenticate(sock_info, reauthenticate=True)
|
||||
raise
|
||||
|
||||
def authenticate(self, sock_info, reauthenticate=False):
|
||||
if reauthenticate:
|
||||
prev_id = getattr(sock_info, "oidc_token_gen_id", None)
|
||||
# Check if we've already changed tokens.
|
||||
if prev_id == self.token_gen_id:
|
||||
self.reauth_gen_id = self.idp_info_gen_id
|
||||
self.token_exp_utc = None
|
||||
if not self.properties.refresh_token_callback:
|
||||
self.clear()
|
||||
|
||||
ctx = sock_info.auth_ctx
|
||||
cmd = None
|
||||
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
resp = ctx.speculative_authenticate
|
||||
else:
|
||||
cmd = self.auth_start_cmd()
|
||||
resp = self.run_command(sock_info, cmd)
|
||||
|
||||
if resp["done"]:
|
||||
sock_info.oidc_token_gen_id = self.token_gen_id
|
||||
return
|
||||
|
||||
server_resp: Dict = bson.decode(resp["payload"])
|
||||
if "issuer" in server_resp:
|
||||
self.idp_info = server_resp
|
||||
self.idp_info_gen_id += 1
|
||||
|
||||
conversation_id = resp["conversationId"]
|
||||
token = self.get_current_token()
|
||||
sock_info.oidc_token_gen_id = self.token_gen_id
|
||||
bin_payload = Binary(bson.encode(dict(jwt=token)))
|
||||
cmd = SON(
|
||||
[
|
||||
("saslContinue", 1),
|
||||
("conversationId", conversation_id),
|
||||
("payload", bin_payload),
|
||||
]
|
||||
)
|
||||
resp = self.run_command(sock_info, cmd)
|
||||
if not resp["done"]:
|
||||
self.clear()
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
return resp
|
||||
|
||||
|
||||
def _authenticate_oidc(credentials, sock_info, reauthenticate):
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, sock_info.address)
|
||||
return authenticator.authenticate(sock_info, reauthenticate=reauthenticate)
|
||||
@ -16,6 +16,7 @@
|
||||
"""Functions and classes common to multiple pymongo modules."""
|
||||
|
||||
import datetime
|
||||
import inspect
|
||||
import warnings
|
||||
from collections import OrderedDict, abc
|
||||
from typing import (
|
||||
@ -416,14 +417,48 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]]
|
||||
|
||||
|
||||
_MECHANISM_PROPS = frozenset(
|
||||
["SERVICE_NAME", "CANONICALIZE_HOST_NAME", "SERVICE_REALM", "AWS_SESSION_TOKEN"]
|
||||
[
|
||||
"SERVICE_NAME",
|
||||
"CANONICALIZE_HOST_NAME",
|
||||
"SERVICE_REALM",
|
||||
"AWS_SESSION_TOKEN",
|
||||
"PROVIDER_NAME",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]:
|
||||
"""Validate authMechanismProperties."""
|
||||
value = validate_string(option, value)
|
||||
props: Dict[str, Any] = {}
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("Auth mechanism properties must be given as a string or a dictionary")
|
||||
for key, value in value.items():
|
||||
if isinstance(value, str):
|
||||
props[key] = value
|
||||
elif isinstance(value, bool):
|
||||
props[key] = str(value).lower()
|
||||
elif key in ["allowed_hosts"] and isinstance(value, list):
|
||||
props[key] = value
|
||||
elif inspect.isfunction(value):
|
||||
signature = inspect.signature(value)
|
||||
if key == "request_token_callback":
|
||||
expected_params = 2
|
||||
elif key == "refresh_token_callback":
|
||||
expected_params = 2
|
||||
else:
|
||||
raise ValueError(f"Unrecognized Auth mechanism function {key}")
|
||||
if len(signature.parameters) != expected_params:
|
||||
msg = f"{key} must accept {expected_params} parameters"
|
||||
raise ValueError(msg)
|
||||
props[key] = value
|
||||
else:
|
||||
raise ValueError(
|
||||
"Auth mechanism property values must be strings or callback functions"
|
||||
)
|
||||
return props
|
||||
|
||||
value = validate_string(option, value)
|
||||
for opt in value.split(","):
|
||||
try:
|
||||
key, val = opt.split(":")
|
||||
@ -715,6 +750,7 @@ KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = {
|
||||
"password": validate_string_or_none,
|
||||
"server_selector": validate_is_callable_or_none,
|
||||
"auto_encryption_opts": validate_auto_encryption_opts_or_none,
|
||||
"authoidcallowedhosts": validate_list,
|
||||
}
|
||||
|
||||
# Dictionary where keys are any URI option name, and values are the
|
||||
|
||||
@ -68,6 +68,9 @@ _RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
# Server code raised when re-authentication is required
|
||||
_REAUTHENTICATION_REQUIRED_CODE = 391
|
||||
|
||||
|
||||
def _gen_index_name(keys):
|
||||
"""Generate an index name from the set of fields it is over."""
|
||||
@ -267,3 +270,35 @@ def _handle_exception():
|
||||
pass
|
||||
finally:
|
||||
del einfo
|
||||
|
||||
|
||||
def _handle_reauth(func):
|
||||
def inner(*args, **kwargs):
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.pool import SocketInfo
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except OperationFailure as exc:
|
||||
if no_reauth:
|
||||
raise
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
# Look for an argument that either is a SocketInfo
|
||||
# or has a socket_info attribute, so we can trigger
|
||||
# a reauth.
|
||||
sock_info = None
|
||||
for arg in args:
|
||||
if isinstance(arg, SocketInfo):
|
||||
sock_info = arg
|
||||
break
|
||||
if hasattr(arg, "sock_info"):
|
||||
sock_info = arg.sock_info
|
||||
break
|
||||
if sock_info:
|
||||
sock_info.authenticate(reauthenticate=True)
|
||||
else:
|
||||
raise
|
||||
return func(*args, **kwargs)
|
||||
raise
|
||||
|
||||
return inner
|
||||
|
||||
@ -54,6 +54,7 @@ from pymongo.errors import (
|
||||
ProtocolError,
|
||||
)
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.helpers import _handle_reauth
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
@ -909,6 +910,7 @@ class _BulkWriteContext(object):
|
||||
self.start_time = datetime.datetime.now()
|
||||
return result
|
||||
|
||||
@_handle_reauth
|
||||
def write_command(self, cmd, request_id, msg, docs):
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
if self.publish:
|
||||
|
||||
@ -57,6 +57,7 @@ from pymongo.errors import (
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.helpers import _handle_reauth
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
|
||||
from pymongo.network import command, receive_message
|
||||
@ -756,7 +757,7 @@ class SocketInfo(object):
|
||||
if creds:
|
||||
if creds.mechanism == "DEFAULT" and creds.username:
|
||||
cmd["saslSupportedMechs"] = creds.source + "." + creds.username
|
||||
auth_ctx = auth._AuthContext.from_credentials(creds)
|
||||
auth_ctx = auth._AuthContext.from_credentials(creds, self.address)
|
||||
if auth_ctx:
|
||||
cmd["speculativeAuthenticate"] = auth_ctx.speculate_command()
|
||||
else:
|
||||
@ -813,6 +814,7 @@ class SocketInfo(object):
|
||||
helpers._check_command_response(response_doc, self.max_wire_version)
|
||||
return response_doc
|
||||
|
||||
@_handle_reauth
|
||||
def command(
|
||||
self,
|
||||
dbname,
|
||||
@ -966,17 +968,22 @@ class SocketInfo(object):
|
||||
helpers._check_command_response(result, self.max_wire_version)
|
||||
return result
|
||||
|
||||
def authenticate(self):
|
||||
def authenticate(self, reauthenticate=False):
|
||||
"""Authenticate to the server if needed.
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
"""
|
||||
# CMAP spec says to publish the ready event only after authenticating
|
||||
# the connection.
|
||||
if reauthenticate:
|
||||
if self.performed_handshake:
|
||||
# Existing auth_ctx is stale, remove it.
|
||||
self.auth_ctx = None
|
||||
self.ready = False
|
||||
if not self.ready:
|
||||
creds = self.opts._credentials
|
||||
if creds:
|
||||
auth.authenticate(creds, self)
|
||||
auth.authenticate(creds, self, reauthenticate=reauthenticate)
|
||||
self.ready = True
|
||||
if self.enabled_for_cmap:
|
||||
self.listeners.publish_connection_ready(self.address, self.id)
|
||||
|
||||
@ -18,7 +18,7 @@ from datetime import datetime
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo.errors import NotPrimaryError, OperationFailure
|
||||
from pymongo.helpers import _check_command_response
|
||||
from pymongo.helpers import _check_command_response, _handle_reauth
|
||||
from pymongo.message import _convert_exception, _OpMsg
|
||||
from pymongo.response import PinnedResponse, Response
|
||||
|
||||
@ -73,6 +73,7 @@ class Server(object):
|
||||
"""Check the server's state soon."""
|
||||
self._monitor.request_check()
|
||||
|
||||
@_handle_reauth
|
||||
def run_operation(self, sock_info, operation, read_preference, listeners, unpack_res):
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
|
||||
|
||||
@ -444,6 +444,133 @@
|
||||
"AWS_SESSION_TOKEN": "token!@#$%^&*()_+"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should recognise the mechanism and request callback (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
|
||||
"callback": ["oidcRequest"],
|
||||
"valid": true,
|
||||
"credential": {
|
||||
"username": null,
|
||||
"password": null,
|
||||
"source": "$external",
|
||||
"mechanism": "MONGODB-OIDC",
|
||||
"mechanism_properties": {
|
||||
"REQUEST_TOKEN_CALLBACK": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external",
|
||||
"callback": ["oidcRequest"],
|
||||
"valid": true,
|
||||
"credential": {
|
||||
"username": null,
|
||||
"password": null,
|
||||
"source": "$external",
|
||||
"mechanism": "MONGODB-OIDC",
|
||||
"mechanism_properties": {
|
||||
"REQUEST_TOKEN_CALLBACK": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
|
||||
"callback": ["oidcRequest", "oidcRefresh"],
|
||||
"valid": true,
|
||||
"credential": {
|
||||
"username": null,
|
||||
"password": null,
|
||||
"source": "$external",
|
||||
"mechanism": "MONGODB-OIDC",
|
||||
"mechanism_properties": {
|
||||
"REQUEST_TOKEN_CALLBACK": true,
|
||||
"REFRESH_TOKEN_CALLBACK": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)",
|
||||
"uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC",
|
||||
"callback": ["oidcRequest"],
|
||||
"valid": true,
|
||||
"credential": {
|
||||
"username": "principalName",
|
||||
"password": null,
|
||||
"source": "$external",
|
||||
"mechanism": "MONGODB-OIDC",
|
||||
"mechanism_properties": {
|
||||
"REQUEST_TOKEN_CALLBACK": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should recognise the mechanism with aws device (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws",
|
||||
"valid": true,
|
||||
"credential": {
|
||||
"username": null,
|
||||
"password": null,
|
||||
"source": "$external",
|
||||
"mechanism": "MONGODB-OIDC",
|
||||
"mechanism_properties": {
|
||||
"PROVIDER_NAME": "aws"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws",
|
||||
"valid": true,
|
||||
"credential": {
|
||||
"username": null,
|
||||
"password": null,
|
||||
"source": "$external",
|
||||
"mechanism": "MONGODB-OIDC",
|
||||
"mechanism_properties": {
|
||||
"PROVIDER_NAME": "aws"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception if username and password are specified (MONGODB-OIDC)",
|
||||
"uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC",
|
||||
"callback": ["oidcRequest"],
|
||||
"valid": false,
|
||||
"credential": null
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)",
|
||||
"uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp",
|
||||
"valid": false,
|
||||
"credential": null
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted",
|
||||
"valid": false,
|
||||
"credential": null
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
|
||||
"valid": false,
|
||||
"credential": null
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
|
||||
"callback": ["oidcRefresh"],
|
||||
"valid": false,
|
||||
"credential": null
|
||||
},
|
||||
{
|
||||
"description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)",
|
||||
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted",
|
||||
"valid": false,
|
||||
"credential": null
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
191
test/auth/unified/reauthenticate_with_retry.json
Normal file
191
test/auth/unified/reauthenticate_with_retry.json
Normal file
@ -0,0 +1,191 @@
|
||||
{
|
||||
"description": "reauthenticate_with_retry",
|
||||
"schemaVersion": "1.12",
|
||||
"runOnRequirements": [
|
||||
{
|
||||
"minServerVersion": "6.3",
|
||||
"auth": true
|
||||
}
|
||||
],
|
||||
"createEntities": [
|
||||
{
|
||||
"client": {
|
||||
"id": "client0",
|
||||
"uriOptions": {
|
||||
"retryReads": true,
|
||||
"retryWrites": true
|
||||
},
|
||||
"observeEvents": [
|
||||
"commandStartedEvent",
|
||||
"commandSucceededEvent",
|
||||
"commandFailedEvent"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"database": {
|
||||
"id": "database0",
|
||||
"client": "client0",
|
||||
"databaseName": "db"
|
||||
}
|
||||
},
|
||||
{
|
||||
"collection": {
|
||||
"id": "collection0",
|
||||
"database": "database0",
|
||||
"collectionName": "collName"
|
||||
}
|
||||
}
|
||||
],
|
||||
"initialData": [
|
||||
{
|
||||
"collectionName": "collName",
|
||||
"databaseName": "db",
|
||||
"documents": []
|
||||
}
|
||||
],
|
||||
"tests": [
|
||||
{
|
||||
"description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true",
|
||||
"operations": [
|
||||
{
|
||||
"name": "failPoint",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"client": "client0",
|
||||
"failPoint": {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {
|
||||
"times": 1
|
||||
},
|
||||
"data": {
|
||||
"failCommands": [
|
||||
"find"
|
||||
],
|
||||
"errorCode": 391
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "find",
|
||||
"arguments": {
|
||||
"filter": {}
|
||||
},
|
||||
"object": "collection0",
|
||||
"expectResult": []
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"find": "collName",
|
||||
"filter": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandFailedEvent": {
|
||||
"commandName": "find"
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"find": "collName",
|
||||
"filter": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandSucceededEvent": {
|
||||
"commandName": "find"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true",
|
||||
"operations": [
|
||||
{
|
||||
"name": "failPoint",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"client": "client0",
|
||||
"failPoint": {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {
|
||||
"times": 1
|
||||
},
|
||||
"data": {
|
||||
"failCommands": [
|
||||
"insert"
|
||||
],
|
||||
"errorCode": 391
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "insertOne",
|
||||
"object": "collection0",
|
||||
"arguments": {
|
||||
"document": {
|
||||
"_id": 1,
|
||||
"x": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"insert": "collName",
|
||||
"documents": [
|
||||
{
|
||||
"_id": 1,
|
||||
"x": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandFailedEvent": {
|
||||
"commandName": "insert"
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"insert": "collName",
|
||||
"documents": [
|
||||
{
|
||||
"_id": 1,
|
||||
"x": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandSucceededEvent": {
|
||||
"commandName": "insert"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
191
test/auth/unified/reauthenticate_without_retry.json
Normal file
191
test/auth/unified/reauthenticate_without_retry.json
Normal file
@ -0,0 +1,191 @@
|
||||
{
|
||||
"description": "reauthenticate_without_retry",
|
||||
"schemaVersion": "1.12",
|
||||
"runOnRequirements": [
|
||||
{
|
||||
"minServerVersion": "6.3",
|
||||
"auth": true
|
||||
}
|
||||
],
|
||||
"createEntities": [
|
||||
{
|
||||
"client": {
|
||||
"id": "client0",
|
||||
"uriOptions": {
|
||||
"retryReads": false,
|
||||
"retryWrites": false
|
||||
},
|
||||
"observeEvents": [
|
||||
"commandStartedEvent",
|
||||
"commandSucceededEvent",
|
||||
"commandFailedEvent"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"database": {
|
||||
"id": "database0",
|
||||
"client": "client0",
|
||||
"databaseName": "db"
|
||||
}
|
||||
},
|
||||
{
|
||||
"collection": {
|
||||
"id": "collection0",
|
||||
"database": "database0",
|
||||
"collectionName": "collName"
|
||||
}
|
||||
}
|
||||
],
|
||||
"initialData": [
|
||||
{
|
||||
"collectionName": "collName",
|
||||
"databaseName": "db",
|
||||
"documents": []
|
||||
}
|
||||
],
|
||||
"tests": [
|
||||
{
|
||||
"description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false",
|
||||
"operations": [
|
||||
{
|
||||
"name": "failPoint",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"client": "client0",
|
||||
"failPoint": {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {
|
||||
"times": 1
|
||||
},
|
||||
"data": {
|
||||
"failCommands": [
|
||||
"find"
|
||||
],
|
||||
"errorCode": 391
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "find",
|
||||
"arguments": {
|
||||
"filter": {}
|
||||
},
|
||||
"object": "collection0",
|
||||
"expectResult": []
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"find": "collName",
|
||||
"filter": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandFailedEvent": {
|
||||
"commandName": "find"
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"find": "collName",
|
||||
"filter": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandSucceededEvent": {
|
||||
"commandName": "find"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false",
|
||||
"operations": [
|
||||
{
|
||||
"name": "failPoint",
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"client": "client0",
|
||||
"failPoint": {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {
|
||||
"times": 1
|
||||
},
|
||||
"data": {
|
||||
"failCommands": [
|
||||
"insert"
|
||||
],
|
||||
"errorCode": 391
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "insertOne",
|
||||
"object": "collection0",
|
||||
"arguments": {
|
||||
"document": {
|
||||
"_id": 1,
|
||||
"x": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"insert": "collName",
|
||||
"documents": [
|
||||
{
|
||||
"_id": 1,
|
||||
"x": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandFailedEvent": {
|
||||
"commandName": "insert"
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"insert": "collName",
|
||||
"documents": [
|
||||
{
|
||||
"_id": 1,
|
||||
"x": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"commandSucceededEvent": {
|
||||
"commandName": "insert"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
821
test/auth_aws/test_auth_oidc.py
Normal file
821
test/auth_aws/test_auth_oidc.py
Normal file
@ -0,0 +1,821 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test MONGODB-OIDC Authentication."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.utils import EventListener
|
||||
|
||||
from bson import SON
|
||||
from pymongo import MongoClient
|
||||
from pymongo.auth_oidc import _CACHE as _oidc_cache
|
||||
from pymongo.cursor import CursorType
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.operations import InsertOne
|
||||
|
||||
|
||||
class TestAuthOIDC(unittest.TestCase):
|
||||
uri: str
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.uri_single = os.environ["MONGODB_URI_SINGLE"]
|
||||
cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"]
|
||||
cls.uri_admin = os.environ["MONGODB_URI"]
|
||||
cls.token_dir = os.environ["OIDC_TOKEN_DIR"]
|
||||
|
||||
def setUp(self):
|
||||
self.request_called = 0
|
||||
self.refresh_called = 0
|
||||
_oidc_cache.clear()
|
||||
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1")
|
||||
|
||||
def create_request_cb(self, username="test_user1", expires_in_seconds=None, sleep=0):
|
||||
|
||||
token_file = os.path.join(self.token_dir, username)
|
||||
|
||||
def request_token(server_info, context):
|
||||
# Validate the info.
|
||||
self.assertIn("issuer", server_info)
|
||||
self.assertIn("clientId", server_info)
|
||||
|
||||
# Validate the timeout.
|
||||
timeout_seconds = context["timeout_seconds"]
|
||||
self.assertEqual(timeout_seconds, 60 * 5)
|
||||
with open(token_file) as fid:
|
||||
token = fid.read()
|
||||
resp = dict(access_token=token)
|
||||
|
||||
time.sleep(sleep)
|
||||
|
||||
if expires_in_seconds is not None:
|
||||
resp["expires_in_seconds"] = expires_in_seconds
|
||||
self.request_called += 1
|
||||
return resp
|
||||
|
||||
return request_token
|
||||
|
||||
def create_refresh_cb(self, username="test_user1", expires_in_seconds=None):
|
||||
|
||||
token_file = os.path.join(self.token_dir, username)
|
||||
|
||||
def refresh_token(server_info, context):
|
||||
with open(token_file) as fid:
|
||||
token = fid.read()
|
||||
|
||||
# Validate the info.
|
||||
self.assertIn("issuer", server_info)
|
||||
self.assertIn("clientId", server_info)
|
||||
|
||||
# Validate the creds
|
||||
self.assertIsNotNone(context["refresh_token"])
|
||||
|
||||
# Validate the timeout.
|
||||
self.assertEqual(context["timeout_seconds"], 60 * 5)
|
||||
|
||||
resp = dict(access_token=token)
|
||||
if expires_in_seconds is not None:
|
||||
resp["expires_in_seconds"] = expires_in_seconds
|
||||
self.refresh_called += 1
|
||||
return resp
|
||||
|
||||
return refresh_token
|
||||
|
||||
@contextmanager
|
||||
def fail_point(self, command_args):
|
||||
cmd_on = SON([("configureFailPoint", "failCommand")])
|
||||
cmd_on.update(command_args)
|
||||
client = MongoClient(self.uri_admin)
|
||||
client.admin.command(cmd_on)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off")
|
||||
|
||||
def test_connect_callbacks_single_implicit_username(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = dict(request_token_callback=request_token)
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_callbacks_single_explicit_username(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = dict(request_token_callback=request_token)
|
||||
client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_callbacks_multiple_principal_user1(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = dict(request_token_callback=request_token)
|
||||
client = MongoClient(
|
||||
self.uri_multiple, username="test_user1", authmechanismproperties=props
|
||||
)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_callbacks_multiple_principal_user2(self):
|
||||
request_token = self.create_request_cb("test_user2")
|
||||
props: Dict = dict(request_token_callback=request_token)
|
||||
client = MongoClient(
|
||||
self.uri_multiple, username="test_user2", authmechanismproperties=props
|
||||
)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_callbacks_multiple_no_username(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = dict(request_token_callback=request_token)
|
||||
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_allowed_hosts_blocked(self):
|
||||
request_token = self.create_request_cb()
|
||||
props: Dict = dict(request_token_callback=request_token, allowed_hosts=[])
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
with self.assertRaises(ConfigurationError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
props: Dict = dict(request_token_callback=request_token, allowed_hosts=["example.com"])
|
||||
client = MongoClient(
|
||||
self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False
|
||||
)
|
||||
with self.assertRaises(ConfigurationError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_aws_single_principal(self):
|
||||
props = dict(PROVIDER_NAME="aws")
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_aws_multiple_principal_user1(self):
|
||||
props = dict(PROVIDER_NAME="aws")
|
||||
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_aws_multiple_principal_user2(self):
|
||||
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2")
|
||||
props = dict(PROVIDER_NAME="aws")
|
||||
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_connect_aws_allowed_hosts_ignored(self):
|
||||
props = dict(PROVIDER_NAME="aws", allowed_hosts=[])
|
||||
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_valid_callbacks(self):
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
props: Dict = dict(
|
||||
request_token_callback=request_cb,
|
||||
refresh_token_callback=refresh_cb,
|
||||
)
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_lock_avoids_extra_callbacks(self):
|
||||
request_cb = self.create_request_cb(sleep=0.5)
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
|
||||
def run_test():
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
t1 = threading.Thread(target=run_test)
|
||||
t2 = threading.Thread(target=run_test)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
self.assertEqual(self.request_called, 1)
|
||||
self.assertEqual(self.refresh_called, 2)
|
||||
|
||||
def test_request_callback_returns_null(self):
|
||||
def request_token_null(a, b):
|
||||
return None
|
||||
|
||||
props: Dict = dict(request_token_callback=request_token_null)
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_refresh_callback_returns_null(self):
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
|
||||
def refresh_token_null(a, b):
|
||||
return None
|
||||
|
||||
props: Dict = dict(
|
||||
request_token_callback=request_cb, refresh_token_callback=refresh_token_null
|
||||
)
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_request_callback_invalid_result(self):
|
||||
def request_token_invalid(a, b):
|
||||
return dict()
|
||||
|
||||
props: Dict = dict(request_token_callback=request_token_invalid)
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def request_cb_extra_value(server_info, context):
|
||||
result = self.create_request_cb()(server_info, context)
|
||||
result["foo"] = "bar"
|
||||
return result
|
||||
|
||||
props: Dict = dict(request_token_callback=request_cb_extra_value)
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_refresh_callback_missing_data(self):
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
|
||||
def refresh_cb_no_token(a, b):
|
||||
return dict()
|
||||
|
||||
props: Dict = dict(
|
||||
request_token_callback=request_cb, refresh_token_callback=refresh_cb_no_token
|
||||
)
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_refresh_callback_extra_data(self):
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
|
||||
def refresh_cb_extra_value(server_info, context):
|
||||
result = self.create_refresh_cb()(server_info, context)
|
||||
result["foo"] = "bar"
|
||||
return result
|
||||
|
||||
props: Dict = dict(
|
||||
request_token_callback=request_cb, refresh_token_callback=refresh_cb_extra_value
|
||||
)
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
with self.assertRaises(ValueError):
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
def test_cache_with_refresh(self):
|
||||
# Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token.
|
||||
|
||||
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
|
||||
request_cb = self.create_request_cb(expires_in_seconds=60)
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
|
||||
# Ensure that a ``find`` operation adds credentials to the cache.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
self.assertEqual(len(_oidc_cache), 1)
|
||||
|
||||
# Create a new client with the same request callback and a refresh callback.
|
||||
# Ensure that a ``find`` operation results in a call to the refresh callback.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
self.assertEqual(len(_oidc_cache), 1)
|
||||
|
||||
def test_cache_with_no_refresh(self):
|
||||
# Create a new client with a request callback callback.
|
||||
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
props = dict(request_token_callback=request_cb)
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
|
||||
# Ensure that a ``find`` operation adds credentials to the cache.
|
||||
self.request_called = 0
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
self.assertEqual(self.request_called, 1)
|
||||
self.assertEqual(len(_oidc_cache), 1)
|
||||
|
||||
# Create a new client with the same request callback.
|
||||
# Ensure that a ``find`` operation results in a call to the request callback.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
self.assertEqual(self.request_called, 2)
|
||||
self.assertEqual(len(_oidc_cache), 1)
|
||||
|
||||
def test_cache_key_includes_callback(self):
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
props: Dict = dict(request_token_callback=request_cb)
|
||||
|
||||
# Ensure that a ``find`` operation adds a new entry to the cache.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
# Create a new client with a different request callback.
|
||||
def request_token_2(a, b):
|
||||
return request_cb(a, b)
|
||||
|
||||
props["request_token_callback"] = request_token_2
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
|
||||
# Ensure that a ``find`` operation adds a new entry to the cache.
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
self.assertEqual(len(_oidc_cache), 2)
|
||||
|
||||
def test_cache_clears_on_error(self):
|
||||
request_cb = self.create_request_cb()
|
||||
|
||||
# Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials.
|
||||
def refresh_cb(a, b):
|
||||
return dict(access_token="bad")
|
||||
|
||||
# Add a token to the cache that will expire soon.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
# Create a new client with the same callbacks.
|
||||
client = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
|
||||
# Ensure that another ``find`` operation results in an error.
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.test.test.find_one()
|
||||
|
||||
client.close()
|
||||
|
||||
# Ensure that the cache has been cleared.
|
||||
authenticator = list(_oidc_cache.values())[0]
|
||||
self.assertIsNone(authenticator.idp_info)
|
||||
|
||||
def test_cache_is_not_used_in_aws_automatic_workflow(self):
|
||||
# Create a new client using the AWS device workflow.
|
||||
# Ensure that a ``find`` operation does not add credentials to the cache.
|
||||
props = dict(PROVIDER_NAME="aws")
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
client.test.test.find_one()
|
||||
client.close()
|
||||
|
||||
# Ensure that the cache has been cleared.
|
||||
authenticator = list(_oidc_cache.values())[0]
|
||||
self.assertIsNone(authenticator.idp_info)
|
||||
|
||||
def test_speculative_auth_success(self):
|
||||
# Clear the cache
|
||||
_oidc_cache.clear()
|
||||
token_file = os.path.join(self.token_dir, "test_user1")
|
||||
|
||||
def request_token(a, b):
|
||||
with open(token_file) as fid:
|
||||
token = fid.read()
|
||||
return dict(access_token=token, expires_in_seconds=1000)
|
||||
|
||||
# Create a client with a request callback that returns a valid token
|
||||
# that will not expire soon.
|
||||
props: Dict = dict(request_token_callback=request_token)
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Set a fail point for saslStart commands.
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["saslStart"], "errorCode": 18},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Close the client.
|
||||
client.close()
|
||||
|
||||
# Create a new client.
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Set a fail point for saslStart commands.
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["saslStart"], "errorCode": 18},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Close the client.
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds(self):
|
||||
listener = EventListener()
|
||||
|
||||
# Create request and refresh callbacks that return valid credentials
|
||||
# that will not expire soon.
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(
|
||||
self.uri_single, event_listeners=[listener], authmechanismproperties=props
|
||||
)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
|
||||
listener.reset()
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
started_events = [
|
||||
i.command_name for i in listener.started_events if not i.command_name.startswith("sasl")
|
||||
]
|
||||
succeeded_events = [
|
||||
i.command_name
|
||||
for i in listener.succeeded_events
|
||||
if not i.command_name.startswith("sasl")
|
||||
]
|
||||
failed_events = [
|
||||
i.command_name for i in listener.failed_events if not i.command_name.startswith("sasl")
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
started_events,
|
||||
[
|
||||
"find",
|
||||
"find",
|
||||
],
|
||||
)
|
||||
self.assertEqual(succeeded_events, ["find"])
|
||||
self.assertEqual(failed_events, ["find"])
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_bulk_write(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["insert"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a bulk write operation.
|
||||
client.test.test.bulk_write([InsertOne({})])
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_bulk_read(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Perform a bulk write operation.
|
||||
client.test.test.bulk_write([InsertOne({})])
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a bulk read operation.
|
||||
cursor = client.test.test.find_raw_batches({})
|
||||
list(cursor)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_cursor(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform an insert operation.
|
||||
client.test.test.insert_one({"a": 1})
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
cursor = client.test.test.find({"a": 1})
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_get_more(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform an insert operation.
|
||||
client.test.test.insert_many([{"a": 1}, {"a": 1}])
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["getMore"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
cursor = client.test.test.find({"a": 1}, batch_size=1)
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_get_more_exhaust(self):
|
||||
# Ensure no mongos
|
||||
props = dict(PROVIDER_NAME="aws")
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
hello = client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
if hello.get("msg") != "isdbgrid":
|
||||
raise unittest.SkipTest("Must not be a mongos")
|
||||
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform an insert operation.
|
||||
client.test.test.insert_many([{"a": 1}, {"a": 1}])
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["getMore"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation.
|
||||
cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST)
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_succeeds_command(self):
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
|
||||
print("start of test")
|
||||
client = MongoClient(self.uri_single, authmechanismproperties=props)
|
||||
|
||||
# Perform an insert operation.
|
||||
client.test.test.insert_one({"a": 1})
|
||||
|
||||
# Assert that the refresh callback has not been called.
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["count"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a count operation.
|
||||
cursor = client.test.command(dict(count="test"))
|
||||
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
|
||||
# Assert that the refresh callback has been called.
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_retries_and_succeeds_with_cache(self):
|
||||
listener = EventListener()
|
||||
|
||||
# Create request and refresh callbacks that return valid credentials
|
||||
# that will not expire soon.
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(
|
||||
self.uri_single, event_listeners=[listener], authmechanismproperties=props
|
||||
)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Set a fail point for ``saslStart`` commands of the form
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that succeeds.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Close the client.
|
||||
client.close()
|
||||
|
||||
def test_reauthenticate_fails_with_no_cache(self):
|
||||
listener = EventListener()
|
||||
|
||||
# Create request and refresh callbacks that return valid credentials
|
||||
# that will not expire soon.
|
||||
request_cb = self.create_request_cb()
|
||||
refresh_cb = self.create_refresh_cb()
|
||||
|
||||
# Create a client with the callbacks.
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client = MongoClient(
|
||||
self.uri_single, event_listeners=[listener], authmechanismproperties=props
|
||||
)
|
||||
|
||||
# Perform a find operation.
|
||||
client.test.test.find_one()
|
||||
|
||||
# Clear the cache.
|
||||
_oidc_cache.clear()
|
||||
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that fails.
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.test.test.find_one()
|
||||
|
||||
client.close()
|
||||
|
||||
def test_late_reauth_avoids_callback(self):
|
||||
# Step 1: connect with both clients
|
||||
request_cb = self.create_request_cb(expires_in_seconds=1e6)
|
||||
refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6)
|
||||
|
||||
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
|
||||
client1 = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client1.test.test.find_one()
|
||||
client2 = MongoClient(self.uri_single, authMechanismProperties=props)
|
||||
client2.test.test.find_one()
|
||||
|
||||
self.assertEqual(self.refresh_called, 0)
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
# Step 2: cause a find 391 on the first client
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that succeeds.
|
||||
client1.test.test.find_one()
|
||||
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
# Step 3: cause a find 391 on the second client
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Perform a find operation that succeeds.
|
||||
client2.test.test.find_one()
|
||||
|
||||
self.assertEqual(self.refresh_called, 1)
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
client1.close()
|
||||
client2.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -22,6 +22,7 @@ import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
@ -41,7 +42,16 @@ def create_test(test_case):
|
||||
if not valid:
|
||||
self.assertRaises(Exception, MongoClient, uri, connect=False)
|
||||
else:
|
||||
client = MongoClient(uri, connect=False)
|
||||
props = {}
|
||||
if credential:
|
||||
props = credential["mechanism_properties"] or {}
|
||||
if props.get("REQUEST_TOKEN_CALLBACK"):
|
||||
props["request_token_callback"] = lambda x, y: 1
|
||||
del props["REQUEST_TOKEN_CALLBACK"]
|
||||
if props.get("REFRESH_TOKEN_CALLBACK"):
|
||||
props["refresh_token_callback"] = lambda a, b: 1
|
||||
del props["REFRESH_TOKEN_CALLBACK"]
|
||||
client = MongoClient(uri, connect=False, authmechanismproperties=props)
|
||||
credentials = client.options.pool_options._credentials
|
||||
if credential is None:
|
||||
self.assertIsNone(credentials)
|
||||
@ -70,6 +80,16 @@ def create_test(test_case):
|
||||
self.assertEqual(
|
||||
actual.aws_session_token, expected["AWS_SESSION_TOKEN"]
|
||||
)
|
||||
elif "PROVIDER_NAME" in expected:
|
||||
self.assertEqual(actual.provider_name, expected["PROVIDER_NAME"])
|
||||
elif "request_token_callback" in expected:
|
||||
self.assertEqual(
|
||||
actual.request_token_callback, expected["request_token_callback"]
|
||||
)
|
||||
elif "refresh_token_callback" in expected:
|
||||
self.assertEqual(
|
||||
actual.refresh_token_callback, expected["refresh_token_callback"]
|
||||
)
|
||||
else:
|
||||
self.fail("Unhandled property: %s" % (key,))
|
||||
else:
|
||||
@ -82,7 +102,7 @@ def create_test(test_case):
|
||||
|
||||
|
||||
def create_tests():
|
||||
for filename in glob.glob(os.path.join(_TEST_PATH, "*.json")):
|
||||
for filename in glob.glob(os.path.join(_TEST_PATH, "legacy", "*.json")):
|
||||
test_suffix, _ = os.path.splitext(os.path.basename(filename))
|
||||
with open(filename) as auth_tests:
|
||||
test_cases = json.load(auth_tests)["tests"]
|
||||
@ -97,5 +117,12 @@ def create_tests():
|
||||
create_tests()
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(_TEST_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user