Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2025-06-09 10:32:56 -05:00
commit bf34fa0feb
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
9 changed files with 96 additions and 3 deletions

View File

@ -46,7 +46,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3
uses: github/codeql-action/init@fca7ace96b7d713c7035871441bd52efbe39e27e # v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@ -63,6 +63,6 @@ jobs:
pip install -e .
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3
uses: github/codeql-action/analyze@fca7ace96b7d713c7035871441bd52efbe39e27e # v3
with:
category: "/language:${{matrix.language}}"

View File

@ -43,6 +43,8 @@ jobs:
aws_region_name: ${{ vars.AWS_REGION_NAME }}
aws_secret_id: ${{ secrets.AWS_SECRET_ID }}
artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }}
- name: Get hatch
run: pip install hatch
- uses: mongodb-labs/drivers-github-tools/create-branch@v2
id: create-branch
with:

View File

@ -26,7 +26,7 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
uses: github/codeql-action/upload-sarif@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3
uses: github/codeql-action/upload-sarif@fca7ace96b7d713c7035871441bd52efbe39e27e # v3
with:
sarif_file: results.sarif
category: zizmor

View File

@ -32,6 +32,12 @@ RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf"))
def reset_all() -> None:
TIMEOUT.set(None)
RTT.set(0.0)
DEADLINE.set(float("inf"))
def get_timeout() -> Optional[float]:
return TIMEOUT.get(None)

View File

@ -23,6 +23,7 @@ import time
import weakref
from typing import Any, Optional
from pymongo import _csot
from pymongo._asyncio_task import create_task
from pymongo.lock import _create_lock
@ -93,6 +94,8 @@ class AsyncPeriodicExecutor:
self._skip_sleep = True
async def _run(self) -> None:
# The CSOT contextvars must be cleared inside the executor task before execution begins
_csot.reset_all()
while not self._stopped:
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
raise asyncio.CancelledError

View File

@ -0,0 +1,43 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test that AsyncPeriodicExecutors do not copy ContextVars from their parents."""
from __future__ import annotations
import asyncio
import sys
from test.asynchronous.utils import async_get_pool
from test.utils_shared import delay, one
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest
class TestAsyncContextVarsReset(AsyncIntegrationTest):
async def test_context_vars_are_reset_in_executor(self):
if sys.version_info < (3, 11):
self.skipTest("Test requires asyncio.Task.get_context (added in Python 3.11)")
client = self.simple_client()
await client.db.test.insert_one({"x": 1})
for server in client._topology._servers.values():
for context in [
c
for c in server._monitor._executor._task.get_context()
if c.name in ["TIMEOUT", "RTT", "DEADLINE"]
]:
self.assertIn(context.get(), [None, float("inf"), 0.0])
await client.db.test.delete_many({})

View File

@ -1085,6 +1085,25 @@ class TestAuthOIDCMachine(OIDCTestBase):
# Assert there were `SaslStart` commands executed.
assert any(event.command_name.lower() == "saslstart" for event in listener.started_events)
async def test_4_5_reauthentication_succeeds_when_a_session_is_involved(self):
# Create an OIDC configured client.
client = await self.create_client()
# Set a fail point for `find` commands of the form:
async with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Start a new session.
async with client.start_session() as session:
# In the started session perform a `find` operation that succeeds.
await client.test.test.find_one({}, session=session)
# Assert that the callback was called 2 times (once during the connection handshake, and again during reauthentication).
self.assertEqual(self.request_called, 2)
async def test_5_1_azure_with_no_username(self):
if ENVIRON != "azure":
raise unittest.SkipTest("Test is only supported on Azure")

View File

@ -1083,6 +1083,25 @@ class TestAuthOIDCMachine(OIDCTestBase):
# Assert there were `SaslStart` commands executed.
assert any(event.command_name.lower() == "saslstart" for event in listener.started_events)
def test_4_5_reauthentication_succeeds_when_a_session_is_involved(self):
# Create an OIDC configured client.
client = self.create_client()
# Set a fail point for `find` commands of the form:
with self.fail_point(
{
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 391},
}
):
# Start a new session.
with client.start_session() as session:
# In the started session perform a `find` operation that succeeds.
client.test.test.find_one({}, session=session)
# Assert that the callback was called 2 times (once during the connection handshake, and again during reauthentication).
self.assertEqual(self.request_called, 2)
def test_5_1_azure_with_no_username(self):
if ENVIRON != "azure":
raise unittest.SkipTest("Test is only supported on Azure")

View File

@ -185,6 +185,7 @@ def async_only_test(f: str) -> bool:
"test_concurrency.py",
"test_async_cancellation.py",
"test_async_loop_safety.py",
"test_async_contextvars_reset.py",
]