diff --git a/test/asynchronous/test_periodic_executor.py b/test/asynchronous/test_periodic_executor.py new file mode 100644 index 000000000..15f75b0f4 --- /dev/null +++ b/test/asynchronous/test_periodic_executor.py @@ -0,0 +1,290 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for periodic_executor.py.""" + +from __future__ import annotations + +import asyncio +import gc +import sys +import threading +import time + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncUnitTest, unittest + +import pymongo.periodic_executor as pe_module +from pymongo.periodic_executor import ( + AsyncPeriodicExecutor, + _register_executor, + _shutdown_executors, +) + +_IS_SYNC = False + + +def _make_executor(interval=30.0, min_interval=0.01, target=None, name="test"): + if target is None: + + async def target(): + return True + + return AsyncPeriodicExecutor( + interval=interval, min_interval=min_interval, target=target, name=name + ) + + +class _AsyncPeriodicExecutorTestBase(AsyncUnitTest): + async def asyncSetUp(self): + self.ex = _make_executor() + + async def asyncTearDown(self): + self.ex.close() + await self.ex.join(timeout=2) + + +class TestAsyncPeriodicExecutorRepr(AsyncUnitTest): + async def test_repr_contains_class_and_name(self): + ex = _make_executor(name="exec") + r = repr(ex) + self.assertIn("AsyncPeriodicExecutor", r) + self.assertIn("exec", r) + + +class TestAsyncPeriodicExecutorBasic(_AsyncPeriodicExecutorTestBase): + async def test_wake_sets_event(self): + self.assertFalse(self.ex._event) + self.ex.wake() + self.assertTrue(self.ex._event) + + async def test_update_interval(self): + self.ex.update_interval(60) + self.assertEqual(self.ex._interval, 60) + + async def test_skip_sleep(self): + self.assertFalse(self.ex._skip_sleep) + self.ex.skip_sleep() + self.assertTrue(self.ex._skip_sleep) + + +class TestAsyncPeriodicExecutorLifecycle(_AsyncPeriodicExecutorTestBase): + async def test_open_starts_worker(self): + self.ex.open() + if _IS_SYNC: + self.assertIsNotNone(self.ex._thread) + self.assertTrue(self.ex._thread.is_alive()) + else: + self.assertIsNotNone(self.ex._task) + + async def test_close_sets_stopped(self): + self.ex.open() + self.ex.close() + self.assertTrue(self.ex._stopped) + await self.ex.join(timeout=1) + + async def test_join_without_open_is_safe(self): + await self.ex.join(timeout=0.01) + + async def test_multiple_open_calls_have_no_effect(self): + self.ex.open() + if _IS_SYNC: + worker_id = id(self.ex._thread) + else: + worker_id = id(self.ex._task) + self.ex.open() + if _IS_SYNC: + self.assertEqual(worker_id, id(self.ex._thread)) + else: + self.assertEqual(worker_id, id(self.ex._task)) + + +class TestAsyncPeriodicExecutorTarget(_AsyncPeriodicExecutorTestBase): + async def test_target_returning_false_stops_executor(self): + if _IS_SYNC: + ran = threading.Event() + else: + ran = asyncio.Event() + + async def target(): + ran.set() + return False + + self.ex = _make_executor(target=target) + self.ex.open() + if _IS_SYNC: + self.assertTrue(ran.wait(timeout=2), "target never ran") + else: + await asyncio.wait_for(ran.wait(), timeout=2) + await self.ex.join(timeout=2) + self.assertTrue(self.ex._stopped) + + async def test_target_exception_stops_executor(self): + if _IS_SYNC: + ran = threading.Event() + captured_exc: list = [] + orig_excepthook = threading.excepthook + + def _capture_excepthook(args): + captured_exc.append(args.exc_value) + + threading.excepthook = _capture_excepthook + try: + + def target(): + ran.set() + raise RuntimeError("boom") + + self.ex = _make_executor(target=target) + self.ex.open() + self.assertTrue(ran.wait(timeout=2), "target never ran") + self.ex.join(timeout=2) + finally: + threading.excepthook = orig_excepthook + self.assertTrue(self.ex._stopped) + self.assertEqual(len(captured_exc), 1) + self.assertIsInstance(captured_exc[0], RuntimeError) + else: + ran = asyncio.Event() + + async def target(): + ran.set() + raise RuntimeError("async boom") + + self.ex = _make_executor(target=target) + self.ex.open() + await asyncio.wait_for(ran.wait(), timeout=2) + await self.ex.join(timeout=2) + self.assertTrue(self.ex._stopped) + if self.ex._task is not None and self.ex._task.done(): + self.ex._task.exception() + + async def test_skip_sleep_flag_skips_interval(self): + call_times = [] + + async def target(): + call_times.append(time.monotonic() if _IS_SYNC else asyncio.get_running_loop().time()) + if len(call_times) >= 2: + return False + return True + + self.ex = _make_executor(interval=30.0, min_interval=0.001, target=target) + self.ex.skip_sleep() + self.ex.open() + await self.ex.join(timeout=3) + self.assertGreaterEqual(len(call_times), 2) + self.assertLess(call_times[1] - call_times[0], 5.0) + + async def test_wake_causes_early_run(self): + call_count = [0] + if _IS_SYNC: + woken = threading.Event() + else: + woken = asyncio.Event() + + async def target(): + call_count[0] += 1 + if call_count[0] == 1: + woken.set() + if call_count[0] >= 2: + return False + return True + + self.ex = _make_executor(interval=30.0, min_interval=0.01, target=target) + self.ex.open() + if _IS_SYNC: + woken.wait(timeout=2) + else: + await asyncio.wait_for(woken.wait(), timeout=2) + self.ex.wake() + await self.ex.join(timeout=3) + self.assertGreaterEqual(call_count[0], 2) + + async def test_open_after_target_returns_false(self): + called = [0] + + async def target(): + called[0] += 1 + return False + + self.ex = _make_executor(target=target) + self.ex.open() + await self.ex.join(timeout=2) + self.assertTrue(self.ex._stopped) + if not _IS_SYNC: + first_task = self.ex._task + self.ex.open() + await self.ex.join(timeout=2) + self.assertGreaterEqual(called[0], 2) + if not _IS_SYNC: + self.assertIsNot(self.ex._task, first_task) + + +class TestShouldStop(AsyncUnitTest): + if _IS_SYNC: + + def test_returns_false_when_not_stopped(self): + ex = _make_executor() + self.assertFalse(ex._should_stop()) + self.assertFalse(ex._thread_will_exit) + + def test_returns_true_and_sets_thread_will_exit(self): + ex = _make_executor() + ex._stopped = True + self.assertTrue(ex._should_stop()) + self.assertTrue(ex._thread_will_exit) + + +class TestRegisterExecutor(AsyncUnitTest): + if _IS_SYNC: + + def setUp(self): + self._orig = set(pe_module._EXECUTORS) + + def tearDown(self): + pe_module._EXECUTORS.clear() + pe_module._EXECUTORS.update(self._orig) + + def test_register_adds_weakref(self): + ex = _make_executor() + before = len(pe_module._EXECUTORS) + _register_executor(ex) + self.assertEqual(len(pe_module._EXECUTORS), before + 1) + ref = next(r for r in pe_module._EXECUTORS if r() is ex) + del ex + gc.collect() + self.assertNotIn(ref, pe_module._EXECUTORS) + + def test_shutdown_executors_stops_running_executors(self): + ran = threading.Event() + + def target(): + ran.set() + return True + + ex = _make_executor(target=target) + ex.open() + self.assertTrue(ran.wait(timeout=2), "target never ran") + _shutdown_executors() + ex.join(timeout=2) + self.assertTrue(ex._stopped) + + def test_shutdown_executors_safe_when_empty(self): + pe_module._EXECUTORS.clear() + _shutdown_executors() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_periodic_executor.py b/test/test_periodic_executor.py index 636c49207..f07ded1fd 100644 --- a/test/test_periodic_executor.py +++ b/test/test_periodic_executor.py @@ -24,22 +24,19 @@ import time sys.path[0:0] = [""] -from test import unittest +from test import UnitTest, unittest import pymongo.periodic_executor as pe_module from pymongo.periodic_executor import ( - AsyncPeriodicExecutor, PeriodicExecutor, _register_executor, _shutdown_executors, ) -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- +_IS_SYNC = True -def _make_sync(interval=30.0, min_interval=0.01, target=None, name="test"): +def _make_executor(interval=30.0, min_interval=0.01, target=None, name="test"): if target is None: def target(): @@ -48,147 +45,152 @@ def _make_sync(interval=30.0, min_interval=0.01, target=None, name="test"): return PeriodicExecutor(interval=interval, min_interval=min_interval, target=target, name=name) -def _make_async(interval=30.0, min_interval=0.01, target=None, name="test"): - async def _default_target(): - return True +class _PeriodicExecutorTestBase(UnitTest): + def setUp(self): + self.ex = _make_executor() - if target is None: - target = _default_target - return AsyncPeriodicExecutor( - interval=interval, min_interval=min_interval, target=target, name=name - ) + def tearDown(self): + self.ex.close() + self.ex.join(timeout=2) -def _run(coroutine): - return asyncio.run(coroutine) - - -# --------------------------------------------------------------------------- -# PeriodicExecutor (sync / threading) -# --------------------------------------------------------------------------- - - -class TestPeriodicExecutorRepr(unittest.TestCase): +class TestPeriodicExecutorRepr(UnitTest): def test_repr_contains_class_and_name(self): - ex = _make_sync(name="myexec") + ex = _make_executor(name="exec") r = repr(ex) self.assertIn("PeriodicExecutor", r) - self.assertIn("myexec", r) + self.assertIn("exec", r) -class TestPeriodicExecutorLifecycle(unittest.TestCase): - def test_open_starts_thread(self): - ex = _make_sync() - ex.open() - try: - self.assertIsNotNone(ex._thread) - # Give thread a moment to start. - time.sleep(0.05) - self.assertTrue(ex._thread.is_alive()) - finally: - ex.close() - ex.join(timeout=2) - - def test_multiple_open_calls_have_no_effect(self): - ex = _make_sync() - ex.open() - thread_id = id(ex._thread) - ex.open() - try: - self.assertEqual(thread_id, id(ex._thread)) - finally: - ex.close() - ex.join(timeout=2) - - def test_close_sets_stopped(self): - ex = _make_sync() - ex.open() - ex.close() - self.assertTrue(ex._stopped) - ex.join(timeout=2) - - def test_join_without_open_is_safe(self): - ex = _make_sync() - ex.join(timeout=0.01) # should not raise - +class TestPeriodicExecutorBasic(_PeriodicExecutorTestBase): def test_wake_sets_event(self): - ex = _make_sync() - self.assertFalse(ex._event) - ex.wake() - self.assertTrue(ex._event) + self.assertFalse(self.ex._event) + self.ex.wake() + self.assertTrue(self.ex._event) def test_update_interval(self): - ex = _make_sync(interval=30.0) - ex.update_interval(60) - self.assertEqual(ex._interval, 60) + self.ex.update_interval(60) + self.assertEqual(self.ex._interval, 60) def test_skip_sleep(self): - ex = _make_sync() - self.assertFalse(ex._skip_sleep) - ex.skip_sleep() - self.assertTrue(ex._skip_sleep) + self.assertFalse(self.ex._skip_sleep) + self.ex.skip_sleep() + self.assertTrue(self.ex._skip_sleep) -class TestPeriodicExecutorTarget(unittest.TestCase): +class TestPeriodicExecutorLifecycle(_PeriodicExecutorTestBase): + def test_open_starts_worker(self): + self.ex.open() + if _IS_SYNC: + self.assertIsNotNone(self.ex._thread) + self.assertTrue(self.ex._thread.is_alive()) + else: + self.assertIsNotNone(self.ex._task) + + def test_close_sets_stopped(self): + self.ex.open() + self.ex.close() + self.assertTrue(self.ex._stopped) + self.ex.join(timeout=1) + + def test_join_without_open_is_safe(self): + self.ex.join(timeout=0.01) + + def test_multiple_open_calls_have_no_effect(self): + self.ex.open() + if _IS_SYNC: + worker_id = id(self.ex._thread) + else: + worker_id = id(self.ex._task) + self.ex.open() + if _IS_SYNC: + self.assertEqual(worker_id, id(self.ex._thread)) + else: + self.assertEqual(worker_id, id(self.ex._task)) + + +class TestPeriodicExecutorTarget(_PeriodicExecutorTestBase): def test_target_returning_false_stops_executor(self): - ran = threading.Event() + if _IS_SYNC: + ran = threading.Event() + else: + ran = asyncio.Event() def target(): ran.set() - return False # Signal stop. + return False - ex = _make_sync(target=target) - ex.open() - self.assertTrue(ran.wait(timeout=2), "target never ran") - ex.join(timeout=2) - self.assertTrue(ex._stopped) + self.ex = _make_executor(target=target) + self.ex.open() + if _IS_SYNC: + self.assertTrue(ran.wait(timeout=2), "target never ran") + else: + asyncio.wait_for(ran.wait(), timeout=2) + self.ex.join(timeout=2) + self.assertTrue(self.ex._stopped) def test_target_exception_stops_executor(self): - ran = threading.Event() - captured_exc = [] - orig_excepthook = threading.excepthook + if _IS_SYNC: + ran = threading.Event() + captured_exc: list = [] + orig_excepthook = threading.excepthook - def _capture_excepthook(args): - captured_exc.append(args.exc_value) + def _capture_excepthook(args): + captured_exc.append(args.exc_value) - threading.excepthook = _capture_excepthook - try: + threading.excepthook = _capture_excepthook + try: + + def target(): + ran.set() + raise RuntimeError("boom") + + self.ex = _make_executor(target=target) + self.ex.open() + self.assertTrue(ran.wait(timeout=2), "target never ran") + self.ex.join(timeout=2) + finally: + threading.excepthook = orig_excepthook + self.assertTrue(self.ex._stopped) + self.assertEqual(len(captured_exc), 1) + self.assertIsInstance(captured_exc[0], RuntimeError) + else: + ran = asyncio.Event() def target(): ran.set() - raise RuntimeError("boom") + raise RuntimeError("async boom") - ex = _make_sync(target=target) - ex.open() - self.assertTrue(ran.wait(timeout=2), "target never ran") - ex.join(timeout=2) - finally: - threading.excepthook = orig_excepthook - - self.assertTrue(ex._stopped) - self.assertEqual(len(captured_exc), 1) - self.assertIsInstance(captured_exc[0], RuntimeError) + self.ex = _make_executor(target=target) + self.ex.open() + asyncio.wait_for(ran.wait(), timeout=2) + self.ex.join(timeout=2) + self.assertTrue(self.ex._stopped) + if self.ex._task is not None and self.ex._task.done(): + self.ex._task.exception() def test_skip_sleep_flag_skips_interval(self): call_times = [] def target(): - call_times.append(time.monotonic()) + call_times.append(time.monotonic() if _IS_SYNC else asyncio.get_running_loop().time()) if len(call_times) >= 2: return False return True - ex = _make_sync(interval=30.0, min_interval=0.001, target=target) - ex.skip_sleep() - ex.open() - ex.join(timeout=2) - # First call should have skipped the 30s sleep. + self.ex = _make_executor(interval=30.0, min_interval=0.001, target=target) + self.ex.skip_sleep() + self.ex.open() + self.ex.join(timeout=3) self.assertGreaterEqual(len(call_times), 2) self.assertLess(call_times[1] - call_times[0], 5.0) def test_wake_causes_early_run(self): call_count = [0] - woken = threading.Event() + if _IS_SYNC: + woken = threading.Event() + else: + woken = asyncio.Event() def target(): call_count[0] += 1 @@ -198,227 +200,88 @@ class TestPeriodicExecutorTarget(unittest.TestCase): return False return True - ex = _make_sync(interval=30.0, min_interval=0.01, target=target) - ex.open() - woken.wait(timeout=2) - ex.wake() - ex.join(timeout=3) + self.ex = _make_executor(interval=30.0, min_interval=0.01, target=target) + self.ex.open() + if _IS_SYNC: + woken.wait(timeout=2) + else: + asyncio.wait_for(woken.wait(), timeout=2) + self.ex.wake() + self.ex.join(timeout=3) self.assertGreaterEqual(call_count[0], 2) - -class TestShouldStop(unittest.TestCase): - def test_returns_false_when_not_stopped(self): - ex = _make_sync() - self.assertFalse(ex._should_stop()) - self.assertFalse(ex._thread_will_exit) - - def test_returns_true_and_sets_thread_will_exit(self): - ex = _make_sync() - ex._stopped = True - self.assertTrue(ex._should_stop()) - self.assertTrue(ex._thread_will_exit) - - -class TestPeriodicExecutorOpenAfterExit(unittest.TestCase): - def test_reopen_after_target_returns_false(self): + def test_open_after_target_returns_false(self): called = [0] def target(): called[0] += 1 return False - ex = _make_sync(target=target) - ex.open() - ex.join(timeout=2) - self.assertTrue(ex._stopped) - # Re-open should start a new thread. - ex.open() - ex.join(timeout=2) + self.ex = _make_executor(target=target) + self.ex.open() + self.ex.join(timeout=2) + self.assertTrue(self.ex._stopped) + if not _IS_SYNC: + first_task = self.ex._task + self.ex.open() + self.ex.join(timeout=2) self.assertGreaterEqual(called[0], 2) + if not _IS_SYNC: + self.assertIsNot(self.ex._task, first_task) -# --------------------------------------------------------------------------- -# Module-level: _register_executor, _on_executor_deleted, _shutdown_executors -# --------------------------------------------------------------------------- +class TestShouldStop(UnitTest): + if _IS_SYNC: + + def test_returns_false_when_not_stopped(self): + ex = _make_executor() + self.assertFalse(ex._should_stop()) + self.assertFalse(ex._thread_will_exit) + + def test_returns_true_and_sets_thread_will_exit(self): + ex = _make_executor() + ex._stopped = True + self.assertTrue(ex._should_stop()) + self.assertTrue(ex._thread_will_exit) -class TestRegisterExecutor(unittest.TestCase): - def setUp(self): - self._orig = set(pe_module._EXECUTORS) +class TestRegisterExecutor(UnitTest): + if _IS_SYNC: - def tearDown(self): - pe_module._EXECUTORS.clear() - pe_module._EXECUTORS.update(self._orig) + def setUp(self): + self._orig = set(pe_module._EXECUTORS) - def test_register_adds_weakref(self): - ex = _make_sync() - before = len(pe_module._EXECUTORS) - _register_executor(ex) - self.assertEqual(len(pe_module._EXECUTORS), before + 1) - # Find the specific weakref we just registered. - ref = next(r for r in pe_module._EXECUTORS if r() is ex) - del ex - gc.collect() - # The weakref callback must have removed our specific ref. - self.assertNotIn(ref, pe_module._EXECUTORS) + def tearDown(self): + pe_module._EXECUTORS.clear() + pe_module._EXECUTORS.update(self._orig) - def test_shutdown_executors_stops_running_executors(self): - ex = _make_sync(interval=30.0) - ex.open() - time.sleep(0.05) - _shutdown_executors() - ex.join(timeout=2) - self.assertTrue(ex._stopped) + def test_register_adds_weakref(self): + ex = _make_executor() + before = len(pe_module._EXECUTORS) + _register_executor(ex) + self.assertEqual(len(pe_module._EXECUTORS), before + 1) + ref = next(r for r in pe_module._EXECUTORS if r() is ex) + del ex + gc.collect() + self.assertNotIn(ref, pe_module._EXECUTORS) - def test_shutdown_executors_safe_when_empty(self): - pe_module._EXECUTORS.clear() - _shutdown_executors() # Should not raise. + def test_shutdown_executors_stops_running_executors(self): + ran = threading.Event() - -# --------------------------------------------------------------------------- -# AsyncPeriodicExecutor -# --------------------------------------------------------------------------- - - -class TestAsyncPeriodicExecutorRepr(unittest.TestCase): - def test_repr_contains_class_and_name(self): - ex = _make_async(name="asyncexec") - r = repr(ex) - self.assertIn("AsyncPeriodicExecutor", r) - self.assertIn("asyncexec", r) - - -class TestAsyncPeriodicExecutorBasic(unittest.TestCase): - def test_wake_sets_event(self): - ex = _make_async() - ex.wake() - self.assertTrue(ex._event) - - def test_update_interval(self): - ex = _make_async(interval=30.0) - ex.update_interval(60) - self.assertEqual(ex._interval, 60) - - def test_skip_sleep(self): - ex = _make_async() - ex.skip_sleep() - self.assertTrue(ex._skip_sleep) - - -class TestAsyncPeriodicExecutorLifecycle(unittest.TestCase): - def test_open_creates_task(self): - async def _test(): - ex = _make_async() - ex.open() - self.assertIsNotNone(ex._task) - ex.close() - await ex.join(timeout=1) - - _run(_test()) - - def test_close_cancels_task(self): - async def _test(): - ex = _make_async() - ex.open() - ex.close() - await ex.join(timeout=1) - self.assertTrue(ex._stopped) - - _run(_test()) - - def test_join_without_open_is_safe(self): - async def _test(): - ex = _make_async() - await ex.join(timeout=0.01) # Should not raise. - - _run(_test()) - - def test_multiple_open_calls_have_no_effect(self): - async def _test(): - ex = _make_async() - ex.open() - task_id = id(ex._task) - ex.open() # Second open: same task still running. - self.assertEqual(task_id, id(ex._task)) - ex.close() - await ex.join(timeout=1) - - _run(_test()) - - -class TestAsyncPeriodicExecutorTarget(unittest.TestCase): - def test_target_returning_false_stops_executor(self): - async def _test(): - ran = asyncio.Event() - - async def target(): + def target(): ran.set() - return False - - ex = _make_async(target=target) - ex.open() - await asyncio.wait_for(ran.wait(), timeout=2) - await ex.join(timeout=2) - self.assertTrue(ex._stopped) - - _run(_test()) - - def test_target_exception_stops_executor(self): - async def _test(): - ran = asyncio.Event() - - async def target(): - ran.set() - raise RuntimeError("async boom") - - ex = _make_async(target=target) - ex.open() - await asyncio.wait_for(ran.wait(), timeout=2) - await ex.join(timeout=2) - self.assertTrue(ex._stopped) - # Retrieve the task exception to avoid "Task exception was never retrieved". - if ex._task is not None and ex._task.done(): - ex._task.exception() - - _run(_test()) - - def test_skip_sleep_flag_skips_interval(self): - async def _test(): - call_times = [] - - async def target(): - call_times.append(asyncio.get_running_loop().time()) - if len(call_times) >= 2: - return False return True - ex = _make_async(interval=30.0, min_interval=0.001, target=target) - ex.skip_sleep() + ex = _make_executor(target=target) ex.open() - await ex.join(timeout=3) - self.assertGreaterEqual(len(call_times), 2) - self.assertLess(call_times[1] - call_times[0], 5.0) + self.assertTrue(ran.wait(timeout=2), "target never ran") + _shutdown_executors() + ex.join(timeout=2) + self.assertTrue(ex._stopped) - _run(_test()) - - def test_open_after_target_returns_false_creates_new_task(self): - async def _test(): - call_count = [0] - - async def target(): - call_count[0] += 1 - return False - - ex = _make_async(target=target) - ex.open() - await ex.join(timeout=2) - first_task = ex._task - ex.open() - await ex.join(timeout=2) - self.assertGreaterEqual(call_count[0], 2) - self.assertIsNot(ex._task, first_task) - - _run(_test()) + def test_shutdown_executors_safe_when_empty(self): + pe_module._EXECUTORS.clear() + _shutdown_executors() if __name__ == "__main__": diff --git a/tools/synchro.py b/tools/synchro.py index ed794c596..5b5267b85 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -279,6 +279,7 @@ converted_tests = [ "unified_format.py", "utils_selection_tests.py", "utils.py", + "test_periodic_executor.py", ]