[v4.13] PYTHON-5406 - AsyncPeriodicExecutor must reset CSOT contextvars before executing (#2373)

Co-authored-by: Noah Stapp <noah.stapp@mongodb.com>
This commit is contained in:
Steven Silvester 2025-06-10 12:26:43 -05:00 committed by GitHub
parent 14417adc3f
commit 09a32f6d40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 0 deletions

View File

@ -32,6 +32,12 @@ RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf"))
def reset_all() -> None:
TIMEOUT.set(None)
RTT.set(0.0)
DEADLINE.set(float("inf"))
def get_timeout() -> Optional[float]:
return TIMEOUT.get(None)

View File

@ -23,6 +23,7 @@ import time
import weakref
from typing import Any, Optional
from pymongo import _csot
from pymongo._asyncio_task import create_task
from pymongo.lock import _create_lock
@ -93,6 +94,8 @@ class AsyncPeriodicExecutor:
self._skip_sleep = True
async def _run(self) -> None:
# The CSOT contextvars must be cleared inside the executor task before execution begins
_csot.reset_all()
while not self._stopped:
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
raise asyncio.CancelledError

View File

@ -0,0 +1,43 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test that AsyncPeriodicExecutors do not copy ContextVars from their parents."""
from __future__ import annotations
import asyncio
import sys
from test.asynchronous.utils import async_get_pool
from test.utils_shared import delay, one
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest
class TestAsyncContextVarsReset(AsyncIntegrationTest):
async def test_context_vars_are_reset_in_executor(self):
if sys.version_info < (3, 11):
self.skipTest("Test requires asyncio.Task.get_context (added in Python 3.11)")
client = self.simple_client()
await client.db.test.insert_one({"x": 1})
for server in client._topology._servers.values():
for context in [
c
for c in server._monitor._executor._task.get_context()
if c.name in ["TIMEOUT", "RTT", "DEADLINE"]
]:
self.assertIn(context.get(), [None, float("inf"), 0.0])
await client.db.test.delete_many({})

View File

@ -185,6 +185,7 @@ def async_only_test(f: str) -> bool:
"test_concurrency.py",
"test_async_cancellation.py",
"test_async_loop_safety.py",
"test_async_contextvars_reset.py",
]