Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
bf34fa0feb
4
.github/workflows/codeql.yml
vendored
4
.github/workflows/codeql.yml
vendored
@ -46,7 +46,7 @@ jobs:
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3
|
||||
uses: github/codeql-action/init@fca7ace96b7d713c7035871441bd52efbe39e27e # v3
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
@ -63,6 +63,6 @@ jobs:
|
||||
pip install -e .
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3
|
||||
uses: github/codeql-action/analyze@fca7ace96b7d713c7035871441bd52efbe39e27e # v3
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
|
||||
2
.github/workflows/create-release-branch.yml
vendored
2
.github/workflows/create-release-branch.yml
vendored
@ -43,6 +43,8 @@ jobs:
|
||||
aws_region_name: ${{ vars.AWS_REGION_NAME }}
|
||||
aws_secret_id: ${{ secrets.AWS_SECRET_ID }}
|
||||
artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }}
|
||||
- name: Get hatch
|
||||
run: pip install hatch
|
||||
- uses: mongodb-labs/drivers-github-tools/create-branch@v2
|
||||
id: create-branch
|
||||
with:
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@ -26,7 +26,7 @@ jobs:
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Upload SARIF file
|
||||
uses: github/codeql-action/upload-sarif@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3
|
||||
uses: github/codeql-action/upload-sarif@fca7ace96b7d713c7035871441bd52efbe39e27e # v3
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
category: zizmor
|
||||
|
||||
@ -32,6 +32,12 @@ RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
|
||||
DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf"))
|
||||
|
||||
|
||||
def reset_all() -> None:
|
||||
TIMEOUT.set(None)
|
||||
RTT.set(0.0)
|
||||
DEADLINE.set(float("inf"))
|
||||
|
||||
|
||||
def get_timeout() -> Optional[float]:
|
||||
return TIMEOUT.get(None)
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo import _csot
|
||||
from pymongo._asyncio_task import create_task
|
||||
from pymongo.lock import _create_lock
|
||||
|
||||
@ -93,6 +94,8 @@ class AsyncPeriodicExecutor:
|
||||
self._skip_sleep = True
|
||||
|
||||
async def _run(self) -> None:
|
||||
# The CSOT contextvars must be cleared inside the executor task before execution begins
|
||||
_csot.reset_all()
|
||||
while not self._stopped:
|
||||
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
|
||||
raise asyncio.CancelledError
|
||||
|
||||
43
test/asynchronous/test_async_contextvars_reset.py
Normal file
43
test/asynchronous/test_async_contextvars_reset.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2025-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test that AsyncPeriodicExecutors do not copy ContextVars from their parents."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from test.asynchronous.utils import async_get_pool
|
||||
from test.utils_shared import delay, one
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest
|
||||
|
||||
|
||||
class TestAsyncContextVarsReset(AsyncIntegrationTest):
|
||||
async def test_context_vars_are_reset_in_executor(self):
|
||||
if sys.version_info < (3, 11):
|
||||
self.skipTest("Test requires asyncio.Task.get_context (added in Python 3.11)")
|
||||
|
||||
client = self.simple_client()
|
||||
|
||||
await client.db.test.insert_one({"x": 1})
|
||||
for server in client._topology._servers.values():
|
||||
for context in [
|
||||
c
|
||||
for c in server._monitor._executor._task.get_context()
|
||||
if c.name in ["TIMEOUT", "RTT", "DEADLINE"]
|
||||
]:
|
||||
self.assertIn(context.get(), [None, float("inf"), 0.0])
|
||||
await client.db.test.delete_many({})
|
||||
@ -1085,6 +1085,25 @@ class TestAuthOIDCMachine(OIDCTestBase):
|
||||
# Assert there were `SaslStart` commands executed.
|
||||
assert any(event.command_name.lower() == "saslstart" for event in listener.started_events)
|
||||
|
||||
async def test_4_5_reauthentication_succeeds_when_a_session_is_involved(self):
|
||||
# Create an OIDC configured client.
|
||||
client = await self.create_client()
|
||||
|
||||
# Set a fail point for `find` commands of the form:
|
||||
async with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Start a new session.
|
||||
async with client.start_session() as session:
|
||||
# In the started session perform a `find` operation that succeeds.
|
||||
await client.test.test.find_one({}, session=session)
|
||||
|
||||
# Assert that the callback was called 2 times (once during the connection handshake, and again during reauthentication).
|
||||
self.assertEqual(self.request_called, 2)
|
||||
|
||||
async def test_5_1_azure_with_no_username(self):
|
||||
if ENVIRON != "azure":
|
||||
raise unittest.SkipTest("Test is only supported on Azure")
|
||||
|
||||
@ -1083,6 +1083,25 @@ class TestAuthOIDCMachine(OIDCTestBase):
|
||||
# Assert there were `SaslStart` commands executed.
|
||||
assert any(event.command_name.lower() == "saslstart" for event in listener.started_events)
|
||||
|
||||
def test_4_5_reauthentication_succeeds_when_a_session_is_involved(self):
|
||||
# Create an OIDC configured client.
|
||||
client = self.create_client()
|
||||
|
||||
# Set a fail point for `find` commands of the form:
|
||||
with self.fail_point(
|
||||
{
|
||||
"mode": {"times": 1},
|
||||
"data": {"failCommands": ["find"], "errorCode": 391},
|
||||
}
|
||||
):
|
||||
# Start a new session.
|
||||
with client.start_session() as session:
|
||||
# In the started session perform a `find` operation that succeeds.
|
||||
client.test.test.find_one({}, session=session)
|
||||
|
||||
# Assert that the callback was called 2 times (once during the connection handshake, and again during reauthentication).
|
||||
self.assertEqual(self.request_called, 2)
|
||||
|
||||
def test_5_1_azure_with_no_username(self):
|
||||
if ENVIRON != "azure":
|
||||
raise unittest.SkipTest("Test is only supported on Azure")
|
||||
|
||||
@ -185,6 +185,7 @@ def async_only_test(f: str) -> bool:
|
||||
"test_concurrency.py",
|
||||
"test_async_cancellation.py",
|
||||
"test_async_loop_safety.py",
|
||||
"test_async_contextvars_reset.py",
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user