Refine thread worker heartbeat handling
This commit is contained in:
parent
7e901a0d2b
commit
0f274d2a4e
@ -11,7 +11,7 @@ from pytest_mock import MockerFixture
|
||||
from uvicorn import Config
|
||||
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
from uvicorn.server import Server
|
||||
from uvicorn.supervisors.multithread import Multithread, Thread, ThreadServer
|
||||
from uvicorn.supervisors.multithread import Multithread, Thread
|
||||
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
@ -34,6 +34,7 @@ class FakeThread:
|
||||
self.joined = False
|
||||
self.healthy = True
|
||||
self.join_result = True
|
||||
self.ready_for_healthcheck = True
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return self.alive
|
||||
@ -41,6 +42,9 @@ class FakeThread:
|
||||
def is_healthy(self, timeout: float) -> bool:
|
||||
return self.healthy
|
||||
|
||||
def is_ready_for_healthcheck(self) -> bool:
|
||||
return self.ready_for_healthcheck
|
||||
|
||||
def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
@ -80,9 +84,12 @@ def test_thread_terminate_sets_server_exit_flag() -> None:
|
||||
target = thread._get_target()
|
||||
|
||||
assert thread.server is not None
|
||||
assert isinstance(thread.server, ThreadServer)
|
||||
assert isinstance(thread.server, Server)
|
||||
assert target.__self__ is thread.server
|
||||
assert thread.server.should_exit is False
|
||||
assert thread.config.callback_progress is not None
|
||||
assert thread.config.callback_progress.__self__ is thread
|
||||
assert thread.config.callback_progress.__func__ is thread.record_heartbeat.__func__
|
||||
|
||||
thread.terminate()
|
||||
|
||||
@ -101,9 +108,10 @@ def test_thread_record_heartbeat_and_is_healthy() -> None:
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_thread_server_records_heartbeat_on_tick() -> None:
|
||||
async def test_server_progress_callback_records_heartbeat_on_tick() -> None:
|
||||
thread = Thread(Config(app=app), target=lambda sockets: None, sockets=[])
|
||||
server = ThreadServer(config=Config(app=app), worker_thread=thread)
|
||||
config = Config(app=app, callback_progress=thread.record_heartbeat)
|
||||
server = Server(config=config)
|
||||
before = thread.last_heartbeat
|
||||
|
||||
thread.last_heartbeat -= 5
|
||||
@ -193,6 +201,20 @@ def test_multithread_keep_subthread_alive_replaces_unhealthy_thread(mocker: Mock
|
||||
assert supervisor.threads[0].started is True
|
||||
|
||||
|
||||
def test_multithread_keep_subthread_alive_skips_healthcheck_until_ready(mocker: MockerFixture) -> None:
|
||||
mocker.patch("uvicorn.supervisors.multithread.Thread", FakeThread)
|
||||
supervisor = Multithread(Config(app=app, workers=1), target=lambda sockets: None, sockets=[])
|
||||
supervisor.init_threads()
|
||||
|
||||
thread = supervisor.threads[0]
|
||||
thread.healthy = False
|
||||
thread.ready_for_healthcheck = False
|
||||
|
||||
supervisor.keep_subthread_alive()
|
||||
|
||||
assert supervisor.threads[0] is thread
|
||||
|
||||
|
||||
def test_multithread_keep_subthread_alive_replaces_unhealthy_thread_without_blocking_join(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
|
||||
@ -119,7 +119,7 @@ def test_run_match_config_params() -> None:
|
||||
config_params = {
|
||||
key: repr(value)
|
||||
for key, value in inspect.signature(Config.__init__).parameters.items()
|
||||
if key not in ("self", "timeout_notify", "callback_notify")
|
||||
if key not in ("self", "timeout_notify", "callback_notify", "callback_progress")
|
||||
}
|
||||
run_params = {
|
||||
key: repr(value) for key, value in inspect.signature(run).parameters.items() if key not in ("app_dir",)
|
||||
|
||||
@ -237,6 +237,7 @@ class Config:
|
||||
timeout_graceful_shutdown: int | None = None,
|
||||
timeout_worker_healthcheck: int = 5,
|
||||
callback_notify: Callable[..., Awaitable[None]] | None = None,
|
||||
callback_progress: Callable[[], None] | None = None,
|
||||
ssl_keyfile: str | os.PathLike[str] | None = None,
|
||||
ssl_certfile: str | os.PathLike[str] | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
@ -284,6 +285,7 @@ class Config:
|
||||
self.timeout_graceful_shutdown = timeout_graceful_shutdown
|
||||
self.timeout_worker_healthcheck = timeout_worker_healthcheck
|
||||
self.callback_notify = callback_notify
|
||||
self.callback_progress = callback_progress
|
||||
self.ssl_keyfile = ssl_keyfile
|
||||
self.ssl_certfile = ssl_certfile
|
||||
self.ssl_keyfile_password = ssl_keyfile_password
|
||||
|
||||
@ -241,6 +241,9 @@ class Server:
|
||||
async def on_tick(self, counter: int) -> bool:
|
||||
# Update the default headers, once per second.
|
||||
if counter % 10 == 0:
|
||||
if self.config.callback_progress is not None:
|
||||
self.config.callback_progress()
|
||||
|
||||
current_time = time.time()
|
||||
current_date = formatdate(current_time, usegmt=True).encode()
|
||||
|
||||
|
||||
@ -20,17 +20,6 @@ from uvicorn.supervisors.multiprocess import SIGNALS
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
|
||||
class ThreadServer(Server):
|
||||
def __init__(self, config: Config, worker_thread: Thread) -> None:
|
||||
super().__init__(config)
|
||||
self.worker_thread = worker_thread
|
||||
|
||||
async def on_tick(self, counter: int) -> bool:
|
||||
if counter % 10 == 0:
|
||||
self.worker_thread.record_heartbeat()
|
||||
return await super().on_tick(counter)
|
||||
|
||||
|
||||
class Thread:
|
||||
def __init__(
|
||||
self,
|
||||
@ -47,7 +36,8 @@ class Thread:
|
||||
|
||||
def _get_target(self) -> Callable[[list[socket] | None], None]:
|
||||
if inspect.ismethod(self.real_target) and isinstance(self.real_target.__self__, Server):
|
||||
self.server = ThreadServer(config=self.config, worker_thread=self)
|
||||
self.config.callback_progress = self.record_heartbeat
|
||||
self.server = Server(config=self.config)
|
||||
return self.server.run
|
||||
return self.real_target
|
||||
|
||||
@ -80,6 +70,9 @@ class Thread:
|
||||
def is_healthy(self, timeout: float) -> bool:
|
||||
return time.monotonic() - self.last_heartbeat <= timeout
|
||||
|
||||
def is_ready_for_healthcheck(self) -> bool:
|
||||
return self.server is None or self.server.started
|
||||
|
||||
|
||||
class Multithread:
|
||||
def __init__(
|
||||
@ -151,6 +144,9 @@ class Multithread:
|
||||
self._replace_thread(idx, thread, reason="Child thread died")
|
||||
continue
|
||||
|
||||
if not thread.is_ready_for_healthcheck():
|
||||
continue
|
||||
|
||||
if not thread.is_healthy(timeout=self.config.timeout_worker_healthcheck):
|
||||
self._replace_thread(idx, thread, reason="Worker thread failed healthcheck")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user