Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2025-05-29 20:09:40 -05:00
commit beee2ec3ef
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
10 changed files with 1232 additions and 33 deletions

View File

@ -15,6 +15,7 @@
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations
import asyncio
import threading
import time
from dataclasses import dataclass, field
@ -36,6 +37,7 @@ from pymongo.auth_oidc_shared import (
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
from pymongo.lock import Lock, _async_create_lock
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection
@ -81,7 +83,11 @@ class _OIDCAuthenticator:
access_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
token_gen_id: int = field(default=0)
lock: threading.Lock = field(default_factory=threading.Lock)
if not _IS_SYNC:
lock: Lock = field(default_factory=_async_create_lock) # type: ignore[assignment]
else:
lock: threading.Lock = field(default_factory=_async_create_lock) # type: ignore[assignment, no-redef]
last_call_time: float = field(default=0)
async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
@ -164,7 +170,7 @@ class _OIDCAuthenticator:
# Attempt to authenticate with a JwtStepRequest.
return await self._sasl_continue_jwt(conn, start_resp)
def _get_access_token(self) -> Optional[str]:
async def _get_access_token(self) -> Optional[str]:
properties = self.properties
cb: Union[None, OIDCCallback]
resp: OIDCCallbackResult
@ -186,7 +192,7 @@ class _OIDCAuthenticator:
return None
if not prev_token and cb is not None:
with self.lock:
async with self.lock: # type: ignore[attr-defined]
# See if the token was changed while we were waiting for the
# lock.
new_token = self.access_token
@ -196,7 +202,7 @@ class _OIDCAuthenticator:
# Ensure that we are waiting a min time between callback invocations.
delta = time.time() - self.last_call_time
if delta < TIME_BETWEEN_CALLS_SECONDS:
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
await asyncio.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
self.last_call_time = time.time()
if is_human:
@ -211,7 +217,10 @@ class _OIDCAuthenticator:
idp_info=self.idp_info,
username=self.properties.username,
)
resp = cb.fetch(context)
if not _IS_SYNC:
resp = await asyncio.get_running_loop().run_in_executor(None, cb.fetch, context) # type: ignore[assignment]
else:
resp = cb.fetch(context)
if not isinstance(resp, OIDCCallbackResult):
raise ValueError(
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
@ -253,13 +262,13 @@ class _OIDCAuthenticator:
start_payload: dict = bson.decode(start_resp["payload"])
if "issuer" in start_payload:
self.idp_info = OIDCIdPInfo(**start_payload)
access_token = self._get_access_token()
access_token = await self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
return await self._run_command(conn, cmd)
async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]:
access_token = self._get_access_token()
access_token = await self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_start_command({"jwt": access_token})
return await self._run_command(conn, cmd)

View File

@ -1130,7 +1130,6 @@ class AsyncCursor(Generic[_DocumentType]):
except BaseException:
await self.close()
raise
self._address = response.address
if isinstance(response, PinnedResponse):
if not self._sock_mgr:

View File

@ -64,7 +64,7 @@ def _handle_reauth(func: F) -> F:
await conn.authenticate(reauthenticate=True)
else:
raise
return func(*args, **kwargs)
return await func(*args, **kwargs)
raise
return cast(F, inner)

View File

@ -15,6 +15,7 @@
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations
import asyncio
import threading
import time
from dataclasses import dataclass, field
@ -36,6 +37,7 @@ from pymongo.auth_oidc_shared import (
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
from pymongo.lock import Lock, _create_lock
if TYPE_CHECKING:
from pymongo.auth_shared import MongoCredential
@ -81,7 +83,11 @@ class _OIDCAuthenticator:
access_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
token_gen_id: int = field(default=0)
lock: threading.Lock = field(default_factory=threading.Lock)
if not _IS_SYNC:
lock: Lock = field(default_factory=_create_lock) # type: ignore[assignment]
else:
lock: threading.Lock = field(default_factory=_create_lock) # type: ignore[assignment, no-redef]
last_call_time: float = field(default=0)
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
@ -186,7 +192,7 @@ class _OIDCAuthenticator:
return None
if not prev_token and cb is not None:
with self.lock:
with self.lock: # type: ignore[attr-defined]
# See if the token was changed while we were waiting for the
# lock.
new_token = self.access_token
@ -211,7 +217,10 @@ class _OIDCAuthenticator:
idp_info=self.idp_info,
username=self.properties.username,
)
resp = cb.fetch(context)
if not _IS_SYNC:
resp = asyncio.get_running_loop().run_in_executor(None, cb.fetch, context) # type: ignore[assignment]
else:
resp = cb.fetch(context)
if not isinstance(resp, OIDCCallbackResult):
raise ValueError(
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"

View File

@ -1128,7 +1128,6 @@ class Cursor(Generic[_DocumentType]):
except BaseException:
self.close()
raise
self._address = response.address
if isinstance(response, PinnedResponse):
if not self._sock_mgr:

File diff suppressed because it is too large Load Diff

View File

@ -30,7 +30,7 @@ from test import unittest
from test.asynchronous.unified_format import generate_test_classes
from pymongo import AsyncMongoClient
from pymongo.asynchronous.auth_oidc import OIDCCallback
from pymongo.auth_oidc_shared import OIDCCallback
pytestmark = pytest.mark.auth

View File

@ -17,13 +17,13 @@ from __future__ import annotations
import os
import sys
import threading
import time
import unittest
import warnings
from contextlib import contextmanager
from pathlib import Path
from test import PyMongoTestCase
from test.helpers import ConcurrentRunner
from typing import Dict
import pytest
@ -51,6 +51,8 @@ from pymongo.synchronous.auth_oidc import (
)
from pymongo.synchronous.uri_parser import parse_uri
_IS_SYNC = True
ROOT = Path(__file__).parent.parent.resolve()
TEST_PATH = ROOT / "auth" / "unified"
ENVIRON = os.environ.get("OIDC_ENV", "test")
@ -86,7 +88,7 @@ class OIDCTestBase(PyMongoTestCase):
token_file = TOKEN_FILE
else:
token_file = os.path.join(TOKEN_DIR, username)
with open(token_file) as fid:
with open(token_file) as fid: # noqa: ASYNC101,RUF100
return fid.read()
elif ENVIRON == "azure":
opts = parse_uri(self.uri_single)["options"]
@ -183,7 +185,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
client = self.create_client(username="test_user1")
# Perform a find operation that succeeds.
client.test.test.find_one()
# Close the client..
# Close the client.
client.close()
def test_1_3_multiple_principal_user_1(self):
@ -254,9 +256,11 @@ class TestAuthOIDCHuman(OIDCTestBase):
uri = "mongodb+srv://example.com?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22example.com%22%5D"
with self.assertRaises(ConfigurationError), warnings.catch_warnings():
warnings.simplefilter("ignore")
_ = MongoClient(
uri, authmechanismproperties=dict(OIDC_HUMAN_CALLBACK=self.create_request_cb())
c = MongoClient(
uri,
authmechanismproperties=dict(OIDC_HUMAN_CALLBACK=self.create_request_cb()),
)
c._connect()
def test_1_8_machine_idp_human_callback(self):
if not os.environ.get("OIDC_IS_LOCAL"):
@ -634,7 +638,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
):
# Perform a bulk read operation.
cursor = client.test.test.find_raw_batches({})
list(cursor)
cursor.to_list()
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
@ -658,7 +662,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1})
self.assertGreaterEqual(len(list(cursor)), 1)
self.assertGreaterEqual(len(cursor.to_list()), 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
@ -682,7 +686,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1}, batch_size=1)
self.assertGreaterEqual(len(list(cursor)), 1)
self.assertGreaterEqual(len(cursor.to_list()), 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
@ -712,7 +716,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
):
# Perform a find operation.
cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST)
self.assertGreaterEqual(len(list(cursor)), 1)
self.assertGreaterEqual(len(cursor.to_list()), 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
@ -737,7 +741,7 @@ class TestAuthOIDCHuman(OIDCTestBase):
# Perform a count operation.
cursor = client.test.command({"count": "test"})
self.assertGreaterEqual(len(list(cursor)), 1)
self.assertGreaterEqual(len(cursor), 1)
# Assert that the request callback has been called twice.
self.assertEqual(self.request_called, 2)
@ -790,19 +794,20 @@ class TestAuthOIDCMachine(OIDCTestBase):
# Create a ``MongoClient`` configured with a custom OIDC callback that
# implements the provider logic.
client = self.create_client()
client._connect()
# Start 10 threads and run 100 find operations in each thread that all succeed.
# Start 10 tasks and run 100 find operations that all succeed in each task.
def target():
for _ in range(100):
client.test.test.find_one()
threads = []
for _ in range(10):
thread = threading.Thread(target=target)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
tasks = []
for i in range(10):
tasks.append(ConcurrentRunner(target=target))
for t in tasks:
t.start()
for t in tasks:
t.join()
# Assert that the callback was called 1 time.
self.assertEqual(self.request_called, 1)
@ -880,6 +885,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
def test_3_1_authentication_failure_with_cached_tokens_fetch_a_new_token_and_retry(self):
# Create a MongoClient and an OIDC callback that implements the provider logic.
client = self.create_client()
client._connect()
# Poison the cache with an invalid access token.
# Set a fail point for ``find`` command.
with self.fail_point(
@ -946,6 +952,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
# Create a ``MongoClient`` configured with a custom OIDC callback that
# implements the provider logic.
client = self.create_client()
client._connect()
# Set a fail point for the find command.
with self.fail_point(
@ -1037,6 +1044,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
# Create an OIDC configured client that can listen for `SaslStart` commands.
listener = EventListener()
client = self.create_client(event_listeners=[listener])
client._connect()
# Preload the *Client Cache* with a valid access token to enforce Speculative Authentication.
client2 = self.create_client()
@ -1101,6 +1109,7 @@ class TestAuthOIDCMachine(OIDCTestBase):
client1 = self.create_client()
client1.test.test.find_one()
client2 = self.create_client()
client2._connect()
# Prime the cache of the second client.
client2.options.pool_options._credentials.cache.data = (

View File

@ -30,7 +30,7 @@ from test import unittest
from test.unified_format import generate_test_classes
from pymongo import MongoClient
from pymongo.synchronous.auth_oidc import OIDCCallback
from pymongo.auth_oidc_shared import OIDCCallback
pytestmark = pytest.mark.auth

View File

@ -203,6 +203,7 @@ converted_tests = [
"utils_spec_runner.py",
"qcheck.py",
"test_auth.py",
"test_auth_oidc.py",
"test_auth_spec.py",
"test_bulk.py",
"test_change_stream.py",