PYTHON-4864 - Create async version of SpecRunnerThread (#2094)
This commit is contained in:
parent
1fda6a2310
commit
b47143cd10
@ -15,6 +15,7 @@
|
||||
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import multiprocessing
|
||||
@ -30,6 +31,8 @@ import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo._asyncio_task import create_task
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
@ -369,3 +372,37 @@ class SystemCertsPatcher:
|
||||
os.environ.pop("SSL_CERT_FILE")
|
||||
else:
|
||||
os.environ["SSL_CERT_FILE"] = self.original_certs
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
PARENT = threading.Thread
|
||||
else:
|
||||
PARENT = object
|
||||
|
||||
|
||||
class ConcurrentRunner(PARENT):
|
||||
def __init__(self, name, *args, **kwargs):
|
||||
if _IS_SYNC:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.name = name
|
||||
self.stopped = False
|
||||
self.task = None
|
||||
if "target" in kwargs:
|
||||
self.target = kwargs["target"]
|
||||
|
||||
if not _IS_SYNC:
|
||||
|
||||
async def start(self):
|
||||
self.task = create_task(self.run(), name=self.name)
|
||||
|
||||
async def join(self, timeout: float | None = 0): # type: ignore[override]
|
||||
if self.task is not None:
|
||||
await asyncio.wait([self.task], timeout=timeout)
|
||||
|
||||
def is_alive(self):
|
||||
return not self.stopped
|
||||
|
||||
async def run(self):
|
||||
if self.target:
|
||||
await self.target()
|
||||
self.stopped = True
|
||||
|
||||
@ -35,6 +35,7 @@ from test.asynchronous import (
|
||||
client_knobs,
|
||||
unittest,
|
||||
)
|
||||
from test.asynchronous.utils_spec_runner import SpecRunnerTask
|
||||
from test.unified_format_shared import (
|
||||
KMS_TLS_OPTS,
|
||||
PLACEHOLDER_MAP,
|
||||
@ -58,7 +59,6 @@ from test.utils import (
|
||||
snake_to_camel,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunnerThread
|
||||
from test.version import Version
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
@ -382,8 +382,8 @@ class EntityMapUtil:
|
||||
return
|
||||
elif entity_type == "thread":
|
||||
name = spec["id"]
|
||||
thread = SpecRunnerThread(name)
|
||||
thread.start()
|
||||
thread = SpecRunnerTask(name)
|
||||
await thread.start()
|
||||
self[name] = thread
|
||||
return
|
||||
|
||||
@ -1177,16 +1177,16 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
|
||||
wait_until(primary_changed, "change primary", timeout=timeout)
|
||||
|
||||
def _testOperation_runOnThread(self, spec):
|
||||
async def _testOperation_runOnThread(self, spec):
|
||||
"""Run the 'runOnThread' operation."""
|
||||
thread = self.entity_map[spec["thread"]]
|
||||
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
|
||||
await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
|
||||
|
||||
def _testOperation_waitForThread(self, spec):
|
||||
async def _testOperation_waitForThread(self, spec):
|
||||
"""Run the 'waitForThread' operation."""
|
||||
thread = self.entity_map[spec["thread"]]
|
||||
thread.stop()
|
||||
thread.join(10)
|
||||
await thread.stop()
|
||||
await thread.join(10)
|
||||
if thread.exc:
|
||||
raise thread.exc
|
||||
self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"]))
|
||||
|
||||
@ -18,11 +18,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
from asyncio import iscoroutinefunction
|
||||
from collections import abc
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
|
||||
from test.asynchronous.helpers import ConcurrentRunner
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
CompareType,
|
||||
@ -47,6 +47,7 @@ from pymongo.asynchronous import client_session
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
|
||||
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.results import BulkWriteResult, _WriteResult
|
||||
@ -55,38 +56,36 @@ from pymongo.write_concern import WriteConcern
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class SpecRunnerThread(threading.Thread):
|
||||
class SpecRunnerTask(ConcurrentRunner):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
super().__init__(name)
|
||||
self.exc = None
|
||||
self.daemon = True
|
||||
self.cond = threading.Condition()
|
||||
self.cond = _async_create_condition(_async_create_lock())
|
||||
self.ops = []
|
||||
self.stopped = False
|
||||
|
||||
def schedule(self, work):
|
||||
async def schedule(self, work):
|
||||
self.ops.append(work)
|
||||
with self.cond:
|
||||
async with self.cond:
|
||||
self.cond.notify()
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
self.stopped = True
|
||||
with self.cond:
|
||||
async with self.cond:
|
||||
self.cond.notify()
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
while not self.stopped or self.ops:
|
||||
if not self.ops:
|
||||
with self.cond:
|
||||
self.cond.wait(10)
|
||||
async with self.cond:
|
||||
await _async_cond_wait(self.cond, 10)
|
||||
if self.ops:
|
||||
try:
|
||||
work = self.ops.pop(0)
|
||||
work()
|
||||
await work()
|
||||
except Exception as exc:
|
||||
self.exc = exc
|
||||
self.stop()
|
||||
await self.stop()
|
||||
|
||||
|
||||
class AsyncSpecTestCreator:
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import multiprocessing
|
||||
@ -30,6 +31,8 @@ import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo._asyncio_task import create_task
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
@ -369,3 +372,37 @@ class SystemCertsPatcher:
|
||||
os.environ.pop("SSL_CERT_FILE")
|
||||
else:
|
||||
os.environ["SSL_CERT_FILE"] = self.original_certs
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
PARENT = threading.Thread
|
||||
else:
|
||||
PARENT = object
|
||||
|
||||
|
||||
class ConcurrentRunner(PARENT):
|
||||
def __init__(self, name, *args, **kwargs):
|
||||
if _IS_SYNC:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.name = name
|
||||
self.stopped = False
|
||||
self.task = None
|
||||
if "target" in kwargs:
|
||||
self.target = kwargs["target"]
|
||||
|
||||
if not _IS_SYNC:
|
||||
|
||||
def start(self):
|
||||
self.task = create_task(self.run(), name=self.name)
|
||||
|
||||
def join(self, timeout: float | None = 0): # type: ignore[override]
|
||||
if self.task is not None:
|
||||
asyncio.wait([self.task], timeout=timeout)
|
||||
|
||||
def is_alive(self):
|
||||
return not self.stopped
|
||||
|
||||
def run(self):
|
||||
if self.target:
|
||||
self.target()
|
||||
self.stopped = True
|
||||
|
||||
@ -1167,7 +1167,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
def _testOperation_runOnThread(self, spec):
|
||||
"""Run the 'runOnThread' operation."""
|
||||
thread = self.entity_map[spec["thread"]]
|
||||
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
|
||||
thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
|
||||
|
||||
def _testOperation_waitForThread(self, spec):
|
||||
"""Run the 'waitForThread' operation."""
|
||||
|
||||
@ -18,11 +18,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
from asyncio import iscoroutinefunction
|
||||
from collections import abc
|
||||
from test import IntegrationTest, client_context, client_knobs
|
||||
from test.helpers import ConcurrentRunner
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
CompareType,
|
||||
@ -44,6 +44,7 @@ from bson.son import SON
|
||||
from gridfs import GridFSBucket
|
||||
from gridfs.synchronous.grid_file import GridFSBucket
|
||||
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
|
||||
from pymongo.lock import _cond_wait, _create_condition, _create_lock
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.results import BulkWriteResult, _WriteResult
|
||||
@ -55,15 +56,13 @@ from pymongo.write_concern import WriteConcern
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class SpecRunnerThread(threading.Thread):
|
||||
class SpecRunnerThread(ConcurrentRunner):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
super().__init__(name)
|
||||
self.exc = None
|
||||
self.daemon = True
|
||||
self.cond = threading.Condition()
|
||||
self.cond = _create_condition(_create_lock())
|
||||
self.ops = []
|
||||
self.stopped = False
|
||||
|
||||
def schedule(self, work):
|
||||
self.ops.append(work)
|
||||
@ -79,7 +78,7 @@ class SpecRunnerThread(threading.Thread):
|
||||
while not self.stopped or self.ops:
|
||||
if not self.ops:
|
||||
with self.cond:
|
||||
self.cond.wait(10)
|
||||
_cond_wait(self.cond, 10)
|
||||
if self.ops:
|
||||
try:
|
||||
work = self.ops.pop(0)
|
||||
|
||||
@ -119,6 +119,7 @@ replacements = {
|
||||
"_async_create_lock": "_create_lock",
|
||||
"_async_create_condition": "_create_condition",
|
||||
"_async_cond_wait": "_cond_wait",
|
||||
"SpecRunnerTask": "SpecRunnerThread",
|
||||
"AsyncMockConnection": "MockConnection",
|
||||
"AsyncMockPool": "MockPool",
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user