[v4.13] PYTHON-5406 - AsyncPeriodicExecutor must reset CSOT contextvars before executing (#2373)
Co-authored-by: Noah Stapp <noah.stapp@mongodb.com>
This commit is contained in:
parent
14417adc3f
commit
09a32f6d40
@ -32,6 +32,12 @@ RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
|
||||
DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf"))
|
||||
|
||||
|
||||
def reset_all() -> None:
|
||||
TIMEOUT.set(None)
|
||||
RTT.set(0.0)
|
||||
DEADLINE.set(float("inf"))
|
||||
|
||||
|
||||
def get_timeout() -> Optional[float]:
|
||||
return TIMEOUT.get(None)
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo import _csot
|
||||
from pymongo._asyncio_task import create_task
|
||||
from pymongo.lock import _create_lock
|
||||
|
||||
@ -93,6 +94,8 @@ class AsyncPeriodicExecutor:
|
||||
self._skip_sleep = True
|
||||
|
||||
async def _run(self) -> None:
|
||||
# The CSOT contextvars must be cleared inside the executor task before execution begins
|
||||
_csot.reset_all()
|
||||
while not self._stopped:
|
||||
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
|
||||
raise asyncio.CancelledError
|
||||
|
||||
43
test/asynchronous/test_async_contextvars_reset.py
Normal file
43
test/asynchronous/test_async_contextvars_reset.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2025-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test that AsyncPeriodicExecutors do not copy ContextVars from their parents."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from test.asynchronous.utils import async_get_pool
|
||||
from test.utils_shared import delay, one
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest
|
||||
|
||||
|
||||
class TestAsyncContextVarsReset(AsyncIntegrationTest):
|
||||
async def test_context_vars_are_reset_in_executor(self):
|
||||
if sys.version_info < (3, 11):
|
||||
self.skipTest("Test requires asyncio.Task.get_context (added in Python 3.11)")
|
||||
|
||||
client = self.simple_client()
|
||||
|
||||
await client.db.test.insert_one({"x": 1})
|
||||
for server in client._topology._servers.values():
|
||||
for context in [
|
||||
c
|
||||
for c in server._monitor._executor._task.get_context()
|
||||
if c.name in ["TIMEOUT", "RTT", "DEADLINE"]
|
||||
]:
|
||||
self.assertIn(context.get(), [None, float("inf"), 0.0])
|
||||
await client.db.test.delete_many({})
|
||||
@ -185,6 +185,7 @@ def async_only_test(f: str) -> bool:
|
||||
"test_concurrency.py",
|
||||
"test_async_cancellation.py",
|
||||
"test_async_loop_safety.py",
|
||||
"test_async_contextvars_reset.py",
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user