# Copyright 2024-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Shared helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations import asyncio import threading import traceback from functools import wraps from typing import Optional, no_type_check from bson import SON from pymongo import common from pymongo._asyncio_task import create_task from pymongo.read_preferences import ReadPreference _IS_SYNC = True def repl_set_step_down(client, **kwargs): """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" cmd = SON([("replSetStepDown", 1)]) cmd.update(kwargs) # Unfreeze a secondary to ensure a speedy election. client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) client.admin.command(cmd) class client_knobs: def __init__( self, heartbeat_frequency=None, min_heartbeat_interval=None, kill_cursor_frequency=None, events_queue_frequency=None, ): self.heartbeat_frequency = heartbeat_frequency self.min_heartbeat_interval = min_heartbeat_interval self.kill_cursor_frequency = kill_cursor_frequency self.events_queue_frequency = events_queue_frequency self.old_heartbeat_frequency = None self.old_min_heartbeat_interval = None self.old_kill_cursor_frequency = None self.old_events_queue_frequency = None self._enabled = False self._stack = None def enable(self): self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY if self.heartbeat_frequency is not None: common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency if self.min_heartbeat_interval is not None: common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval if self.kill_cursor_frequency is not None: common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency if self.events_queue_frequency is not None: common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency self._enabled = True # Store the allocation traceback to catch non-disabled client_knobs. self._stack = "".join(traceback.format_stack()) def __enter__(self): self.enable() @no_type_check def disable(self): common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency self._enabled = False def __exit__(self, exc_type, exc_val, exc_tb): self.disable() def __call__(self, func): def make_wrapper(f): @wraps(f) def wrap(*args, **kwargs): with self: return f(*args, **kwargs) return wrap return make_wrapper(func) def __del__(self): if self._enabled: msg = ( "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, " "MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, " "EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format( common.HEARTBEAT_FREQUENCY, common.MIN_HEARTBEAT_INTERVAL, common.KILL_CURSOR_FREQUENCY, common.EVENTS_QUEUE_FREQUENCY, self._stack, ) ) self.disable() raise Exception(msg) # Global knobs to speed up the test suite. global_knobs = client_knobs(events_queue_frequency=0.05) if _IS_SYNC: PARENT = threading.Thread else: PARENT = object class ConcurrentRunner(PARENT): def __init__(self, **kwargs): if _IS_SYNC: super().__init__(**kwargs) self.name = kwargs.get("name", "ConcurrentRunner") self.stopped = False self.task = None self.target = kwargs.get("target", None) self.args = kwargs.get("args", []) if not _IS_SYNC: def start(self): self.task = create_task(self.run(), name=self.name) def join(self, timeout: Optional[float] = None): # type: ignore[override] if self.task is not None: asyncio.wait([self.task], timeout=timeout) def is_alive(self): return not self.stopped def run(self): try: self.target(*self.args) finally: self.stopped = True class ExceptionCatchingTask(ConcurrentRunner): """A Task that stores any exception encountered while running.""" def __init__(self, **kwargs): super().__init__(**kwargs) self.exc = None def run(self): try: super().run() except BaseException as exc: self.exc = exc raise