Refine thread worker heartbeat handling

This commit is contained in:
Marcelo Trylesinski 2026-04-12 17:11:40 +02:00
parent 7e901a0d2b
commit 0f274d2a4e
5 changed files with 40 additions and 17 deletions

View File

@ -11,7 +11,7 @@ from pytest_mock import MockerFixture
from uvicorn import Config
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.server import Server
from uvicorn.supervisors.multithread import Multithread, Thread, ThreadServer
from uvicorn.supervisors.multithread import Multithread, Thread
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
@ -34,6 +34,7 @@ class FakeThread:
self.joined = False
self.healthy = True
self.join_result = True
self.ready_for_healthcheck = True
def is_alive(self) -> bool:
return self.alive
@ -41,6 +42,9 @@ class FakeThread:
def is_healthy(self, timeout: float) -> bool:
return self.healthy
def is_ready_for_healthcheck(self) -> bool:
return self.ready_for_healthcheck
def start(self) -> None:
self.started = True
@ -80,9 +84,12 @@ def test_thread_terminate_sets_server_exit_flag() -> None:
target = thread._get_target()
assert thread.server is not None
assert isinstance(thread.server, ThreadServer)
assert isinstance(thread.server, Server)
assert target.__self__ is thread.server
assert thread.server.should_exit is False
assert thread.config.callback_progress is not None
assert thread.config.callback_progress.__self__ is thread
assert thread.config.callback_progress.__func__ is thread.record_heartbeat.__func__
thread.terminate()
@ -101,9 +108,10 @@ def test_thread_record_heartbeat_and_is_healthy() -> None:
@pytest.mark.anyio
async def test_thread_server_records_heartbeat_on_tick() -> None:
async def test_server_progress_callback_records_heartbeat_on_tick() -> None:
thread = Thread(Config(app=app), target=lambda sockets: None, sockets=[])
server = ThreadServer(config=Config(app=app), worker_thread=thread)
config = Config(app=app, callback_progress=thread.record_heartbeat)
server = Server(config=config)
before = thread.last_heartbeat
thread.last_heartbeat -= 5
@ -193,6 +201,20 @@ def test_multithread_keep_subthread_alive_replaces_unhealthy_thread(mocker: Mock
assert supervisor.threads[0].started is True
def test_multithread_keep_subthread_alive_skips_healthcheck_until_ready(mocker: MockerFixture) -> None:
mocker.patch("uvicorn.supervisors.multithread.Thread", FakeThread)
supervisor = Multithread(Config(app=app, workers=1), target=lambda sockets: None, sockets=[])
supervisor.init_threads()
thread = supervisor.threads[0]
thread.healthy = False
thread.ready_for_healthcheck = False
supervisor.keep_subthread_alive()
assert supervisor.threads[0] is thread
def test_multithread_keep_subthread_alive_replaces_unhealthy_thread_without_blocking_join(
mocker: MockerFixture,
) -> None:

View File

@ -119,7 +119,7 @@ def test_run_match_config_params() -> None:
config_params = {
key: repr(value)
for key, value in inspect.signature(Config.__init__).parameters.items()
if key not in ("self", "timeout_notify", "callback_notify")
if key not in ("self", "timeout_notify", "callback_notify", "callback_progress")
}
run_params = {
key: repr(value) for key, value in inspect.signature(run).parameters.items() if key not in ("app_dir",)

View File

@ -237,6 +237,7 @@ class Config:
timeout_graceful_shutdown: int | None = None,
timeout_worker_healthcheck: int = 5,
callback_notify: Callable[..., Awaitable[None]] | None = None,
callback_progress: Callable[[], None] | None = None,
ssl_keyfile: str | os.PathLike[str] | None = None,
ssl_certfile: str | os.PathLike[str] | None = None,
ssl_keyfile_password: str | None = None,
@ -284,6 +285,7 @@ class Config:
self.timeout_graceful_shutdown = timeout_graceful_shutdown
self.timeout_worker_healthcheck = timeout_worker_healthcheck
self.callback_notify = callback_notify
self.callback_progress = callback_progress
self.ssl_keyfile = ssl_keyfile
self.ssl_certfile = ssl_certfile
self.ssl_keyfile_password = ssl_keyfile_password

View File

@ -241,6 +241,9 @@ class Server:
async def on_tick(self, counter: int) -> bool:
# Update the default headers, once per second.
if counter % 10 == 0:
if self.config.callback_progress is not None:
self.config.callback_progress()
current_time = time.time()
current_date = formatdate(current_time, usegmt=True).encode()

View File

@ -20,17 +20,6 @@ from uvicorn.supervisors.multiprocess import SIGNALS
logger = logging.getLogger("uvicorn.error")
class ThreadServer(Server):
def __init__(self, config: Config, worker_thread: Thread) -> None:
super().__init__(config)
self.worker_thread = worker_thread
async def on_tick(self, counter: int) -> bool:
if counter % 10 == 0:
self.worker_thread.record_heartbeat()
return await super().on_tick(counter)
class Thread:
def __init__(
self,
@ -47,7 +36,8 @@ class Thread:
def _get_target(self) -> Callable[[list[socket] | None], None]:
if inspect.ismethod(self.real_target) and isinstance(self.real_target.__self__, Server):
self.server = ThreadServer(config=self.config, worker_thread=self)
self.config.callback_progress = self.record_heartbeat
self.server = Server(config=self.config)
return self.server.run
return self.real_target
@ -80,6 +70,9 @@ class Thread:
def is_healthy(self, timeout: float) -> bool:
return time.monotonic() - self.last_heartbeat <= timeout
def is_ready_for_healthcheck(self) -> bool:
return self.server is None or self.server.started
class Multithread:
def __init__(
@ -151,6 +144,9 @@ class Multithread:
self._replace_thread(idx, thread, reason="Child thread died")
continue
if not thread.is_ready_for_healthcheck():
continue
if not thread.is_healthy(timeout=self.config.timeout_worker_healthcheck):
self._replace_thread(idx, thread, reason="Worker thread failed healthcheck")