Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
beee2ec3ef
@ -15,6 +15,7 @@
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
@ -36,6 +37,7 @@ from pymongo.auth_oidc_shared import (
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
|
||||
from pymongo.lock import Lock, _async_create_lock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
@ -81,7 +83,11 @@ class _OIDCAuthenticator:
|
||||
access_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
token_gen_id: int = field(default=0)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
if not _IS_SYNC:
|
||||
lock: Lock = field(default_factory=_async_create_lock) # type: ignore[assignment]
|
||||
else:
|
||||
lock: threading.Lock = field(default_factory=_async_create_lock) # type: ignore[assignment, no-redef]
|
||||
|
||||
last_call_time: float = field(default=0)
|
||||
|
||||
async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
|
||||
@ -164,7 +170,7 @@ class _OIDCAuthenticator:
|
||||
# Attempt to authenticate with a JwtStepRequest.
|
||||
return await self._sasl_continue_jwt(conn, start_resp)
|
||||
|
||||
def _get_access_token(self) -> Optional[str]:
|
||||
async def _get_access_token(self) -> Optional[str]:
|
||||
properties = self.properties
|
||||
cb: Union[None, OIDCCallback]
|
||||
resp: OIDCCallbackResult
|
||||
@ -186,7 +192,7 @@ class _OIDCAuthenticator:
|
||||
return None
|
||||
|
||||
if not prev_token and cb is not None:
|
||||
with self.lock:
|
||||
async with self.lock: # type: ignore[attr-defined]
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.access_token
|
||||
@ -196,7 +202,7 @@ class _OIDCAuthenticator:
|
||||
# Ensure that we are waiting a min time between callback invocations.
|
||||
delta = time.time() - self.last_call_time
|
||||
if delta < TIME_BETWEEN_CALLS_SECONDS:
|
||||
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
|
||||
await asyncio.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
|
||||
self.last_call_time = time.time()
|
||||
|
||||
if is_human:
|
||||
@ -211,7 +217,10 @@ class _OIDCAuthenticator:
|
||||
idp_info=self.idp_info,
|
||||
username=self.properties.username,
|
||||
)
|
||||
resp = cb.fetch(context)
|
||||
if not _IS_SYNC:
|
||||
resp = await asyncio.get_running_loop().run_in_executor(None, cb.fetch, context) # type: ignore[assignment]
|
||||
else:
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError(
|
||||
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
|
||||
@ -253,13 +262,13 @@ class _OIDCAuthenticator:
|
||||
start_payload: dict = bson.decode(start_resp["payload"])
|
||||
if "issuer" in start_payload:
|
||||
self.idp_info = OIDCIdPInfo(**start_payload)
|
||||
access_token = self._get_access_token()
|
||||
access_token = await self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
|
||||
return await self._run_command(conn, cmd)
|
||||
|
||||
async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]:
|
||||
access_token = self._get_access_token()
|
||||
access_token = await self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_start_command({"jwt": access_token})
|
||||
return await self._run_command(conn, cmd)
|
||||
|
||||
@ -1130,7 +1130,6 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
except BaseException:
|
||||
await self.close()
|
||||
raise
|
||||
|
||||
self._address = response.address
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
|
||||
@ -64,7 +64,7 @@ def _handle_reauth(func: F) -> F:
|
||||
await conn.authenticate(reauthenticate=True)
|
||||
else:
|
||||
raise
|
||||
return func(*args, **kwargs)
|
||||
return await func(*args, **kwargs)
|
||||
raise
|
||||
|
||||
return cast(F, inner)
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
@ -36,6 +37,7 @@ from pymongo.auth_oidc_shared import (
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
|
||||
from pymongo.lock import Lock, _create_lock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
@ -81,7 +83,11 @@ class _OIDCAuthenticator:
|
||||
access_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
token_gen_id: int = field(default=0)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
if not _IS_SYNC:
|
||||
lock: Lock = field(default_factory=_create_lock) # type: ignore[assignment]
|
||||
else:
|
||||
lock: threading.Lock = field(default_factory=_create_lock) # type: ignore[assignment, no-redef]
|
||||
|
||||
last_call_time: float = field(default=0)
|
||||
|
||||
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
|
||||
@ -186,7 +192,7 @@ class _OIDCAuthenticator:
|
||||
return None
|
||||
|
||||
if not prev_token and cb is not None:
|
||||
with self.lock:
|
||||
with self.lock: # type: ignore[attr-defined]
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.access_token
|
||||
@ -211,7 +217,10 @@ class _OIDCAuthenticator:
|
||||
idp_info=self.idp_info,
|
||||
username=self.properties.username,
|
||||
)
|
||||
resp = cb.fetch(context)
|
||||
if not _IS_SYNC:
|
||||
resp = asyncio.get_running_loop().run_in_executor(None, cb.fetch, context) # type: ignore[assignment]
|
||||
else:
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError(
|
||||
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
|
||||
|
||||
@ -1128,7 +1128,6 @@ class Cursor(Generic[_DocumentType]):
|
||||
except BaseException:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
self._address = response.address
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
|
||||
1173
test/asynchronous/test_auth_oidc.py
Normal file
1173
test/asynchronous/test_auth_oidc.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -30,7 +30,7 @@ from test import unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.auth_oidc import OIDCCallback
|
||||
from pymongo.auth_oidc_shared import OIDCCallback
|
||||
|
||||
pytestmark = pytest.mark.auth
|
||||
|
||||
|
||||
@ -17,13 +17,13 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from test import PyMongoTestCase
|
||||
from test.helpers import ConcurrentRunner
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
@ -51,6 +51,8 @@ from pymongo.synchronous.auth_oidc import (
|
||||
)
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
ROOT = Path(__file__).parent.parent.resolve()
|
||||
TEST_PATH = ROOT / "auth" / "unified"
|
||||
ENVIRON = os.environ.get("OIDC_ENV", "test")
|
||||
@ -86,7 +88,7 @@ class OIDCTestBase(PyMongoTestCase):
|
||||
token_file = TOKEN_FILE
|
||||
else:
|
||||
token_file = os.path.join(TOKEN_DIR, username)
|
||||
with open(token_file) as fid:
|
||||
with open(token_file) as fid: # noqa: ASYNC101,RUF100
|
||||
return fid.read()
|
||||
elif ENVIRON == "azure":
|
||||
opts = parse_uri(self.uri_single)["options"]
|
||||
@ -183,7 +185,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
|
||||
client = self.create_client(username="test_user1")
|
||||
# Perform a find operation that succeeds.
|
||||
client.test.test.find_one()
|
||||
# Close the client..
|
||||
# Close the client.
|
||||
client.close()
|
||||
|
||||
def test_1_3_multiple_principal_user_1(self):
|
||||
@ -254,9 +256,11 @@ class TestAuthOIDCHuman(OIDCTestBase):
|
||||
uri = "mongodb+srv://example.com?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22example.com%22%5D"
|
||||
with self.assertRaises(ConfigurationError), warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
_ = MongoClient(
|
||||
uri, authmechanismproperties=dict(OIDC_HUMAN_CALLBACK=self.create_request_cb())
|
||||
c = MongoClient(
|
||||
uri,
|
||||
authmechanismproperties=dict(OIDC_HUMAN_CALLBACK=self.create_request_cb()),
|
||||
)
|
||||
c._connect()
|
||||
|
||||
def test_1_8_machine_idp_human_callback(self):
|
||||
if not os.environ.get("OIDC_IS_LOCAL"):
|
||||
@ -634,7 +638,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
|
||||
):
|
||||
# Perform a bulk read operation.
|
||||
cursor = client.test.test.find_raw_batches({})
|
||||
list(cursor)
|
||||
cursor.to_list()
|
||||
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
@ -658,7 +662,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
|
||||
):
|
||||
# Perform a find operation.
|
||||
cursor = client.test.test.find({"a": 1})
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
self.assertGreaterEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
@ -682,7 +686,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
|
||||
):
|
||||
# Perform a find operation.
|
||||
cursor = client.test.test.find({"a": 1}, batch_size=1)
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
self.assertGreaterEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
@ -712,7 +716,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
|
||||
):
|
||||
# Perform a find operation.
|
||||
cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST)
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
self.assertGreaterEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
@ -737,7 +741,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
|
||||
# Perform a count operation.
|
||||
cursor = client.test.command({"count": "test"})
|
||||
|
||||
self.assertGreaterEqual(len(list(cursor)), 1)
|
||||
self.assertGreaterEqual(len(cursor), 1)
|
||||
|
||||
# Assert that the request callback has been called twice.
|
||||
self.assertEqual(self.request_called, 2)
|
||||
@ -790,19 +794,20 @@ class TestAuthOIDCMachine(OIDCTestBase):
|
||||
# Create a ``MongoClient`` configured with a custom OIDC callback that
|
||||
# implements the provider logic.
|
||||
client = self.create_client()
|
||||
client._connect()
|
||||
|
||||
# Start 10 threads and run 100 find operations in each thread that all succeed.
|
||||
# Start 10 tasks and run 100 find operations that all succeed in each task.
|
||||
def target():
|
||||
for _ in range(100):
|
||||
client.test.test.find_one()
|
||||
|
||||
threads = []
|
||||
for _ in range(10):
|
||||
thread = threading.Thread(target=target)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
tasks.append(ConcurrentRunner(target=target))
|
||||
for t in tasks:
|
||||
t.start()
|
||||
for t in tasks:
|
||||
t.join()
|
||||
# Assert that the callback was called 1 time.
|
||||
self.assertEqual(self.request_called, 1)
|
||||
|
||||
@ -880,6 +885,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
|
||||
def test_3_1_authentication_failure_with_cached_tokens_fetch_a_new_token_and_retry(self):
|
||||
# Create a MongoClient and an OIDC callback that implements the provider logic.
|
||||
client = self.create_client()
|
||||
client._connect()
|
||||
# Poison the cache with an invalid access token.
|
||||
# Set a fail point for ``find`` command.
|
||||
with self.fail_point(
|
||||
@ -946,6 +952,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
|
||||
# Create a ``MongoClient`` configured with a custom OIDC callback that
|
||||
# implements the provider logic.
|
||||
client = self.create_client()
|
||||
client._connect()
|
||||
|
||||
# Set a fail point for the find command.
|
||||
with self.fail_point(
|
||||
@ -1037,6 +1044,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
|
||||
# Create an OIDC configured client that can listen for `SaslStart` commands.
|
||||
listener = EventListener()
|
||||
client = self.create_client(event_listeners=[listener])
|
||||
client._connect()
|
||||
|
||||
# Preload the *Client Cache* with a valid access token to enforce Speculative Authentication.
|
||||
client2 = self.create_client()
|
||||
@ -1101,6 +1109,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
|
||||
client1 = self.create_client()
|
||||
client1.test.test.find_one()
|
||||
client2 = self.create_client()
|
||||
client2._connect()
|
||||
|
||||
# Prime the cache of the second client.
|
||||
client2.options.pool_options._credentials.cache.data = (
|
||||
@ -30,7 +30,7 @@ from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.synchronous.auth_oidc import OIDCCallback
|
||||
from pymongo.auth_oidc_shared import OIDCCallback
|
||||
|
||||
pytestmark = pytest.mark.auth
|
||||
|
||||
|
||||
@ -203,6 +203,7 @@ converted_tests = [
|
||||
"utils_spec_runner.py",
|
||||
"qcheck.py",
|
||||
"test_auth.py",
|
||||
"test_auth_oidc.py",
|
||||
"test_auth_spec.py",
|
||||
"test_bulk.py",
|
||||
"test_change_stream.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user