PYTHON-4864 - Create async version of SpecRunnerThread (#2094)

This commit is contained in:
Noah Stapp 2025-02-04 07:42:13 -05:00 committed by GitHub
parent 1fda6a2310
commit b47143cd10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 104 additions and 31 deletions

View File

@ -15,6 +15,7 @@
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations
import asyncio
import base64
import gc
import multiprocessing
@ -30,6 +31,8 @@ import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo._asyncio_task import create_task
try:
import ipaddress
@ -369,3 +372,37 @@ class SystemCertsPatcher:
os.environ.pop("SSL_CERT_FILE")
else:
os.environ["SSL_CERT_FILE"] = self.original_certs
if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object
class ConcurrentRunner(PARENT):
def __init__(self, name, *args, **kwargs):
if _IS_SYNC:
super().__init__(*args, **kwargs)
self.name = name
self.stopped = False
self.task = None
if "target" in kwargs:
self.target = kwargs["target"]
if not _IS_SYNC:
async def start(self):
self.task = create_task(self.run(), name=self.name)
async def join(self, timeout: float | None = 0): # type: ignore[override]
if self.task is not None:
await asyncio.wait([self.task], timeout=timeout)
def is_alive(self):
return not self.stopped
async def run(self):
if self.target:
await self.target()
self.stopped = True

View File

@ -35,6 +35,7 @@ from test.asynchronous import (
client_knobs,
unittest,
)
from test.asynchronous.utils_spec_runner import SpecRunnerTask
from test.unified_format_shared import (
KMS_TLS_OPTS,
PLACEHOLDER_MAP,
@ -58,7 +59,6 @@ from test.utils import (
snake_to_camel,
wait_until,
)
from test.utils_spec_runner import SpecRunnerThread
from test.version import Version
from typing import Any, Dict, List, Mapping, Optional
@ -382,8 +382,8 @@ class EntityMapUtil:
return
elif entity_type == "thread":
name = spec["id"]
thread = SpecRunnerThread(name)
thread.start()
thread = SpecRunnerTask(name)
await thread.start()
self[name] = thread
return
@ -1177,16 +1177,16 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
wait_until(primary_changed, "change primary", timeout=timeout)
def _testOperation_runOnThread(self, spec):
async def _testOperation_runOnThread(self, spec):
"""Run the 'runOnThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
def _testOperation_waitForThread(self, spec):
async def _testOperation_waitForThread(self, spec):
"""Run the 'waitForThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.stop()
thread.join(10)
await thread.stop()
await thread.join(10)
if thread.exc:
raise thread.exc
self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"]))

View File

@ -18,11 +18,11 @@ from __future__ import annotations
import asyncio
import functools
import os
import threading
import unittest
from asyncio import iscoroutinefunction
from collections import abc
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
from test.asynchronous.helpers import ConcurrentRunner
from test.utils import (
CMAPListener,
CompareType,
@ -47,6 +47,7 @@ from pymongo.asynchronous import client_session
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
@ -55,38 +56,36 @@ from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class SpecRunnerThread(threading.Thread):
class SpecRunnerTask(ConcurrentRunner):
def __init__(self, name):
super().__init__()
self.name = name
super().__init__(name)
self.exc = None
self.daemon = True
self.cond = threading.Condition()
self.cond = _async_create_condition(_async_create_lock())
self.ops = []
self.stopped = False
def schedule(self, work):
async def schedule(self, work):
self.ops.append(work)
with self.cond:
async with self.cond:
self.cond.notify()
def stop(self):
async def stop(self):
self.stopped = True
with self.cond:
async with self.cond:
self.cond.notify()
def run(self):
async def run(self):
while not self.stopped or self.ops:
if not self.ops:
with self.cond:
self.cond.wait(10)
async with self.cond:
await _async_cond_wait(self.cond, 10)
if self.ops:
try:
work = self.ops.pop(0)
work()
await work()
except Exception as exc:
self.exc = exc
self.stop()
await self.stop()
class AsyncSpecTestCreator:

View File

@ -15,6 +15,7 @@
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations
import asyncio
import base64
import gc
import multiprocessing
@ -30,6 +31,8 @@ import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo._asyncio_task import create_task
try:
import ipaddress
@ -369,3 +372,37 @@ class SystemCertsPatcher:
os.environ.pop("SSL_CERT_FILE")
else:
os.environ["SSL_CERT_FILE"] = self.original_certs
if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object
class ConcurrentRunner(PARENT):
def __init__(self, name, *args, **kwargs):
if _IS_SYNC:
super().__init__(*args, **kwargs)
self.name = name
self.stopped = False
self.task = None
if "target" in kwargs:
self.target = kwargs["target"]
if not _IS_SYNC:
def start(self):
self.task = create_task(self.run(), name=self.name)
def join(self, timeout: float | None = 0): # type: ignore[override]
if self.task is not None:
asyncio.wait([self.task], timeout=timeout)
def is_alive(self):
return not self.stopped
def run(self):
if self.target:
self.target()
self.stopped = True

View File

@ -1167,7 +1167,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
def _testOperation_runOnThread(self, spec):
"""Run the 'runOnThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
def _testOperation_waitForThread(self, spec):
"""Run the 'waitForThread' operation."""

View File

@ -18,11 +18,11 @@ from __future__ import annotations
import asyncio
import functools
import os
import threading
import unittest
from asyncio import iscoroutinefunction
from collections import abc
from test import IntegrationTest, client_context, client_knobs
from test.helpers import ConcurrentRunner
from test.utils import (
CMAPListener,
CompareType,
@ -44,6 +44,7 @@ from bson.son import SON
from gridfs import GridFSBucket
from gridfs.synchronous.grid_file import GridFSBucket
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from pymongo.lock import _cond_wait, _create_condition, _create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
@ -55,15 +56,13 @@ from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class SpecRunnerThread(threading.Thread):
class SpecRunnerThread(ConcurrentRunner):
def __init__(self, name):
super().__init__()
self.name = name
super().__init__(name)
self.exc = None
self.daemon = True
self.cond = threading.Condition()
self.cond = _create_condition(_create_lock())
self.ops = []
self.stopped = False
def schedule(self, work):
self.ops.append(work)
@ -79,7 +78,7 @@ class SpecRunnerThread(threading.Thread):
while not self.stopped or self.ops:
if not self.ops:
with self.cond:
self.cond.wait(10)
_cond_wait(self.cond, 10)
if self.ops:
try:
work = self.ops.pop(0)

View File

@ -119,6 +119,7 @@ replacements = {
"_async_create_lock": "_create_lock",
"_async_create_condition": "_create_condition",
"_async_cond_wait": "_cond_wait",
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",
}