PYTHON-3460 Implement OIDC SASL mechanism (#1138)

This commit is contained in:
Steven Silvester 2023-05-11 14:35:30 -05:00 committed by GitHub
parent d504322a74
commit afd7e1c2cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1970 additions and 18 deletions

View File

@ -749,6 +749,68 @@ functions:
fi
PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-aws-test.sh
"bootstrap oidc":
- command: ec2.assume_role
params:
role_arn: ${aws_test_secrets_role}
- command: shell.exec
type: test
params:
working_dir: "src"
shell: bash
script: |
${PREPARE_SHELL}
if [ "${skip_EC2_auth_test}" = "true" ]; then
echo "This platform does not support the oidc auth test, skipping..."
exit 0
fi
cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN}
export OIDC_TOKEN_DIR=/tmp/tokens
. ./activate-authoidcvenv.sh
python oidc_write_orchestration.py
python oidc_get_tokens.py
"run oidc auth test with aws credentials":
- command: shell.exec
type: test
params:
working_dir: "src"
shell: bash
script: |
${PREPARE_SHELL}
if [ "${skip_EC2_auth_test}" = "true" ]; then
echo "This platform does not support the oidc auth test, skipping..."
exit 0
fi
cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc
mongosh setup_oidc.js
- command: shell.exec
type: test
params:
working_dir: "src"
silent: true
script: |
# DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does)
cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh"
export OIDC_TOKEN_DIR=/tmp/tokens
EOF
- command: shell.exec
type: test
params:
working_dir: "src"
script: |
${PREPARE_SHELL}
if [ "${skip_web_identity_auth_test}" = "true" ]; then
echo "This platform does not support the oidc auth test, skipping..."
exit 0
fi
PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-oidc-test.sh
"run aws auth test with aws credentials as environment variables":
- command: shell.exec
type: test
@ -2034,6 +2096,19 @@ tasks:
- func: "run aws auth test with aws web identity credentials"
- func: "run aws ECS auth test"
- name: "oidc-auth-test-latest"
commands:
- func: "bootstrap oidc"
- func: "bootstrap mongo-orchestration"
vars:
AUTH: "auth"
ORCHESTRATION_FILE: "auth-oidc.json"
TOPOLOGY: "replica_set"
VERSION: "latest"
- func: "run oidc auth test with aws credentials"
vars:
AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1
- name: load-balancer-test
commands:
- func: "bootstrap mongo-orchestration"
@ -3103,6 +3178,14 @@ buildvariants:
# macOS MongoDB servers do not staple OCSP responses and only support RSA.
- name: ".ocsp-rsa !.ocsp-staple"
- matrix_name: "oidc-auth-test"
matrix_spec:
platform: [ ubuntu-20.04 ]
python-version: ["3.9"]
display_name: "MONGODB-OIDC Auth ${platform} ${python-version}"
tasks:
- name: "oidc-auth-test-latest"
- matrix_name: "aws-auth-test"
matrix_spec:
platform: [ubuntu-20.04]

View File

@ -70,6 +70,9 @@ for spec in "$@"
do
# Match the spec dir name, the python test dir name, and/or common abbreviations.
case "$spec" in
auth)
cpjson auth/tests/ auth
;;
atlas-data-lake-testing|data_lake)
cpjson atlas-data-lake-testing/tests/ data_lake
;;

View File

@ -0,0 +1,85 @@
#!/bin/bash
set -o xtrace
set -o errexit # Exit the script with error if any of the commands fail
############################################
# Main Program #
############################################
# Supported/used environment variables:
# MONGODB_URI Set the URI, including an optional username/password to use
# to connect to the server via MONGODB-OIDC authentication
# mechanism.
# PYTHON_BINARY The Python version to use.
echo "Running MONGODB-OIDC authentication tests"
# ensure no secrets are printed in log files
set +x
# load the script
shopt -s expand_aliases # needed for `urlencode` alias
[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh"
MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"}
MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC"
MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true"
if [ -z "${OIDC_TOKEN_DIR}" ]; then
echo "Must specify OIDC_TOKEN_DIR"
exit 1
fi
export MONGODB_URI_SINGLE="$MONGODB_URI_SINGLE"
export MONGODB_URI_MULTIPLE="$MONGODB_URI_MULTIPLE"
export MONGODB_URI="$MONGODB_URI"
echo $MONGODB_URI_SINGLE
echo $MONGODB_URI_MULTIPLE
echo $MONGODB_URI
if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then
if echo "$MONGODB_URI" | grep -q "@"; then
echo "MONGODB_URI unexpectedly contains user credentials!";
exit 1
fi
fi
# show test output
set -x
# Workaround macOS python 3.9 incompatibility with system virtualenv.
if [ "$(uname -s)" = "Darwin" ]; then
VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.9/bin/python3 -m virtualenv"
else
VIRTUALENV=$(command -v virtualenv)
fi
authtest () {
if [ "Windows_NT" = "$OS" ]; then
PYTHON=$(cygpath -m $PYTHON)
fi
echo "Running MONGODB-OIDC authentication tests with $PYTHON"
$PYTHON --version
$VIRTUALENV -p $PYTHON --never-download venvoidc
if [ "Windows_NT" = "$OS" ]; then
. venvoidc/Scripts/activate
else
. venvoidc/bin/activate
fi
python -m pip install -U pip setuptools
python -m pip install '.[aws]'
python test/auth_aws/test_auth_oidc.py -v
deactivate
rm -rf venvoidc
}
PYTHON=${PYTHON_BINARY:-}
if [ -z "$PYTHON" ]; then
echo "Cannot test without specifying PYTHON_BINARY"
exit 1
fi
authtest

View File

@ -27,6 +27,7 @@ from urllib.parse import quote
from bson.binary import Binary
from bson.son import SON
from pymongo.auth_aws import _authenticate_aws
from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCProperties
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
@ -48,6 +49,7 @@ MECHANISMS = frozenset(
[
"GSSAPI",
"MONGODB-CR",
"MONGODB-OIDC",
"MONGODB-X509",
"MONGODB-AWS",
"PLAIN",
@ -101,7 +103,7 @@ _AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"])
def _build_credentials_tuple(mech, source, user, passwd, extra, database):
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS") and user is None:
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError("%s requires a username." % (mech,))
if mech == "GSSAPI":
if source is not None and source != "$external":
@ -137,6 +139,32 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
aws_props = _AWSProperties(aws_session_token=aws_session_token)
# user can be None for temporary link-local EC2 credentials.
return MongoCredential(mech, "$external", user, passwd, aws_props, None)
elif mech == "MONGODB-OIDC":
properties = extra.get("authmechanismproperties", {})
request_token_callback = properties.get("request_token_callback")
refresh_token_callback = properties.get("refresh_token_callback", None)
provider_name = properties.get("PROVIDER_NAME", "")
default_allowed = [
"*.mongodb.net",
"*.mongodb-dev.net",
"*.mongodbgov.net",
"localhost",
"127.0.0.1",
"::1",
]
allowed_hosts = properties.get("allowed_hosts", default_allowed)
if not request_token_callback and provider_name != "aws":
raise ConfigurationError(
"authentication with MONGODB-OIDC requires providing an request_token_callback or a provider_name of 'aws'"
)
oidc_props = _OIDCProperties(
request_token_callback=request_token_callback,
refresh_token_callback=refresh_token_callback,
provider_name=provider_name,
allowed_hosts=allowed_hosts,
)
return MongoCredential(mech, "$external", user, passwd, oidc_props, None)
elif mech == "PLAIN":
source_database = source or database or "$external"
return MongoCredential(mech, source_database, user, passwd, None, None)
@ -439,7 +467,7 @@ def _authenticate_x509(credentials, sock_info):
# MONGODB-X509 is done after the speculative auth step.
return
cmd = _X509Context(credentials).speculate_command()
cmd = _X509Context(credentials, sock_info.address).speculate_command()
sock_info.command("$external", cmd)
@ -482,6 +510,7 @@ _AUTH_MAP: Mapping[str, Callable] = {
"MONGODB-CR": _authenticate_mongo_cr,
"MONGODB-X509": _authenticate_x509,
"MONGODB-AWS": _authenticate_aws,
"MONGODB-OIDC": _authenticate_oidc,
"PLAIN": _authenticate_plain,
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
@ -490,15 +519,16 @@ _AUTH_MAP: Mapping[str, Callable] = {
class _AuthContext(object):
def __init__(self, credentials):
def __init__(self, credentials, address):
self.credentials = credentials
self.speculative_authenticate = None
self.address = address
@staticmethod
def from_credentials(creds):
def from_credentials(creds, address):
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
if spec_cls:
return spec_cls(creds)
return spec_cls(creds, address)
return None
def speculate_command(self):
@ -512,8 +542,8 @@ class _AuthContext(object):
class _ScramContext(_AuthContext):
def __init__(self, credentials, mechanism):
super(_ScramContext, self).__init__(credentials)
def __init__(self, credentials, address, mechanism):
super(_ScramContext, self).__init__(credentials, address)
self.scram_data = None
self.mechanism = mechanism
@ -534,16 +564,30 @@ class _X509Context(_AuthContext):
return cmd
class _OIDCContext(_AuthContext):
def speculate_command(self):
authenticator = _get_authenticator(self.credentials, self.address)
cmd = authenticator.auth_start_cmd(False)
if cmd is None:
return
cmd["db"] = self.credentials.source
return cmd
_SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
"MONGODB-X509": _X509Context,
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
"MONGODB-OIDC": _OIDCContext,
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
}
def authenticate(credentials, sock_info):
def authenticate(credentials, sock_info, reauthenticate=False):
"""Authenticate sock_info."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP[mechanism]
auth_func(credentials, sock_info)
if mechanism == "MONGODB-OIDC":
_authenticate_oidc(credentials, sock_info, reauthenticate)
else:
auth_func(credentials, sock_info)

299
pymongo/auth_oidc.py Normal file
View File

@ -0,0 +1,299 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MONGODB-OIDC Authentication helpers."""
import os
import threading
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Callable, Dict, List, Optional
import bson
from bson.binary import Binary
from bson.son import SON
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE
@dataclass
class _OIDCProperties:
request_token_callback: Optional[Callable[..., Dict]]
refresh_token_callback: Optional[Callable[..., Dict]]
provider_name: Optional[str]
allowed_hosts: List[str]
"""Mechanism properties for MONGODB-OIDC authentication."""
TOKEN_BUFFER_MINUTES = 5
CALLBACK_TIMEOUT_SECONDS = 5 * 60
CACHE_TIMEOUT_MINUTES = 60 * 5
CALLBACK_VERSION = 0
_CACHE: Dict[str, "_OIDCAuthenticator"] = {}
def _get_authenticator(credentials, address):
# Clear out old items in the cache.
now_utc = datetime.now(timezone.utc)
to_remove = []
for key, value in _CACHE.items():
if value.cache_exp_utc is not None and value.cache_exp_utc < now_utc:
to_remove.append(key)
for key in to_remove:
del _CACHE[key]
# Extract values.
principal_name = credentials.username
properties = credentials.mechanism_properties
request_cb = properties.request_token_callback
refresh_cb = properties.refresh_token_callback
# Validate that the address is allowed.
if not properties.provider_name:
found = False
allowed_hosts = properties.allowed_hosts
for patt in allowed_hosts:
if patt == address[0]:
found = True
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
found = True
if not found:
raise ConfigurationError(
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
)
# Get or create the cache item.
cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}"
_CACHE.setdefault(cache_key, _OIDCAuthenticator(username=principal_name, properties=properties))
return _CACHE[cache_key]
def _get_cache_exp():
return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES)
@dataclass
class _OIDCAuthenticator:
username: str
properties: _OIDCProperties
idp_info: Optional[Dict] = field(default=None)
idp_resp: Optional[Dict] = field(default=None)
reauth_gen_id: int = field(default=0)
idp_info_gen_id: int = field(default=0)
token_gen_id: int = field(default=0)
token_exp_utc: Optional[datetime] = field(default=None)
cache_exp_utc: datetime = field(default_factory=_get_cache_exp)
lock: threading.Lock = field(default_factory=threading.Lock)
def get_current_token(self, use_callbacks=True):
properties = self.properties
request_cb = properties.request_token_callback
refresh_cb = properties.refresh_token_callback
if not use_callbacks:
request_cb = None
refresh_cb = None
current_valid_token = False
if self.token_exp_utc is not None:
now_utc = datetime.now(timezone.utc)
exp_utc = self.token_exp_utc
buffer_seconds = TOKEN_BUFFER_MINUTES * 60
if (exp_utc - now_utc).total_seconds() >= buffer_seconds:
current_valid_token = True
timeout = CALLBACK_TIMEOUT_SECONDS
if not use_callbacks and not current_valid_token:
return None
if not current_valid_token and request_cb is not None:
prev_token = self.idp_resp and self.idp_resp["access_token"]
with self.lock:
# See if the token was changed while we were waiting for the
# lock.
new_token = self.idp_resp and self.idp_resp["access_token"]
if new_token != prev_token:
return new_token
refresh_token = self.idp_resp and self.idp_resp.get("refresh_token")
refresh_token = refresh_token or ""
context = dict(
timeout_seconds=timeout,
version=CALLBACK_VERSION,
refresh_token=refresh_token,
)
if self.idp_resp is None or refresh_cb is None:
self.idp_resp = request_cb(self.idp_info, context)
elif request_cb is not None:
self.idp_resp = refresh_cb(self.idp_info, context)
cache_exp_utc = datetime.now(timezone.utc) + timedelta(
minutes=CACHE_TIMEOUT_MINUTES
)
self.cache_exp_utc = cache_exp_utc
self.token_gen_id += 1
token_result = self.idp_resp
# Validate callback return value.
if not isinstance(token_result, dict):
raise ValueError("OIDC callback returned invalid result")
if "access_token" not in token_result:
raise ValueError("OIDC callback did not return an access_token")
expected = ["access_token", "expires_in_seconds", "refesh_token"]
for key in token_result:
if key not in expected:
raise ValueError(f'Unexpected field in callback result "{key}"')
token = token_result["access_token"]
if "expires_in_seconds" in token_result:
expires_in = int(token_result["expires_in_seconds"])
buffer_seconds = TOKEN_BUFFER_MINUTES * 60
if expires_in >= buffer_seconds:
now_utc = datetime.now(timezone.utc)
exp_utc = now_utc + timedelta(seconds=expires_in)
self.token_exp_utc = exp_utc
return token
def auth_start_cmd(self, use_callbacks=True):
properties = self.properties
# Handle aws provider credentials.
if properties.provider_name == "aws":
aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"]
with open(aws_identity_file) as fid:
token = fid.read().strip()
payload = dict(jwt=token)
cmd = SON(
[
("saslStart", 1),
("mechanism", "MONGODB-OIDC"),
("payload", Binary(bson.encode(payload))),
]
)
return cmd
principal_name = self.username
if self.idp_info is not None:
self.cache_exp_utc = datetime.now(timezone.utc) + timedelta(
minutes=CACHE_TIMEOUT_MINUTES
)
if self.idp_info is None:
self.cache_exp_utc = _get_cache_exp()
if self.idp_info is None:
# Send the SASL start with the optional principal name.
payload = dict()
if principal_name:
payload["n"] = principal_name
cmd = SON(
[
("saslStart", 1),
("mechanism", "MONGODB-OIDC"),
("payload", Binary(bson.encode(payload))),
("autoAuthorize", 1),
]
)
return cmd
token = self.get_current_token(use_callbacks)
if not token:
return None
bin_payload = Binary(bson.encode(dict(jwt=token)))
return SON(
[
("saslStart", 1),
("mechanism", "MONGODB-OIDC"),
("payload", bin_payload),
]
)
def clear(self):
self.idp_info = None
self.idp_resp = None
self.token_exp_utc = None
def run_command(self, sock_info, cmd):
try:
return sock_info.command("$external", cmd, no_reauth=True)
except OperationFailure as exc:
self.clear()
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined]
if self.idp_info_gen_id > self.reauth_gen_id:
raise
return self.authenticate(sock_info, reauthenticate=True)
raise
def authenticate(self, sock_info, reauthenticate=False):
if reauthenticate:
prev_id = getattr(sock_info, "oidc_token_gen_id", None)
# Check if we've already changed tokens.
if prev_id == self.token_gen_id:
self.reauth_gen_id = self.idp_info_gen_id
self.token_exp_utc = None
if not self.properties.refresh_token_callback:
self.clear()
ctx = sock_info.auth_ctx
cmd = None
if ctx and ctx.speculate_succeeded():
resp = ctx.speculative_authenticate
else:
cmd = self.auth_start_cmd()
resp = self.run_command(sock_info, cmd)
if resp["done"]:
sock_info.oidc_token_gen_id = self.token_gen_id
return
server_resp: Dict = bson.decode(resp["payload"])
if "issuer" in server_resp:
self.idp_info = server_resp
self.idp_info_gen_id += 1
conversation_id = resp["conversationId"]
token = self.get_current_token()
sock_info.oidc_token_gen_id = self.token_gen_id
bin_payload = Binary(bson.encode(dict(jwt=token)))
cmd = SON(
[
("saslContinue", 1),
("conversationId", conversation_id),
("payload", bin_payload),
]
)
resp = self.run_command(sock_info, cmd)
if not resp["done"]:
self.clear()
raise OperationFailure("SASL conversation failed to complete.")
return resp
def _authenticate_oidc(credentials, sock_info, reauthenticate):
"""Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, sock_info.address)
return authenticator.authenticate(sock_info, reauthenticate=reauthenticate)

View File

@ -16,6 +16,7 @@
"""Functions and classes common to multiple pymongo modules."""
import datetime
import inspect
import warnings
from collections import OrderedDict, abc
from typing import (
@ -416,14 +417,48 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]]
_MECHANISM_PROPS = frozenset(
["SERVICE_NAME", "CANONICALIZE_HOST_NAME", "SERVICE_REALM", "AWS_SESSION_TOKEN"]
[
"SERVICE_NAME",
"CANONICALIZE_HOST_NAME",
"SERVICE_REALM",
"AWS_SESSION_TOKEN",
"PROVIDER_NAME",
]
)
def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]:
"""Validate authMechanismProperties."""
value = validate_string(option, value)
props: Dict[str, Any] = {}
if not isinstance(value, str):
if not isinstance(value, dict):
raise ValueError("Auth mechanism properties must be given as a string or a dictionary")
for key, value in value.items():
if isinstance(value, str):
props[key] = value
elif isinstance(value, bool):
props[key] = str(value).lower()
elif key in ["allowed_hosts"] and isinstance(value, list):
props[key] = value
elif inspect.isfunction(value):
signature = inspect.signature(value)
if key == "request_token_callback":
expected_params = 2
elif key == "refresh_token_callback":
expected_params = 2
else:
raise ValueError(f"Unrecognized Auth mechanism function {key}")
if len(signature.parameters) != expected_params:
msg = f"{key} must accept {expected_params} parameters"
raise ValueError(msg)
props[key] = value
else:
raise ValueError(
"Auth mechanism property values must be strings or callback functions"
)
return props
value = validate_string(option, value)
for opt in value.split(","):
try:
key, val = opt.split(":")
@ -715,6 +750,7 @@ KW_VALIDATORS: Dict[str, Callable[[Any, Any], Any]] = {
"password": validate_string_or_none,
"server_selector": validate_is_callable_or_none,
"auto_encryption_opts": validate_auto_encryption_opts_or_none,
"authoidcallowedhosts": validate_list,
}
# Dictionary where keys are any URI option name, and values are the

View File

@ -68,6 +68,9 @@ _RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
]
)
# Server code raised when re-authentication is required
_REAUTHENTICATION_REQUIRED_CODE = 391
def _gen_index_name(keys):
"""Generate an index name from the set of fields it is over."""
@ -267,3 +270,35 @@ def _handle_exception():
pass
finally:
del einfo
def _handle_reauth(func):
def inner(*args, **kwargs):
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.pool import SocketInfo
try:
return func(*args, **kwargs)
except OperationFailure as exc:
if no_reauth:
raise
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
# Look for an argument that either is a SocketInfo
# or has a socket_info attribute, so we can trigger
# a reauth.
sock_info = None
for arg in args:
if isinstance(arg, SocketInfo):
sock_info = arg
break
if hasattr(arg, "sock_info"):
sock_info = arg.sock_info
break
if sock_info:
sock_info.authenticate(reauthenticate=True)
else:
raise
return func(*args, **kwargs)
raise
return inner

View File

@ -54,6 +54,7 @@ from pymongo.errors import (
ProtocolError,
)
from pymongo.hello import HelloCompat
from pymongo.helpers import _handle_reauth
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
@ -909,6 +910,7 @@ class _BulkWriteContext(object):
self.start_time = datetime.datetime.now()
return result
@_handle_reauth
def write_command(self, cmd, request_id, msg, docs):
"""A proxy for SocketInfo.write_command that handles event publishing."""
if self.publish:

View File

@ -57,6 +57,7 @@ from pymongo.errors import (
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers import _handle_reauth
from pymongo.lock import _create_lock
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
from pymongo.network import command, receive_message
@ -756,7 +757,7 @@ class SocketInfo(object):
if creds:
if creds.mechanism == "DEFAULT" and creds.username:
cmd["saslSupportedMechs"] = creds.source + "." + creds.username
auth_ctx = auth._AuthContext.from_credentials(creds)
auth_ctx = auth._AuthContext.from_credentials(creds, self.address)
if auth_ctx:
cmd["speculativeAuthenticate"] = auth_ctx.speculate_command()
else:
@ -813,6 +814,7 @@ class SocketInfo(object):
helpers._check_command_response(response_doc, self.max_wire_version)
return response_doc
@_handle_reauth
def command(
self,
dbname,
@ -966,17 +968,22 @@ class SocketInfo(object):
helpers._check_command_response(result, self.max_wire_version)
return result
def authenticate(self):
def authenticate(self, reauthenticate=False):
"""Authenticate to the server if needed.
Can raise ConnectionFailure or OperationFailure.
"""
# CMAP spec says to publish the ready event only after authenticating
# the connection.
if reauthenticate:
if self.performed_handshake:
# Existing auth_ctx is stale, remove it.
self.auth_ctx = None
self.ready = False
if not self.ready:
creds = self.opts._credentials
if creds:
auth.authenticate(creds, self)
auth.authenticate(creds, self, reauthenticate=reauthenticate)
self.ready = True
if self.enabled_for_cmap:
self.listeners.publish_connection_ready(self.address, self.id)

View File

@ -18,7 +18,7 @@ from datetime import datetime
from bson import _decode_all_selective
from pymongo.errors import NotPrimaryError, OperationFailure
from pymongo.helpers import _check_command_response
from pymongo.helpers import _check_command_response, _handle_reauth
from pymongo.message import _convert_exception, _OpMsg
from pymongo.response import PinnedResponse, Response
@ -73,6 +73,7 @@ class Server(object):
"""Check the server's state soon."""
self._monitor.request_check()
@_handle_reauth
def run_operation(self, sock_info, operation, read_preference, listeners, unpack_res):
"""Run a _Query or _GetMore operation and return a Response object.

View File

@ -444,6 +444,133 @@
"AWS_SESSION_TOKEN": "token!@#$%^&*()_+"
}
}
},
{
"description": "should recognise the mechanism and request callback (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
"callback": ["oidcRequest"],
"valid": true,
"credential": {
"username": null,
"password": null,
"source": "$external",
"mechanism": "MONGODB-OIDC",
"mechanism_properties": {
"REQUEST_TOKEN_CALLBACK": true
}
}
},
{
"description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external",
"callback": ["oidcRequest"],
"valid": true,
"credential": {
"username": null,
"password": null,
"source": "$external",
"mechanism": "MONGODB-OIDC",
"mechanism_properties": {
"REQUEST_TOKEN_CALLBACK": true
}
}
},
{
"description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
"callback": ["oidcRequest", "oidcRefresh"],
"valid": true,
"credential": {
"username": null,
"password": null,
"source": "$external",
"mechanism": "MONGODB-OIDC",
"mechanism_properties": {
"REQUEST_TOKEN_CALLBACK": true,
"REFRESH_TOKEN_CALLBACK": true
}
}
},
{
"description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)",
"uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC",
"callback": ["oidcRequest"],
"valid": true,
"credential": {
"username": "principalName",
"password": null,
"source": "$external",
"mechanism": "MONGODB-OIDC",
"mechanism_properties": {
"REQUEST_TOKEN_CALLBACK": true
}
}
},
{
"description": "should recognise the mechanism with aws device (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws",
"valid": true,
"credential": {
"username": null,
"password": null,
"source": "$external",
"mechanism": "MONGODB-OIDC",
"mechanism_properties": {
"PROVIDER_NAME": "aws"
}
}
},
{
"description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws",
"valid": true,
"credential": {
"username": null,
"password": null,
"source": "$external",
"mechanism": "MONGODB-OIDC",
"mechanism_properties": {
"PROVIDER_NAME": "aws"
}
}
},
{
"description": "should throw an exception if username and password are specified (MONGODB-OIDC)",
"uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC",
"callback": ["oidcRequest"],
"valid": false,
"credential": null
},
{
"description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)",
"uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp",
"valid": false,
"credential": null
},
{
"description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted",
"valid": false,
"credential": null
},
{
"description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
"valid": false,
"credential": null
},
{
"description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
"callback": ["oidcRefresh"],
"valid": false,
"credential": null
},
{
"description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted",
"valid": false,
"credential": null
}
]
}
}

View File

@ -0,0 +1,191 @@
{
"description": "reauthenticate_with_retry",
"schemaVersion": "1.12",
"runOnRequirements": [
{
"minServerVersion": "6.3",
"auth": true
}
],
"createEntities": [
{
"client": {
"id": "client0",
"uriOptions": {
"retryReads": true,
"retryWrites": true
},
"observeEvents": [
"commandStartedEvent",
"commandSucceededEvent",
"commandFailedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "db"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "collName"
}
}
],
"initialData": [
{
"collectionName": "collName",
"databaseName": "db",
"documents": []
}
],
"tests": [
{
"description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find"
],
"errorCode": 391
}
}
}
},
{
"name": "find",
"arguments": {
"filter": {}
},
"object": "collection0",
"expectResult": []
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"find": "collName",
"filter": {}
}
}
},
{
"commandFailedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"command": {
"find": "collName",
"filter": {}
}
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
}
]
}
]
},
{
"description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"errorCode": 391
}
}
}
},
{
"name": "insertOne",
"object": "collection0",
"arguments": {
"document": {
"_id": 1,
"x": 1
}
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "collName",
"documents": [
{
"_id": 1,
"x": 1
}
]
}
}
},
{
"commandFailedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "collName",
"documents": [
{
"_id": 1,
"x": 1
}
]
}
}
},
{
"commandSucceededEvent": {
"commandName": "insert"
}
}
]
}
]
}
]
}

View File

@ -0,0 +1,191 @@
{
"description": "reauthenticate_without_retry",
"schemaVersion": "1.12",
"runOnRequirements": [
{
"minServerVersion": "6.3",
"auth": true
}
],
"createEntities": [
{
"client": {
"id": "client0",
"uriOptions": {
"retryReads": false,
"retryWrites": false
},
"observeEvents": [
"commandStartedEvent",
"commandSucceededEvent",
"commandFailedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "db"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "collName"
}
}
],
"initialData": [
{
"collectionName": "collName",
"databaseName": "db",
"documents": []
}
],
"tests": [
{
"description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find"
],
"errorCode": 391
}
}
}
},
{
"name": "find",
"arguments": {
"filter": {}
},
"object": "collection0",
"expectResult": []
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"find": "collName",
"filter": {}
}
}
},
{
"commandFailedEvent": {
"commandName": "find"
}
},
{
"commandStartedEvent": {
"command": {
"find": "collName",
"filter": {}
}
}
},
{
"commandSucceededEvent": {
"commandName": "find"
}
}
]
}
]
},
{
"description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false",
"operations": [
{
"name": "failPoint",
"object": "testRunner",
"arguments": {
"client": "client0",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"errorCode": 391
}
}
}
},
{
"name": "insertOne",
"object": "collection0",
"arguments": {
"document": {
"_id": 1,
"x": 1
}
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "collName",
"documents": [
{
"_id": 1,
"x": 1
}
]
}
}
},
{
"commandFailedEvent": {
"commandName": "insert"
}
},
{
"commandStartedEvent": {
"command": {
"insert": "collName",
"documents": [
{
"_id": 1,
"x": 1
}
]
}
}
},
{
"commandSucceededEvent": {
"commandName": "insert"
}
}
]
}
]
}
]
}

View File

@ -0,0 +1,821 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test MONGODB-OIDC Authentication."""
import os
import sys
import threading
import time
import unittest
from contextlib import contextmanager
from typing import Dict
sys.path[0:0] = [""]
from test.utils import EventListener
from bson import SON
from pymongo import MongoClient
from pymongo.auth_oidc import _CACHE as _oidc_cache
from pymongo.cursor import CursorType
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.hello import HelloCompat
from pymongo.operations import InsertOne
class TestAuthOIDC(unittest.TestCase):
uri: str
@classmethod
def setUpClass(cls):
cls.uri_single = os.environ["MONGODB_URI_SINGLE"]
cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"]
cls.uri_admin = os.environ["MONGODB_URI"]
cls.token_dir = os.environ["OIDC_TOKEN_DIR"]
def setUp(self):
self.request_called = 0
self.refresh_called = 0
_oidc_cache.clear()
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1")
def create_request_cb(self, username="test_user1", expires_in_seconds=None, sleep=0):
token_file = os.path.join(self.token_dir, username)
def request_token(server_info, context):
# Validate the info.
self.assertIn("issuer", server_info)
self.assertIn("clientId", server_info)
# Validate the timeout.
timeout_seconds = context["timeout_seconds"]
self.assertEqual(timeout_seconds, 60 * 5)
with open(token_file) as fid:
token = fid.read()
resp = dict(access_token=token)
time.sleep(sleep)
if expires_in_seconds is not None:
resp["expires_in_seconds"] = expires_in_seconds
self.request_called += 1
return resp
return request_token
def create_refresh_cb(self, username="test_user1", expires_in_seconds=None):
token_file = os.path.join(self.token_dir, username)
def refresh_token(server_info, context):
with open(token_file) as fid:
token = fid.read()
# Validate the info.
self.assertIn("issuer", server_info)
self.assertIn("clientId", server_info)
# Validate the creds
self.assertIsNotNone(context["refresh_token"])
# Validate the timeout.
self.assertEqual(context["timeout_seconds"], 60 * 5)
resp = dict(access_token=token)
if expires_in_seconds is not None:
resp["expires_in_seconds"] = expires_in_seconds
self.refresh_called += 1
return resp
return refresh_token
@contextmanager
def fail_point(self, command_args):
cmd_on = SON([("configureFailPoint", "failCommand")])
cmd_on.update(command_args)
client = MongoClient(self.uri_admin)
client.admin.command(cmd_on)
try:
yield
finally:
client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off")
def test_connect_callbacks_single_implicit_username(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token)
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_callbacks_single_explicit_username(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token)
client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_principal_user1(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token)
client = MongoClient(
self.uri_multiple, username="test_user1", authmechanismproperties=props
)
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_principal_user2(self):
request_token = self.create_request_cb("test_user2")
props: Dict = dict(request_token_callback=request_token)
client = MongoClient(
self.uri_multiple, username="test_user2", authmechanismproperties=props
)
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_no_username(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token)
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
with self.assertRaises(OperationFailure):
client.test.test.find_one()
client.close()
def test_allowed_hosts_blocked(self):
request_token = self.create_request_cb()
props: Dict = dict(request_token_callback=request_token, allowed_hosts=[])
client = MongoClient(self.uri_single, authmechanismproperties=props)
with self.assertRaises(ConfigurationError):
client.test.test.find_one()
client.close()
props: Dict = dict(request_token_callback=request_token, allowed_hosts=["example.com"])
client = MongoClient(
self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False
)
with self.assertRaises(ConfigurationError):
client.test.test.find_one()
client.close()
def test_connect_aws_single_principal(self):
props = dict(PROVIDER_NAME="aws")
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_multiple_principal_user1(self):
props = dict(PROVIDER_NAME="aws")
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_multiple_principal_user2(self):
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2")
props = dict(PROVIDER_NAME="aws")
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_allowed_hosts_ignored(self):
props = dict(PROVIDER_NAME="aws", allowed_hosts=[])
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_valid_callbacks(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
refresh_cb = self.create_refresh_cb()
props: Dict = dict(
request_token_callback=request_cb,
refresh_token_callback=refresh_cb,
)
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_lock_avoids_extra_callbacks(self):
request_cb = self.create_request_cb(sleep=0.5)
refresh_cb = self.create_refresh_cb()
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
def run_test():
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
t1 = threading.Thread(target=run_test)
t2 = threading.Thread(target=run_test)
t1.start()
t2.start()
t1.join()
t2.join()
self.assertEqual(self.request_called, 1)
self.assertEqual(self.refresh_called, 2)
def test_request_callback_returns_null(self):
def request_token_null(a, b):
return None
props: Dict = dict(request_token_callback=request_token_null)
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_refresh_callback_returns_null(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_token_null(a, b):
return None
props: Dict = dict(
request_token_callback=request_cb, refresh_token_callback=refresh_token_null
)
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_request_callback_invalid_result(self):
def request_token_invalid(a, b):
return dict()
props: Dict = dict(request_token_callback=request_token_invalid)
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def request_cb_extra_value(server_info, context):
result = self.create_request_cb()(server_info, context)
result["foo"] = "bar"
return result
props: Dict = dict(request_token_callback=request_cb_extra_value)
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_refresh_callback_missing_data(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_cb_no_token(a, b):
return dict()
props: Dict = dict(
request_token_callback=request_cb, refresh_token_callback=refresh_cb_no_token
)
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_refresh_callback_extra_data(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_cb_extra_value(server_info, context):
result = self.create_refresh_cb()(server_info, context)
result["foo"] = "bar"
return result
props: Dict = dict(
request_token_callback=request_cb, refresh_token_callback=refresh_cb_extra_value
)
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_cache_with_refresh(self):
# Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token.
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
request_cb = self.create_request_cb(expires_in_seconds=60)
refresh_cb = self.create_refresh_cb()
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
# Ensure that a ``find`` operation adds credentials to the cache.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(len(_oidc_cache), 1)
# Create a new client with the same request callback and a refresh callback.
# Ensure that a ``find`` operation results in a call to the refresh callback.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(len(_oidc_cache), 1)
def test_cache_with_no_refresh(self):
# Create a new client with a request callback callback.
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
request_cb = self.create_request_cb()
props = dict(request_token_callback=request_cb)
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that a ``find`` operation adds credentials to the cache.
self.request_called = 0
client.test.test.find_one()
client.close()
self.assertEqual(self.request_called, 1)
self.assertEqual(len(_oidc_cache), 1)
# Create a new client with the same request callback.
# Ensure that a ``find`` operation results in a call to the request callback.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(self.request_called, 2)
self.assertEqual(len(_oidc_cache), 1)
def test_cache_key_includes_callback(self):
request_cb = self.create_request_cb()
props: Dict = dict(request_token_callback=request_cb)
# Ensure that a ``find`` operation adds a new entry to the cache.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
# Create a new client with a different request callback.
def request_token_2(a, b):
return request_cb(a, b)
props["request_token_callback"] = request_token_2
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that a ``find`` operation adds a new entry to the cache.
client.test.test.find_one()
client.close()
self.assertEqual(len(_oidc_cache), 2)
def test_cache_clears_on_error(self):
request_cb = self.create_request_cb()
# Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials.
def refresh_cb(a, b):
return dict(access_token="bad")
# Add a token to the cache that will expire soon.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
# Create a new client with the same callbacks.
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that another ``find`` operation results in an error.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
client.close()
# Ensure that the cache has been cleared.
authenticator = list(_oidc_cache.values())[0]
self.assertIsNone(authenticator.idp_info)
def test_cache_is_not_used_in_aws_automatic_workflow(self):
# Create a new client using the AWS device workflow.
# Ensure that a ``find`` operation does not add credentials to the cache.
props = dict(PROVIDER_NAME="aws")
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
# Ensure that the cache has been cleared.
authenticator = list(_oidc_cache.values())[0]
self.assertIsNone(authenticator.idp_info)
def test_speculative_auth_success(self):
# Clear the cache
_oidc_cache.clear()
token_file = os.path.join(self.token_dir, "test_user1")
def request_token(a, b):
with open(token_file) as fid:
token = fid.read()
return dict(access_token=token, expires_in_seconds=1000)
# Create a client with a request callback that returns a valid token
# that will not expire soon.
props: Dict = dict(request_token_callback=request_token)
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Set a fail point for saslStart commands.
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["saslStart"], "errorCode": 18},
}
):
# Perform a find operation.
client.test.test.find_one()
# Close the client.
client.close()
# Create a new client.
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Set a fail point for saslStart commands.
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["saslStart"], "errorCode": 18},
}
):
# Perform a find operation.
client.test.test.find_one()
# Close the client.
client.close()
def test_reauthenticate_succeeds(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
# Perform a find operation.
client.test.test.find_one()
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
listener.reset()
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation.
client.test.test.find_one()
started_events = [
i.command_name for i in listener.started_events if not i.command_name.startswith("sasl")
]
succeeded_events = [
i.command_name
for i in listener.succeeded_events
if not i.command_name.startswith("sasl")
]
failed_events = [
i.command_name for i in listener.failed_events if not i.command_name.startswith("sasl")
]
self.assertEqual(
started_events,
[
"find",
"find",
],
)
self.assertEqual(succeeded_events, ["find"])
self.assertEqual(failed_events, ["find"])
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_bulk_write(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
client.test.test.find_one()
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["insert"], "errorCode": 391},
}
):
# Perform a bulk write operation.
client.test.test.bulk_write([InsertOne({})])
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_bulk_read(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
client.test.test.find_one()
# Perform a bulk write operation.
client.test.test.bulk_write([InsertOne({})])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a bulk read operation.
cursor = client.test.test.find_raw_batches({})
list(cursor)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_cursor(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_one({"a": 1})
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1})
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_get_more(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_many([{"a": 1}, {"a": 1}])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["getMore"], "errorCode": 391},
}
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1}, batch_size=1)
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_get_more_exhaust(self):
# Ensure no mongos
props = dict(PROVIDER_NAME="aws")
client = MongoClient(self.uri_single, authmechanismproperties=props)
hello = client.admin.command(HelloCompat.LEGACY_CMD)
if hello.get("msg") != "isdbgrid":
raise unittest.SkipTest("Must not be a mongos")
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_many([{"a": 1}, {"a": 1}])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["getMore"], "errorCode": 391},
}
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST)
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_command(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
print("start of test")
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_one({"a": 1})
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["count"], "errorCode": 391},
}
):
# Perform a count operation.
cursor = client.test.command(dict(count="test"))
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_retries_and_succeeds_with_cache(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
# Perform a find operation.
client.test.test.find_one()
# Set a fail point for ``saslStart`` commands of the form
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client.test.test.find_one()
# Close the client.
client.close()
def test_reauthenticate_fails_with_no_cache(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
# Perform a find operation.
client.test.test.find_one()
# Clear the cache.
_oidc_cache.clear()
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
}
):
# Perform a find operation that fails.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
client.close()
def test_late_reauth_avoids_callback(self):
# Step 1: connect with both clients
request_cb = self.create_request_cb(expires_in_seconds=1e6)
refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6)
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
client1 = MongoClient(self.uri_single, authMechanismProperties=props)
client1.test.test.find_one()
client2 = MongoClient(self.uri_single, authMechanismProperties=props)
client2.test.test.find_one()
self.assertEqual(self.refresh_called, 0)
self.assertEqual(self.request_called, 1)
# Step 2: cause a find 391 on the first client
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client1.test.test.find_one()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(self.request_called, 1)
# Step 3: cause a find 391 on the second client
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client2.test.test.find_one()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(self.request_called, 1)
client1.close()
client2.close()
if __name__ == "__main__":
unittest.main()

View File

@ -22,6 +22,7 @@ import sys
sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
from pymongo import MongoClient
@ -41,7 +42,16 @@ def create_test(test_case):
if not valid:
self.assertRaises(Exception, MongoClient, uri, connect=False)
else:
client = MongoClient(uri, connect=False)
props = {}
if credential:
props = credential["mechanism_properties"] or {}
if props.get("REQUEST_TOKEN_CALLBACK"):
props["request_token_callback"] = lambda x, y: 1
del props["REQUEST_TOKEN_CALLBACK"]
if props.get("REFRESH_TOKEN_CALLBACK"):
props["refresh_token_callback"] = lambda a, b: 1
del props["REFRESH_TOKEN_CALLBACK"]
client = MongoClient(uri, connect=False, authmechanismproperties=props)
credentials = client.options.pool_options._credentials
if credential is None:
self.assertIsNone(credentials)
@ -70,6 +80,16 @@ def create_test(test_case):
self.assertEqual(
actual.aws_session_token, expected["AWS_SESSION_TOKEN"]
)
elif "PROVIDER_NAME" in expected:
self.assertEqual(actual.provider_name, expected["PROVIDER_NAME"])
elif "request_token_callback" in expected:
self.assertEqual(
actual.request_token_callback, expected["request_token_callback"]
)
elif "refresh_token_callback" in expected:
self.assertEqual(
actual.refresh_token_callback, expected["refresh_token_callback"]
)
else:
self.fail("Unhandled property: %s" % (key,))
else:
@ -82,7 +102,7 @@ def create_test(test_case):
def create_tests():
for filename in glob.glob(os.path.join(_TEST_PATH, "*.json")):
for filename in glob.glob(os.path.join(_TEST_PATH, "legacy", "*.json")):
test_suffix, _ = os.path.splitext(os.path.basename(filename))
with open(filename) as auth_tests:
test_cases = json.load(auth_tests)["tests"]
@ -97,5 +117,12 @@ def create_tests():
create_tests()
globals().update(
generate_test_classes(
os.path.join(_TEST_PATH, "unified"),
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()