Support custom IOLOOPs (#2435)
Co-authored-by: Thomas Grainger <tagrain@gmail.com> Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
parent
1aa02117ba
commit
52b3ec62a3
@ -101,7 +101,7 @@ filterwarnings = [
|
||||
parallel = true
|
||||
source_pkgs = ["uvicorn", "tests"]
|
||||
plugins = ["coverage_conditional_plugin"]
|
||||
omit = ["uvicorn/workers.py", "uvicorn/__main__.py"]
|
||||
omit = ["uvicorn/workers.py", "uvicorn/__main__.py", "uvicorn/_compat.py"]
|
||||
|
||||
[tool.coverage.report]
|
||||
precision = 2
|
||||
|
||||
7
tests/custom_loop_utils.py
Normal file
7
tests/custom_loop_utils.py
Normal file
@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
class CustomLoop(asyncio.SelectorEventLoop):
|
||||
pass
|
||||
@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.loops.auto import auto_loop_setup
|
||||
from uvicorn.loops.auto import auto_loop_factory
|
||||
from uvicorn.protocols.http.auto import AutoHTTPProtocol
|
||||
from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol
|
||||
from uvicorn.server import ServerState
|
||||
@ -33,10 +34,10 @@ async def app(scope, receive, send):
|
||||
|
||||
|
||||
def test_loop_auto():
|
||||
auto_loop_setup()
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
assert isinstance(policy, asyncio.events.BaseDefaultEventLoopPolicy)
|
||||
assert type(policy).__module__.startswith(expected_loop)
|
||||
loop_factory = auto_loop_factory(use_subprocess=True)
|
||||
with contextlib.closing(loop_factory()) as loop:
|
||||
assert isinstance(loop, asyncio.AbstractEventLoop)
|
||||
assert type(loop).__module__.startswith(expected_loop)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
30
tests/test_compat.py
Normal file
30
tests/test_compat.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.custom_loop_utils import CustomLoop
|
||||
from tests.utils import get_asyncio_default_loop_per_os
|
||||
from uvicorn._compat import asyncio_run
|
||||
|
||||
|
||||
async def assert_event_loop(expected_loop_class: type[AbstractEventLoop]):
|
||||
assert isinstance(asyncio.get_running_loop(), expected_loop_class)
|
||||
|
||||
|
||||
def test_asyncio_run__default_loop_factory() -> None:
|
||||
asyncio_run(assert_event_loop(get_asyncio_default_loop_per_os()), loop_factory=None)
|
||||
|
||||
|
||||
def test_asyncio_run__custom_loop_factory() -> None:
|
||||
asyncio_run(assert_event_loop(CustomLoop), loop_factory=CustomLoop)
|
||||
|
||||
|
||||
def test_asyncio_run__passing_a_non_awaitable_callback_should_throw_error() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
asyncio_run(
|
||||
lambda: None, # type: ignore
|
||||
loop_factory=CustomLoop,
|
||||
)
|
||||
@ -8,6 +8,7 @@ import os
|
||||
import socket
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from contextlib import closing
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, Callable, Literal
|
||||
from unittest.mock import MagicMock
|
||||
@ -16,16 +17,10 @@ import pytest
|
||||
import yaml
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.utils import as_cwd
|
||||
from uvicorn._types import (
|
||||
ASGIApplication,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
Environ,
|
||||
Scope,
|
||||
StartResponse,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from tests.custom_loop_utils import CustomLoop
|
||||
from tests.utils import as_cwd, get_asyncio_default_loop_per_os
|
||||
from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Environ, Scope, StartResponse
|
||||
from uvicorn.config import Config, LoopFactoryType
|
||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
from uvicorn.middleware.wsgi import WSGIMiddleware
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
@ -545,3 +540,48 @@ def test_warn_when_using_reload_and_workers(caplog: pytest.LogCaptureFixture) ->
|
||||
Config(app=asgi_app, reload=True, workers=2)
|
||||
assert len(caplog.records) == 1
|
||||
assert '"workers" flag is ignored when reloading is enabled.' in caplog.records[0].message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("loop_type", "expected_loop_factory"),
|
||||
[
|
||||
("none", None),
|
||||
("asyncio", get_asyncio_default_loop_per_os()),
|
||||
],
|
||||
)
|
||||
def test_get_loop_factory(loop_type: LoopFactoryType, expected_loop_factory: Any):
|
||||
config = Config(app=asgi_app, loop=loop_type)
|
||||
loop_factory = config.get_loop_factory()
|
||||
if loop_factory is None:
|
||||
assert expected_loop_factory is loop_factory
|
||||
else:
|
||||
loop = loop_factory()
|
||||
with closing(loop):
|
||||
assert loop is not None
|
||||
assert isinstance(loop, expected_loop_factory)
|
||||
|
||||
|
||||
def test_custom_loop__importable_custom_loop_setup_function() -> None:
|
||||
config = Config(app=asgi_app, loop="tests.custom_loop_utils:CustomLoop")
|
||||
config.load()
|
||||
loop_factory = config.get_loop_factory()
|
||||
assert loop_factory, "Loop factory should be set"
|
||||
event_loop = loop_factory()
|
||||
with closing(event_loop):
|
||||
assert event_loop is not None
|
||||
assert isinstance(event_loop, CustomLoop)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
|
||||
def test_custom_loop__not_importable_custom_loop_setup_function(caplog: pytest.LogCaptureFixture) -> None:
|
||||
config = Config(app=asgi_app, loop="tests.test_config:non_existing_setup_function")
|
||||
config.load()
|
||||
with pytest.raises(SystemExit):
|
||||
config.get_loop_factory()
|
||||
error_messages = [
|
||||
record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR"
|
||||
]
|
||||
assert (
|
||||
'Error loading custom loop setup function. Attribute "non_existing_setup_function" not found in module "tests.test_config".' # noqa: E501
|
||||
== error_messages.pop(0)
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from pathlib import Path
|
||||
@ -44,3 +45,11 @@ def as_cwd(path: Path):
|
||||
yield
|
||||
finally:
|
||||
os.chdir(prev_cwd)
|
||||
|
||||
|
||||
def get_asyncio_default_loop_per_os() -> type[asyncio.AbstractEventLoop]:
|
||||
"""Get the default asyncio loop per OS."""
|
||||
if sys.platform == "win32":
|
||||
return asyncio.ProactorEventLoop # type: ignore # pragma: nocover
|
||||
else:
|
||||
return asyncio.SelectorEventLoop # pragma: nocover
|
||||
|
||||
84
uvicorn/_compat.py
Normal file
84
uvicorn/_compat.py
Normal file
@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
asyncio_run = asyncio.run
|
||||
elif sys.version_info >= (3, 11):
|
||||
|
||||
def asyncio_run(
|
||||
main: Coroutine[Any, Any, _T],
|
||||
*,
|
||||
debug: bool = False,
|
||||
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
|
||||
) -> _T:
|
||||
# asyncio.run from Python 3.12
|
||||
# https://docs.python.org/3/license.html#psf-license
|
||||
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
|
||||
return runner.run(main)
|
||||
|
||||
else:
|
||||
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
|
||||
# https://docs.python.org/3/license.html#psf-license
|
||||
def asyncio_run(
|
||||
main: Coroutine[Any, Any, _T],
|
||||
*,
|
||||
debug: bool = False,
|
||||
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
|
||||
) -> _T:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("asyncio.run() cannot be called from a running event loop")
|
||||
|
||||
if not asyncio.iscoroutine(main):
|
||||
raise ValueError(f"a coroutine was expected, got {main!r}")
|
||||
|
||||
if loop_factory is None:
|
||||
loop = asyncio.new_event_loop()
|
||||
else:
|
||||
loop = loop_factory()
|
||||
try:
|
||||
if loop_factory is None:
|
||||
asyncio.set_event_loop(loop)
|
||||
if debug is not None:
|
||||
loop.set_debug(debug)
|
||||
return loop.run_until_complete(main)
|
||||
finally:
|
||||
try:
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.run_until_complete(loop.shutdown_default_executor())
|
||||
finally:
|
||||
if loop_factory is None:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
to_cancel = asyncio.all_tasks(loop)
|
||||
if not to_cancel:
|
||||
return
|
||||
|
||||
for task in to_cancel:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
|
||||
|
||||
for task in to_cancel:
|
||||
if task.cancelled():
|
||||
continue
|
||||
if task.exception() is not None:
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "unhandled exception during asyncio.run() shutdown",
|
||||
"exception": task.exception(),
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
@ -27,7 +27,7 @@ from uvicorn.middleware.wsgi import WSGIMiddleware
|
||||
HTTPProtocolType = Literal["auto", "h11", "httptools"]
|
||||
WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"]
|
||||
LifespanType = Literal["auto", "on", "off"]
|
||||
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
|
||||
LoopFactoryType = Literal["none", "auto", "asyncio", "uvloop"]
|
||||
InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]
|
||||
|
||||
LOG_LEVELS: dict[str, int] = {
|
||||
@ -55,11 +55,11 @@ LIFESPAN: dict[LifespanType, str] = {
|
||||
"on": "uvicorn.lifespan.on:LifespanOn",
|
||||
"off": "uvicorn.lifespan.off:LifespanOff",
|
||||
}
|
||||
LOOP_SETUPS: dict[LoopSetupType, str | None] = {
|
||||
LOOP_FACTORIES: dict[LoopFactoryType | str, str | None] = {
|
||||
"none": None,
|
||||
"auto": "uvicorn.loops.auto:auto_loop_setup",
|
||||
"asyncio": "uvicorn.loops.asyncio:asyncio_setup",
|
||||
"uvloop": "uvicorn.loops.uvloop:uvloop_setup",
|
||||
"auto": "uvicorn.loops.auto:auto_loop_factory",
|
||||
"asyncio": "uvicorn.loops.asyncio:asyncio_loop_factory",
|
||||
"uvloop": "uvicorn.loops.uvloop:uvloop_loop_factory",
|
||||
}
|
||||
INTERFACES: list[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]
|
||||
|
||||
@ -182,7 +182,7 @@ class Config:
|
||||
port: int = 8000,
|
||||
uds: str | None = None,
|
||||
fd: int | None = None,
|
||||
loop: LoopSetupType = "auto",
|
||||
loop: LoopFactoryType | str = "auto",
|
||||
http: type[asyncio.Protocol] | HTTPProtocolType = "auto",
|
||||
ws: type[asyncio.Protocol] | WSProtocolType = "auto",
|
||||
ws_max_size: int = 16 * 1024 * 1024,
|
||||
@ -473,10 +473,18 @@ class Config:
|
||||
|
||||
self.loaded = True
|
||||
|
||||
def setup_event_loop(self) -> None:
|
||||
loop_setup: Callable | None = import_from_string(LOOP_SETUPS[self.loop])
|
||||
if loop_setup is not None:
|
||||
loop_setup(use_subprocess=self.use_subprocess)
|
||||
def get_loop_factory(self) -> Callable[[], asyncio.AbstractEventLoop] | None:
|
||||
if self.loop in LOOP_FACTORIES:
|
||||
loop_factory: Callable | None = import_from_string(LOOP_FACTORIES[self.loop])
|
||||
else:
|
||||
try:
|
||||
return import_from_string(self.loop)
|
||||
except ImportFromStringError as exc:
|
||||
logger.error("Error loading custom loop setup function. %s" % exc)
|
||||
sys.exit(1)
|
||||
if loop_factory is None:
|
||||
return None
|
||||
return loop_factory(use_subprocess=self.use_subprocess)
|
||||
|
||||
def bind_socket(self) -> socket.socket:
|
||||
logger_args: list[str | int]
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def asyncio_setup(use_subprocess: bool = False) -> None:
|
||||
if sys.platform == "win32" and use_subprocess:
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # pragma: full coverage
|
||||
def asyncio_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
|
||||
if sys.platform == "win32" and not use_subprocess:
|
||||
return asyncio.ProactorEventLoop
|
||||
return asyncio.SelectorEventLoop
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
def auto_loop_setup(use_subprocess: bool = False) -> None:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def auto_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
|
||||
try:
|
||||
import uvloop # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
from uvicorn.loops.asyncio import asyncio_setup as loop_setup
|
||||
from uvicorn.loops.asyncio import asyncio_loop_factory as loop_factory
|
||||
|
||||
loop_setup(use_subprocess=use_subprocess)
|
||||
return loop_factory(use_subprocess=use_subprocess)
|
||||
else: # pragma: no cover
|
||||
from uvicorn.loops.uvloop import uvloop_setup
|
||||
from uvicorn.loops.uvloop import uvloop_loop_factory
|
||||
|
||||
uvloop_setup(use_subprocess=use_subprocess)
|
||||
return uvloop_loop_factory(use_subprocess=use_subprocess)
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
|
||||
import uvloop
|
||||
|
||||
|
||||
def uvloop_setup(use_subprocess: bool = False) -> None:
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
def uvloop_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]:
|
||||
return uvloop.new_event_loop
|
||||
|
||||
@ -20,14 +20,14 @@ from uvicorn.config import (
|
||||
LIFESPAN,
|
||||
LOG_LEVELS,
|
||||
LOGGING_CONFIG,
|
||||
LOOP_SETUPS,
|
||||
LOOP_FACTORIES,
|
||||
SSL_PROTOCOL_VERSION,
|
||||
WS_PROTOCOLS,
|
||||
Config,
|
||||
HTTPProtocolType,
|
||||
InterfaceType,
|
||||
LifespanType,
|
||||
LoopSetupType,
|
||||
LoopFactoryType,
|
||||
WSProtocolType,
|
||||
)
|
||||
from uvicorn.server import Server
|
||||
@ -37,7 +37,7 @@ LEVEL_CHOICES = click.Choice(list(LOG_LEVELS.keys()))
|
||||
HTTP_CHOICES = click.Choice(list(HTTP_PROTOCOLS.keys()))
|
||||
WS_CHOICES = click.Choice(list(WS_PROTOCOLS.keys()))
|
||||
LIFESPAN_CHOICES = click.Choice(list(LIFESPAN.keys()))
|
||||
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
|
||||
LOOP_CHOICES = click.Choice([key for key in LOOP_FACTORIES.keys() if key != "none"])
|
||||
INTERFACE_CHOICES = click.Choice(INTERFACES)
|
||||
|
||||
STARTUP_FAILURE = 3
|
||||
@ -120,7 +120,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
"--loop",
|
||||
type=LOOP_CHOICES,
|
||||
default="auto",
|
||||
help="Event loop implementation.",
|
||||
help="Event loop factory implementation.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
@ -367,7 +367,7 @@ def main(
|
||||
port: int,
|
||||
uds: str,
|
||||
fd: int,
|
||||
loop: LoopSetupType,
|
||||
loop: LoopFactoryType | str,
|
||||
http: HTTPProtocolType,
|
||||
ws: WSProtocolType,
|
||||
ws_max_size: int,
|
||||
@ -468,7 +468,7 @@ def run(
|
||||
port: int = 8000,
|
||||
uds: str | None = None,
|
||||
fd: int | None = None,
|
||||
loop: LoopSetupType = "auto",
|
||||
loop: LoopFactoryType | str = "auto",
|
||||
http: type[asyncio.Protocol] | HTTPProtocolType = "auto",
|
||||
ws: type[asyncio.Protocol] | WSProtocolType = "auto",
|
||||
ws_max_size: int = 16777216,
|
||||
|
||||
@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Union
|
||||
|
||||
import click
|
||||
|
||||
from uvicorn._compat import asyncio_run
|
||||
from uvicorn.config import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -63,8 +64,7 @@ class Server:
|
||||
self._captured_signals: list[int] = []
|
||||
|
||||
def run(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
self.config.setup_event_loop()
|
||||
return asyncio.run(self.serve(sockets=sockets))
|
||||
return asyncio_run(self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory())
|
||||
|
||||
async def serve(self, sockets: list[socket.socket] | None = None) -> None:
|
||||
with self.capture_signals():
|
||||
|
||||
@ -10,6 +10,7 @@ from typing import Any
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from gunicorn.workers.base import Worker
|
||||
|
||||
from uvicorn._compat import asyncio_run
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.server import Server
|
||||
|
||||
@ -70,10 +71,6 @@ class UvicornWorker(Worker):
|
||||
|
||||
self.config = Config(**config_kwargs)
|
||||
|
||||
def init_process(self) -> None:
|
||||
self.config.setup_event_loop()
|
||||
super().init_process()
|
||||
|
||||
def init_signals(self) -> None:
|
||||
# Reset signals so Gunicorn doesn't swallow subprocess return codes
|
||||
# other signals are set up by Server.install_signal_handlers()
|
||||
@ -104,7 +101,7 @@ class UvicornWorker(Worker):
|
||||
sys.exit(Arbiter.WORKER_BOOT_ERROR)
|
||||
|
||||
def run(self) -> None:
|
||||
return asyncio.run(self._serve())
|
||||
return asyncio_run(self._serve(), loop_factory=self.config.get_loop_factory())
|
||||
|
||||
async def callback_notify(self) -> None:
|
||||
self.notify()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user