Support custom IOLOOPs (#2435)

Co-authored-by: Thomas Grainger <tagrain@gmail.com>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
Nir Geller 2025-07-01 20:55:34 +03:00 committed by GitHub
parent 1aa02117ba
commit 52b3ec62a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 238 additions and 52 deletions

View File

@ -101,7 +101,7 @@ filterwarnings = [
parallel = true
source_pkgs = ["uvicorn", "tests"]
plugins = ["coverage_conditional_plugin"]
omit = ["uvicorn/workers.py", "uvicorn/__main__.py"]
omit = ["uvicorn/workers.py", "uvicorn/__main__.py", "uvicorn/_compat.py"]
[tool.coverage.report]
precision = 2

View File

@ -0,0 +1,7 @@
from __future__ import annotations
import asyncio
class CustomLoop(asyncio.SelectorEventLoop):
pass

View File

@ -1,10 +1,11 @@
import asyncio
import contextlib
import importlib
import pytest
from uvicorn.config import Config
from uvicorn.loops.auto import auto_loop_setup
from uvicorn.loops.auto import auto_loop_factory
from uvicorn.protocols.http.auto import AutoHTTPProtocol
from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol
from uvicorn.server import ServerState
@ -33,10 +34,10 @@ async def app(scope, receive, send):
def test_loop_auto():
auto_loop_setup()
policy = asyncio.get_event_loop_policy()
assert isinstance(policy, asyncio.events.BaseDefaultEventLoopPolicy)
assert type(policy).__module__.startswith(expected_loop)
loop_factory = auto_loop_factory(use_subprocess=True)
with contextlib.closing(loop_factory()) as loop:
assert isinstance(loop, asyncio.AbstractEventLoop)
assert type(loop).__module__.startswith(expected_loop)
@pytest.mark.anyio

30
tests/test_compat.py Normal file
View File

@ -0,0 +1,30 @@
from __future__ import annotations
import asyncio
from asyncio import AbstractEventLoop
import pytest
from tests.custom_loop_utils import CustomLoop
from tests.utils import get_asyncio_default_loop_per_os
from uvicorn._compat import asyncio_run
async def assert_event_loop(expected_loop_class: type[AbstractEventLoop]):
assert isinstance(asyncio.get_running_loop(), expected_loop_class)
def test_asyncio_run__default_loop_factory() -> None:
asyncio_run(assert_event_loop(get_asyncio_default_loop_per_os()), loop_factory=None)
def test_asyncio_run__custom_loop_factory() -> None:
asyncio_run(assert_event_loop(CustomLoop), loop_factory=CustomLoop)
def test_asyncio_run__passing_a_non_awaitable_callback_should_throw_error() -> None:
with pytest.raises(ValueError):
asyncio_run(
lambda: None, # type: ignore
loop_factory=CustomLoop,
)

View File

@ -8,6 +8,7 @@ import os
import socket
import sys
from collections.abc import Iterator
from contextlib import closing
from pathlib import Path
from typing import IO, Any, Callable, Literal
from unittest.mock import MagicMock
@ -16,16 +17,10 @@ import pytest
import yaml
from pytest_mock import MockerFixture
from tests.utils import as_cwd
from uvicorn._types import (
ASGIApplication,
ASGIReceiveCallable,
ASGISendCallable,
Environ,
Scope,
StartResponse,
)
from uvicorn.config import Config
from tests.custom_loop_utils import CustomLoop
from tests.utils import as_cwd, get_asyncio_default_loop_per_os
from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Environ, Scope, StartResponse
from uvicorn.config import Config, LoopFactoryType
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from uvicorn.middleware.wsgi import WSGIMiddleware
from uvicorn.protocols.http.h11_impl import H11Protocol
@ -545,3 +540,48 @@ def test_warn_when_using_reload_and_workers(caplog: pytest.LogCaptureFixture) ->
Config(app=asgi_app, reload=True, workers=2)
assert len(caplog.records) == 1
assert '"workers" flag is ignored when reloading is enabled.' in caplog.records[0].message
@pytest.mark.parametrize(
("loop_type", "expected_loop_factory"),
[
("none", None),
("asyncio", get_asyncio_default_loop_per_os()),
],
)
def test_get_loop_factory(loop_type: LoopFactoryType, expected_loop_factory: Any):
config = Config(app=asgi_app, loop=loop_type)
loop_factory = config.get_loop_factory()
if loop_factory is None:
assert expected_loop_factory is loop_factory
else:
loop = loop_factory()
with closing(loop):
assert loop is not None
assert isinstance(loop, expected_loop_factory)
def test_custom_loop__importable_custom_loop_setup_function() -> None:
config = Config(app=asgi_app, loop="tests.custom_loop_utils:CustomLoop")
config.load()
loop_factory = config.get_loop_factory()
assert loop_factory, "Loop factory should be set"
event_loop = loop_factory()
with closing(event_loop):
assert event_loop is not None
assert isinstance(event_loop, CustomLoop)
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
def test_custom_loop__not_importable_custom_loop_setup_function(caplog: pytest.LogCaptureFixture) -> None:
config = Config(app=asgi_app, loop="tests.test_config:non_existing_setup_function")
config.load()
with pytest.raises(SystemExit):
config.get_loop_factory()
error_messages = [
record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR"
]
assert (
'Error loading custom loop setup function. Attribute "non_existing_setup_function" not found in module "tests.test_config".' # noqa: E501
== error_messages.pop(0)
)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import os
import signal
import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path
@ -44,3 +45,11 @@ def as_cwd(path: Path):
yield
finally:
os.chdir(prev_cwd)
def get_asyncio_default_loop_per_os() -> type[asyncio.AbstractEventLoop]:
"""Get the default asyncio loop per OS."""
if sys.platform == "win32":
return asyncio.ProactorEventLoop # type: ignore # pragma: nocover
else:
return asyncio.SelectorEventLoop # pragma: nocover

84
uvicorn/_compat.py Normal file
View File

@ -0,0 +1,84 @@
from __future__ import annotations
import asyncio
import sys
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar
_T = TypeVar("_T")
if sys.version_info >= (3, 12):
asyncio_run = asyncio.run
elif sys.version_info >= (3, 11):
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
# asyncio.run from Python 3.12
# https://docs.python.org/3/license.html#psf-license
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(main)
else:
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
# https://docs.python.org/3/license.html#psf-license
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError("asyncio.run() cannot be called from a running event loop")
if not asyncio.iscoroutine(main):
raise ValueError(f"a coroutine was expected, got {main!r}")
if loop_factory is None:
loop = asyncio.new_event_loop()
else:
loop = loop_factory()
try:
if loop_factory is None:
asyncio.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if loop_factory is None:
asyncio.set_event_loop(None)
loop.close()
def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)

View File

@ -27,7 +27,7 @@ from uvicorn.middleware.wsgi import WSGIMiddleware
HTTPProtocolType = Literal["auto", "h11", "httptools"]
WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"]
LifespanType = Literal["auto", "on", "off"]
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
LoopFactoryType = Literal["none", "auto", "asyncio", "uvloop"]
InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]
LOG_LEVELS: dict[str, int] = {
@ -55,11 +55,11 @@ LIFESPAN: dict[LifespanType, str] = {
"on": "uvicorn.lifespan.on:LifespanOn",
"off": "uvicorn.lifespan.off:LifespanOff",
}
LOOP_SETUPS: dict[LoopSetupType, str | None] = {
LOOP_FACTORIES: dict[LoopFactoryType | str, str | None] = {
"none": None,
"auto": "uvicorn.loops.auto:auto_loop_setup",
"asyncio": "uvicorn.loops.asyncio:asyncio_setup",
"uvloop": "uvicorn.loops.uvloop:uvloop_setup",
"auto": "uvicorn.loops.auto:auto_loop_factory",
"asyncio": "uvicorn.loops.asyncio:asyncio_loop_factory",
"uvloop": "uvicorn.loops.uvloop:uvloop_loop_factory",
}
INTERFACES: list[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]
@ -182,7 +182,7 @@ class Config:
port: int = 8000,
uds: str | None = None,
fd: int | None = None,
loop: LoopSetupType = "auto",
loop: LoopFactoryType | str = "auto",
http: type[asyncio.Protocol] | HTTPProtocolType = "auto",
ws: type[asyncio.Protocol] | WSProtocolType = "auto",
ws_max_size: int = 16 * 1024 * 1024,
@ -473,10 +473,18 @@ class Config:
self.loaded = True
def setup_event_loop(self) -> None:
loop_setup: Callable | None = import_from_string(LOOP_SETUPS[self.loop])
if loop_setup is not None:
loop_setup(use_subprocess=self.use_subprocess)
def get_loop_factory(self) -> Callable[[], asyncio.AbstractEventLoop] | None:
if self.loop in LOOP_FACTORIES:
loop_factory: Callable | None = import_from_string(LOOP_FACTORIES[self.loop])
else:
try:
return import_from_string(self.loop)
except ImportFromStringError as exc:
logger.error("Error loading custom loop setup function. %s" % exc)
sys.exit(1)
if loop_factory is None:
return None
return loop_factory(use_subprocess=self.use_subprocess)
def bind_socket(self) -> socket.socket:
logger_args: list[str | int]

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import asyncio
import logging
import sys
logger = logging.getLogger("uvicorn.error")
from collections.abc import Callable
def asyncio_setup(use_subprocess: bool = False) -> None:
if sys.platform == "win32" and use_subprocess:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # pragma: full coverage
def asyncio_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
if sys.platform == "win32" and not use_subprocess:
return asyncio.ProactorEventLoop
return asyncio.SelectorEventLoop

View File

@ -1,11 +1,17 @@
def auto_loop_setup(use_subprocess: bool = False) -> None:
from __future__ import annotations
import asyncio
from collections.abc import Callable
def auto_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
try:
import uvloop # noqa
except ImportError: # pragma: no cover
from uvicorn.loops.asyncio import asyncio_setup as loop_setup
from uvicorn.loops.asyncio import asyncio_loop_factory as loop_factory
loop_setup(use_subprocess=use_subprocess)
return loop_factory(use_subprocess=use_subprocess)
else: # pragma: no cover
from uvicorn.loops.uvloop import uvloop_setup
from uvicorn.loops.uvloop import uvloop_loop_factory
uvloop_setup(use_subprocess=use_subprocess)
return uvloop_loop_factory(use_subprocess=use_subprocess)

View File

@ -1,7 +1,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
import uvloop
def uvloop_setup(use_subprocess: bool = False) -> None:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
def uvloop_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
return uvloop.new_event_loop

View File

@ -20,14 +20,14 @@ from uvicorn.config import (
LIFESPAN,
LOG_LEVELS,
LOGGING_CONFIG,
LOOP_SETUPS,
LOOP_FACTORIES,
SSL_PROTOCOL_VERSION,
WS_PROTOCOLS,
Config,
HTTPProtocolType,
InterfaceType,
LifespanType,
LoopSetupType,
LoopFactoryType,
WSProtocolType,
)
from uvicorn.server import Server
@ -37,7 +37,7 @@ LEVEL_CHOICES = click.Choice(list(LOG_LEVELS.keys()))
HTTP_CHOICES = click.Choice(list(HTTP_PROTOCOLS.keys()))
WS_CHOICES = click.Choice(list(WS_PROTOCOLS.keys()))
LIFESPAN_CHOICES = click.Choice(list(LIFESPAN.keys()))
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
LOOP_CHOICES = click.Choice([key for key in LOOP_FACTORIES.keys() if key != "none"])
INTERFACE_CHOICES = click.Choice(INTERFACES)
STARTUP_FAILURE = 3
@ -120,7 +120,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
"--loop",
type=LOOP_CHOICES,
default="auto",
help="Event loop implementation.",
help="Event loop factory implementation.",
show_default=True,
)
@click.option(
@ -367,7 +367,7 @@ def main(
port: int,
uds: str,
fd: int,
loop: LoopSetupType,
loop: LoopFactoryType | str,
http: HTTPProtocolType,
ws: WSProtocolType,
ws_max_size: int,
@ -468,7 +468,7 @@ def run(
port: int = 8000,
uds: str | None = None,
fd: int | None = None,
loop: LoopSetupType = "auto",
loop: LoopFactoryType | str = "auto",
http: type[asyncio.Protocol] | HTTPProtocolType = "auto",
ws: type[asyncio.Protocol] | WSProtocolType = "auto",
ws_max_size: int = 16777216,

View File

@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Union
import click
from uvicorn._compat import asyncio_run
from uvicorn.config import Config
if TYPE_CHECKING:
@ -63,8 +64,7 @@ class Server:
self._captured_signals: list[int] = []
def run(self, sockets: list[socket.socket] | None = None) -> None:
self.config.setup_event_loop()
return asyncio.run(self.serve(sockets=sockets))
return asyncio_run(self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory())
async def serve(self, sockets: list[socket.socket] | None = None) -> None:
with self.capture_signals():

View File

@ -10,6 +10,7 @@ from typing import Any
from gunicorn.arbiter import Arbiter
from gunicorn.workers.base import Worker
from uvicorn._compat import asyncio_run
from uvicorn.config import Config
from uvicorn.server import Server
@ -70,10 +71,6 @@ class UvicornWorker(Worker):
self.config = Config(**config_kwargs)
def init_process(self) -> None:
self.config.setup_event_loop()
super().init_process()
def init_signals(self) -> None:
# Reset signals so Gunicorn doesn't swallow subprocess return codes
# other signals are set up by Server.install_signal_handlers()
@ -104,7 +101,7 @@ class UvicornWorker(Worker):
sys.exit(Arbiter.WORKER_BOOT_ERROR)
def run(self) -> None:
return asyncio.run(self._serve())
return asyncio_run(self._serve(), loop_factory=self.config.get_loop_factory())
async def callback_notify(self) -> None:
self.notify()