mongo-python-driver/test/auth_aws/test_auth_oidc.py
2023-06-13 11:30:50 -05:00

829 lines
30 KiB
Python

# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test MONGODB-OIDC Authentication."""
import os
import sys
import threading
import time
import unittest
from contextlib import contextmanager
from typing import Dict
sys.path[0:0] = [""]
from test.utils import EventListener
from bson import SON
from pymongo import MongoClient
from pymongo.auth import _AUTH_MAP, _authenticate_oidc
from pymongo.auth_oidc import _CACHE as _oidc_cache
from pymongo.cursor import CursorType
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.hello import HelloCompat
from pymongo.operations import InsertOne
# Force MONGODB-OIDC to be enabled.
_AUTH_MAP["MONGODB-OIDC"] = _authenticate_oidc # type:ignore
class TestAuthOIDC(unittest.TestCase):
uri: str
@classmethod
def setUpClass(cls):
cls.uri_single = os.environ["MONGODB_URI_SINGLE"]
cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"]
cls.uri_admin = os.environ["MONGODB_URI"]
cls.token_dir = os.environ["OIDC_TOKEN_DIR"]
def setUp(self):
self.request_called = 0
self.refresh_called = 0
_oidc_cache.clear()
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1")
def create_request_cb(self, username="test_user1", expires_in_seconds=None, sleep=0):
token_file = os.path.join(self.token_dir, username)
def request_token(server_info, context):
# Validate the info.
self.assertIn("issuer", server_info)
self.assertIn("clientId", server_info)
# Validate the timeout.
timeout_seconds = context["timeout_seconds"]
self.assertEqual(timeout_seconds, 60 * 5)
with open(token_file) as fid:
token = fid.read()
resp = {"access_token": token}
time.sleep(sleep)
if expires_in_seconds is not None:
resp["expires_in_seconds"] = expires_in_seconds
self.request_called += 1
return resp
return request_token
def create_refresh_cb(self, username="test_user1", expires_in_seconds=None):
token_file = os.path.join(self.token_dir, username)
def refresh_token(server_info, context):
with open(token_file) as fid:
token = fid.read()
# Validate the info.
self.assertIn("issuer", server_info)
self.assertIn("clientId", server_info)
# Validate the creds
self.assertIsNotNone(context["refresh_token"])
# Validate the timeout.
self.assertEqual(context["timeout_seconds"], 60 * 5)
resp = {"access_token": token}
if expires_in_seconds is not None:
resp["expires_in_seconds"] = expires_in_seconds
self.refresh_called += 1
return resp
return refresh_token
@contextmanager
def fail_point(self, command_args):
cmd_on = SON([("configureFailPoint", "failCommand")])
cmd_on.update(command_args)
client = MongoClient(self.uri_admin)
client.admin.command(cmd_on)
try:
yield
finally:
client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off")
def test_connect_callbacks_single_implicit_username(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_callbacks_single_explicit_username(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_principal_user1(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token}
client = MongoClient(
self.uri_multiple, username="test_user1", authmechanismproperties=props
)
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_principal_user2(self):
request_token = self.create_request_cb("test_user2")
props: Dict = {"request_token_callback": request_token}
client = MongoClient(
self.uri_multiple, username="test_user2", authmechanismproperties=props
)
client.test.test.find_one()
client.close()
def test_connect_callbacks_multiple_no_username(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
with self.assertRaises(OperationFailure):
client.test.test.find_one()
client.close()
def test_allowed_hosts_blocked(self):
request_token = self.create_request_cb()
props: Dict = {"request_token_callback": request_token, "allowed_hosts": []}
client = MongoClient(self.uri_single, authmechanismproperties=props)
with self.assertRaises(ConfigurationError):
client.test.test.find_one()
client.close()
props: Dict = {"request_token_callback": request_token, "allowed_hosts": ["example.com"]}
client = MongoClient(
self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False
)
with self.assertRaises(ConfigurationError):
client.test.test.find_one()
client.close()
def test_connect_aws_single_principal(self):
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_multiple_principal_user1(self):
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_multiple_principal_user2(self):
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2")
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_connect_aws_allowed_hosts_ignored(self):
props = {"PROVIDER_NAME": "aws", "allowed_hosts": []}
client = MongoClient(self.uri_multiple, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_valid_callbacks(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
refresh_cb = self.create_refresh_cb()
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb,
}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
def test_lock_avoids_extra_callbacks(self):
request_cb = self.create_request_cb(sleep=0.5)
refresh_cb = self.create_refresh_cb()
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
def run_test():
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
t1 = threading.Thread(target=run_test)
t2 = threading.Thread(target=run_test)
t1.start()
t2.start()
t1.join()
t2.join()
self.assertEqual(self.request_called, 1)
self.assertEqual(self.refresh_called, 2)
def test_request_callback_returns_null(self):
def request_token_null(a, b):
return None
props: Dict = {"request_token_callback": request_token_null}
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_refresh_callback_returns_null(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_token_null(a, b):
return None
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_token_null,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_request_callback_invalid_result(self):
def request_token_invalid(a, b):
return {}
props: Dict = {"request_token_callback": request_token_invalid}
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def request_cb_extra_value(server_info, context):
result = self.create_request_cb()(server_info, context)
result["foo"] = "bar"
return result
props: Dict = {"request_token_callback": request_cb_extra_value}
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_refresh_callback_missing_data(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_cb_no_token(a, b):
return {}
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb_no_token,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_refresh_callback_extra_data(self):
request_cb = self.create_request_cb(expires_in_seconds=60)
def refresh_cb_extra_value(server_info, context):
result = self.create_refresh_cb()(server_info, context)
result["foo"] = "bar"
return result
props: Dict = {
"request_token_callback": request_cb,
"refresh_token_callback": refresh_cb_extra_value,
}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
client = MongoClient(self.uri_single, authMechanismProperties=props)
with self.assertRaises(ValueError):
client.test.test.find_one()
client.close()
def test_cache_with_refresh(self):
# Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token.
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
request_cb = self.create_request_cb(expires_in_seconds=60)
refresh_cb = self.create_refresh_cb()
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
# Ensure that a ``find`` operation adds credentials to the cache.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(len(_oidc_cache), 1)
# Create a new client with the same request callback and a refresh callback.
# Ensure that a ``find`` operation results in a call to the refresh callback.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(len(_oidc_cache), 1)
def test_cache_with_no_refresh(self):
# Create a new client with a request callback callback.
# Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute.
request_cb = self.create_request_cb()
props = {"request_token_callback": request_cb}
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that a ``find`` operation adds credentials to the cache.
self.request_called = 0
client.test.test.find_one()
client.close()
self.assertEqual(self.request_called, 1)
self.assertEqual(len(_oidc_cache), 1)
# Create a new client with the same request callback.
# Ensure that a ``find`` operation results in a call to the request callback.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
self.assertEqual(self.request_called, 2)
self.assertEqual(len(_oidc_cache), 1)
def test_cache_key_includes_callback(self):
request_cb = self.create_request_cb()
props: Dict = {"request_token_callback": request_cb}
# Ensure that a ``find`` operation adds a new entry to the cache.
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
# Create a new client with a different request callback.
def request_token_2(a, b):
return request_cb(a, b)
props["request_token_callback"] = request_token_2
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that a ``find`` operation adds a new entry to the cache.
client.test.test.find_one()
client.close()
self.assertEqual(len(_oidc_cache), 2)
def test_cache_clears_on_error(self):
request_cb = self.create_request_cb()
# Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials.
def refresh_cb(a, b):
return {"access_token": "bad"}
# Add a token to the cache that will expire soon.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authMechanismProperties=props)
client.test.test.find_one()
client.close()
# Create a new client with the same callbacks.
client = MongoClient(self.uri_single, authMechanismProperties=props)
# Ensure that another ``find`` operation results in an error.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
client.close()
# Ensure that the cache has been cleared.
authenticator = list(_oidc_cache.values())[0]
self.assertIsNone(authenticator.idp_info)
def test_cache_is_not_used_in_aws_automatic_workflow(self):
# Create a new client using the AWS device workflow.
# Ensure that a ``find`` operation does not add credentials to the cache.
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_single, authmechanismproperties=props)
client.test.test.find_one()
client.close()
# Ensure that the cache has been cleared.
authenticator = list(_oidc_cache.values())[0]
self.assertIsNone(authenticator.idp_info)
def test_speculative_auth_success(self):
# Clear the cache
_oidc_cache.clear()
token_file = os.path.join(self.token_dir, "test_user1")
def request_token(a, b):
with open(token_file) as fid:
token = fid.read()
return {"access_token": token, "expires_in_seconds": 1000}
# Create a client with a request callback that returns a valid token
# that will not expire soon.
props: Dict = {"request_token_callback": request_token}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Set a fail point for saslStart commands.
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["saslStart"], "errorCode": 18},
}
):
# Perform a find operation.
client.test.test.find_one()
# Close the client.
client.close()
# Create a new client.
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Set a fail point for saslStart commands.
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["saslStart"], "errorCode": 18},
}
):
# Perform a find operation.
client.test.test.find_one()
# Close the client.
client.close()
def test_reauthenticate_succeeds(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
# Perform a find operation.
client.test.test.find_one()
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
listener.reset()
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation.
client.test.test.find_one()
started_events = [
i.command_name for i in listener.started_events if not i.command_name.startswith("sasl")
]
succeeded_events = [
i.command_name
for i in listener.succeeded_events
if not i.command_name.startswith("sasl")
]
failed_events = [
i.command_name for i in listener.failed_events if not i.command_name.startswith("sasl")
]
self.assertEqual(
started_events,
[
"find",
"find",
],
)
self.assertEqual(succeeded_events, ["find"])
self.assertEqual(failed_events, ["find"])
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_bulk_write(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
client.test.test.find_one()
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["insert"], "errorCode": 391},
}
):
# Perform a bulk write operation.
client.test.test.bulk_write([InsertOne({})])
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_bulk_read(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform a find operation.
client.test.test.find_one()
# Perform a bulk write operation.
client.test.test.bulk_write([InsertOne({})])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a bulk read operation.
cursor = client.test.test.find_raw_batches({})
list(cursor)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_cursor(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_one({"a": 1})
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1})
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_get_more(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_many([{"a": 1}, {"a": 1}])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["getMore"], "errorCode": 391},
}
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1}, batch_size=1)
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_get_more_exhaust(self):
# Ensure no mongos
props = {"PROVIDER_NAME": "aws"}
client = MongoClient(self.uri_single, authmechanismproperties=props)
hello = client.admin.command(HelloCompat.LEGACY_CMD)
if hello.get("msg") != "isdbgrid":
raise unittest.SkipTest("Must not be a mongos")
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_many([{"a": 1}, {"a": 1}])
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["getMore"], "errorCode": 391},
}
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST)
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_succeeds_command(self):
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
print("start of test")
client = MongoClient(self.uri_single, authmechanismproperties=props)
# Perform an insert operation.
client.test.test.insert_one({"a": 1})
# Assert that the refresh callback has not been called.
self.assertEqual(self.refresh_called, 0)
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["count"], "errorCode": 391},
}
):
# Perform a count operation.
cursor = client.test.command({"count": "test"})
self.assertGreaterEqual(len(list(cursor)), 1)
# Assert that the refresh callback has been called.
self.assertEqual(self.refresh_called, 1)
client.close()
def test_reauthenticate_retries_and_succeeds_with_cache(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
# Perform a find operation.
client.test.test.find_one()
# Set a fail point for ``saslStart`` commands of the form
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client.test.test.find_one()
# Close the client.
client.close()
def test_reauthenticate_fails_with_no_cache(self):
listener = EventListener()
# Create request and refresh callbacks that return valid credentials
# that will not expire soon.
request_cb = self.create_request_cb()
refresh_cb = self.create_refresh_cb()
# Create a client with the callbacks.
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client = MongoClient(
self.uri_single, event_listeners=[listener], authmechanismproperties=props
)
# Perform a find operation.
client.test.test.find_one()
# Clear the cache.
_oidc_cache.clear()
with self.fail_point(
{
"mode": {"times": 2},
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
}
):
# Perform a find operation that fails.
with self.assertRaises(OperationFailure):
client.test.test.find_one()
client.close()
def test_late_reauth_avoids_callback(self):
# Step 1: connect with both clients
request_cb = self.create_request_cb(expires_in_seconds=1e6)
refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6)
props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb}
client1 = MongoClient(self.uri_single, authMechanismProperties=props)
client1.test.test.find_one()
client2 = MongoClient(self.uri_single, authMechanismProperties=props)
client2.test.test.find_one()
self.assertEqual(self.refresh_called, 0)
self.assertEqual(self.request_called, 1)
# Step 2: cause a find 391 on the first client
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client1.test.test.find_one()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(self.request_called, 1)
# Step 3: cause a find 391 on the second client
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Perform a find operation that succeeds.
client2.test.test.find_one()
self.assertEqual(self.refresh_called, 1)
self.assertEqual(self.request_called, 1)
client1.close()
client2.close()
if __name__ == "__main__":
unittest.main()