Compare commits
9 Commits
main
...
fm/async-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
46555b5bed | ||
|
|
57534e605d | ||
|
|
c4f5c76462 | ||
|
|
bf17e875e7 | ||
|
|
ed4ffac972 | ||
|
|
2a01073135 | ||
|
|
77f48d1351 | ||
|
|
a14dd4b38b | ||
|
|
34fd9cd460 |
@ -8,8 +8,10 @@ twine
|
||||
wheel
|
||||
|
||||
# Testing
|
||||
anyio==2.0.2
|
||||
autoflake
|
||||
black
|
||||
curio==1.4
|
||||
flake8
|
||||
isort
|
||||
pytest
|
||||
@ -17,6 +19,7 @@ pytest-mock
|
||||
requests
|
||||
mypy
|
||||
trustme
|
||||
trio==0.17.*
|
||||
cryptography
|
||||
coverage
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@ -47,6 +47,8 @@ minimal_requirements = [
|
||||
"click==7.*",
|
||||
"h11>=0.8",
|
||||
"typing-extensions;" + env_marker_below_38,
|
||||
"async_generator; python_version<'3.7'",
|
||||
"async_exit_stack; python_version<'3.7'",
|
||||
]
|
||||
|
||||
extra_requirements = [
|
||||
|
||||
0
tests/async_agnostic/__init__.py
Normal file
0
tests/async_agnostic/__init__.py
Normal file
32
tests/async_agnostic/conftest.py
Normal file
32
tests/async_agnostic/conftest.py
Normal file
@ -0,0 +1,32 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"),
|
||||
pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"),
|
||||
pytest.param(
|
||||
("trio", {"restrict_keyboard_interrupt_to_checkpoints": True}), id="trio"
|
||||
),
|
||||
pytest.param(("curio", {}), id="curio"),
|
||||
],
|
||||
)
|
||||
def anyio_backend(request: Any) -> str:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ensure_event_loop():
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# We use anyio; its pytest plugin shuts down and unsets the asyncio event loop
|
||||
# after each test cases. For some reason on Windows when the loop is unset
|
||||
# asyncio won't create one again by itself. So we give it a hand.
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
yield
|
||||
85
tests/async_agnostic/test_default_headers.py
Normal file
85
tests/async_agnostic/test_default_headers.py
Normal file
@ -0,0 +1,85 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from uvicorn import Config
|
||||
from uvicorn._async_agnostic import Server
|
||||
|
||||
from .utils import HTTP11_IMPLEMENTATIONS
|
||||
|
||||
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
assert scope["type"] == "http"
|
||||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio", "curio"])
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
def test_default_headers(async_library: str, http: str) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
async_library=async_library,
|
||||
http=http,
|
||||
limit_max_requests=1,
|
||||
)
|
||||
|
||||
with Server(config=config).run_in_thread():
|
||||
response = requests.get("http://localhost:8000")
|
||||
assert response.headers["server"] == "uvicorn" and response.headers["date"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio", "curio"])
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
def test_override_server_header(async_library: str, http: str) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
async_library=async_library,
|
||||
http=http,
|
||||
limit_max_requests=1,
|
||||
headers=[("Server", "overridden")],
|
||||
)
|
||||
|
||||
with Server(config=config).run_in_thread():
|
||||
response = requests.get("http://localhost:8000")
|
||||
assert response.headers["server"] == "overridden" and response.headers["date"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio", "curio"])
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
def test_override_server_header_multiple_times(async_library: str, http: str) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
async_library=async_library,
|
||||
http=http,
|
||||
limit_max_requests=1,
|
||||
headers=[("Server", "overridden"), ("Server", "another-value")],
|
||||
)
|
||||
|
||||
with Server(config=config).run_in_thread():
|
||||
response = requests.get("http://localhost:8000")
|
||||
assert (
|
||||
response.headers["server"] == "overridden, another-value"
|
||||
and response.headers["date"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio", "curio"])
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
def test_add_additional_header(async_library: str, http: str) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
async_library=async_library,
|
||||
http=http,
|
||||
limit_max_requests=1,
|
||||
headers=[("X-Additional", "new-value")],
|
||||
)
|
||||
|
||||
with Server(config=config).run_in_thread():
|
||||
response = requests.get("http://localhost:8000")
|
||||
assert (
|
||||
response.headers["x-additional"] == "new-value"
|
||||
and response.headers["server"] == "uvicorn"
|
||||
and response.headers["date"]
|
||||
)
|
||||
235
tests/async_agnostic/test_http.py
Normal file
235
tests/async_agnostic/test_http.py
Normal file
@ -0,0 +1,235 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from uvicorn._async_agnostic.backends.auto import AutoBackend
|
||||
from uvicorn._async_agnostic.backends.base import AsyncSocket
|
||||
from uvicorn._async_agnostic.http11.handler import handle_http11
|
||||
from uvicorn._async_agnostic.state import ServerState
|
||||
from uvicorn.config import Config
|
||||
|
||||
from ..response import Response
|
||||
from .utils import HTTP11_IMPLEMENTATIONS
|
||||
|
||||
SIMPLE_GET_REQUEST = b"\r\n".join(
|
||||
[
|
||||
b"GET / HTTP/1.1",
|
||||
b"Host: example.org",
|
||||
b"",
|
||||
b"",
|
||||
]
|
||||
)
|
||||
|
||||
SIMPLE_HEAD_REQUEST = b"\r\n".join(
|
||||
[
|
||||
b"HEAD / HTTP/1.1",
|
||||
b"Host: example.org",
|
||||
b"",
|
||||
b"",
|
||||
]
|
||||
)
|
||||
|
||||
SIMPLE_POST_REQUEST = b"\r\n".join(
|
||||
[
|
||||
b"POST / HTTP/1.1",
|
||||
b"Host: example.org",
|
||||
b"Content-Type: application/json",
|
||||
b"Content-Length: 18",
|
||||
b"",
|
||||
b'{"hello": "world"}',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class MockSocket(AsyncSocket):
|
||||
"""
|
||||
An in-memory socket for testing purposes.
|
||||
"""
|
||||
|
||||
def __init__(self, request: bytes, prevent_keepalive_loop: bool = True) -> None:
|
||||
self._backend = AutoBackend()
|
||||
self._request = request
|
||||
self._prevent_keepalive_loop = prevent_keepalive_loop
|
||||
self._response = b""
|
||||
self._readable = False
|
||||
self._is_closed = False
|
||||
self._response_received = self._backend.create_event()
|
||||
|
||||
def get_remote_addr(self) -> Tuple[str, int]:
|
||||
return ("127.0.0.1", 8000)
|
||||
|
||||
def get_local_addr(self) -> Tuple[str, int]:
|
||||
return ("127.0.0.1", 42424)
|
||||
|
||||
is_ssl = False
|
||||
|
||||
@property
|
||||
def response(self) -> bytes:
|
||||
return self._response
|
||||
|
||||
async def read(self, n: int) -> bytes:
|
||||
if self._readable:
|
||||
return b""
|
||||
|
||||
if not self._request:
|
||||
while not self._readable:
|
||||
await self._backend.sleep(0.01)
|
||||
return b""
|
||||
|
||||
data, self._request = self._request[:n], self._request[n:]
|
||||
|
||||
if not self._request and self._prevent_keepalive_loop:
|
||||
await self._response_received.set()
|
||||
self._readable = True
|
||||
|
||||
return data
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
if data == b"":
|
||||
await self._response_received.set()
|
||||
if self._prevent_keepalive_loop:
|
||||
# Simulate a client disconnect right after having received
|
||||
# the response so that the HTTP handler doesn't run a keep-alive
|
||||
# cycle for nothing.
|
||||
self.simulate_client_disconnect()
|
||||
return
|
||||
|
||||
self._response += data
|
||||
|
||||
async def wait_response_received(self) -> None:
|
||||
await self._response_received.wait()
|
||||
|
||||
def simulate_client_disconnect(self) -> None:
|
||||
self._readable = True
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
# Simulate instantaneous acknowledgement by client.
|
||||
self._readable = True
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._is_closed = True
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
async def test_get_request(http: str) -> None:
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
sock = MockSocket(SIMPLE_GET_REQUEST)
|
||||
|
||||
await handle_http11(sock, ServerState(), Config(app=app, http=http))
|
||||
|
||||
assert b"HTTP/1.1 200 OK" in sock.response
|
||||
assert b"Hello, world" in sock.response
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
async def test_head_request(http: str) -> None:
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
sock = MockSocket(SIMPLE_HEAD_REQUEST)
|
||||
|
||||
await handle_http11(sock, ServerState(), Config(app=app, http=http))
|
||||
|
||||
assert b"HTTP/1.1 200 OK" in sock.response
|
||||
assert b"Hello, world" not in sock.response
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
async def test_post_request(http: str) -> None:
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
body = b""
|
||||
while True:
|
||||
message = await receive()
|
||||
body += message.get("body", b"")
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
response = Response(b"Body: " + body, media_type="text/plain")
|
||||
await response(scope, receive, send)
|
||||
|
||||
sock = MockSocket(SIMPLE_POST_REQUEST)
|
||||
|
||||
await handle_http11(sock, ServerState(), Config(app=app, http=http))
|
||||
|
||||
assert b"HTTP/1.1 200 OK" in sock.response
|
||||
assert b'Body: {"hello": "world"}' in sock.response
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
@pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"])
|
||||
async def test_request_logging(http: str, path: str, caplog: Any) -> None:
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
get_request_with_query_string = b"\r\n".join(
|
||||
[
|
||||
"GET {} HTTP/1.1".format(path).encode("ascii"),
|
||||
b"Host: example.org",
|
||||
b"",
|
||||
b"",
|
||||
]
|
||||
)
|
||||
|
||||
sock = MockSocket(get_request_with_query_string)
|
||||
state = ServerState()
|
||||
config = Config(app=app, http=http) # Keep this here -- configures initial logging.
|
||||
|
||||
with caplog.at_level(logging.INFO, logger="uvicorn.access"):
|
||||
logging.getLogger("uvicorn.access").propagate = True
|
||||
await handle_http11(sock, state, config)
|
||||
|
||||
assert '"GET {} HTTP/1.1" 200'.format(path) in caplog.records[0].message
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
async def test_keepalive(http: str) -> None:
|
||||
app = Response(b"", status_code=204)
|
||||
sock = MockSocket(SIMPLE_GET_REQUEST, prevent_keepalive_loop=False)
|
||||
|
||||
config = Config(app=app, http=http)
|
||||
|
||||
async with AutoBackend().start_soon(handle_http11, sock, ServerState(), config):
|
||||
await sock.wait_response_received()
|
||||
assert b"HTTP/1.1 204 No Content" in sock.response
|
||||
assert not sock.is_closed
|
||||
sock.simulate_client_disconnect()
|
||||
|
||||
assert sock.is_closed
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
async def test_keepalive_timeout(http: str) -> None:
|
||||
app = Response(b"", status_code=204)
|
||||
sock = MockSocket(SIMPLE_GET_REQUEST, prevent_keepalive_loop=False)
|
||||
|
||||
backend = AutoBackend()
|
||||
config = Config(app=app, http=http, timeout_keep_alive=0.1)
|
||||
|
||||
async with backend.start_soon(handle_http11, sock, ServerState(), config):
|
||||
await sock.wait_response_received()
|
||||
assert b"HTTP/1.1 204 No Content" in sock.response
|
||||
assert not sock.is_closed
|
||||
|
||||
await backend.sleep(0.01)
|
||||
assert not sock.is_closed
|
||||
|
||||
await backend.sleep(0.2)
|
||||
assert sock.is_closed
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
async def test_close(http: str) -> None:
|
||||
app = Response(b"", status_code=204, headers={"connection": "close"})
|
||||
sock = MockSocket(SIMPLE_GET_REQUEST, prevent_keepalive_loop=False)
|
||||
|
||||
await handle_http11(sock, state=ServerState(), config=Config(app=app, http=http))
|
||||
|
||||
assert b"HTTP/1.1 204 No Content" in sock.response
|
||||
assert sock.is_closed
|
||||
167
tests/async_agnostic/test_lifespan.py
Normal file
167
tests/async_agnostic/test_lifespan.py
Normal file
@ -0,0 +1,167 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from uvicorn._async_agnostic.backends.auto import AutoBackend
|
||||
from uvicorn._async_agnostic.exceptions import LifespanFailure
|
||||
from uvicorn._async_agnostic.lifespan import Lifespan
|
||||
from uvicorn.config import Config
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_lifespan_on() -> None:
|
||||
startup_complete = False
|
||||
shutdown_complete = False
|
||||
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
nonlocal startup_complete, shutdown_complete
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.startup"
|
||||
startup_complete = True
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.shutdown"
|
||||
shutdown_complete = True
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
config = Config(app=app, lifespan="on")
|
||||
lifespan = Lifespan(config)
|
||||
|
||||
async with AutoBackend().start_soon(lifespan.main):
|
||||
assert not startup_complete
|
||||
assert not shutdown_complete
|
||||
await lifespan.startup()
|
||||
assert startup_complete
|
||||
assert not shutdown_complete
|
||||
await lifespan.shutdown()
|
||||
assert startup_complete
|
||||
assert shutdown_complete
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_lifespan_off() -> None:
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
pass # pragma: no cover
|
||||
|
||||
config = Config(app=app, lifespan="off")
|
||||
lifespan = Lifespan(config)
|
||||
|
||||
async with AutoBackend().start_soon(lifespan.main):
|
||||
await lifespan.startup()
|
||||
await lifespan.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_lifespan_auto() -> None:
|
||||
startup_complete = False
|
||||
shutdown_complete = False
|
||||
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
nonlocal startup_complete, shutdown_complete
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.startup"
|
||||
startup_complete = True
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.shutdown"
|
||||
shutdown_complete = True
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
config = Config(app=app, lifespan="auto")
|
||||
lifespan = Lifespan(config)
|
||||
|
||||
async with AutoBackend().start_soon(lifespan.main):
|
||||
assert not startup_complete
|
||||
assert not shutdown_complete
|
||||
await lifespan.startup()
|
||||
assert startup_complete
|
||||
assert not shutdown_complete
|
||||
await lifespan.shutdown()
|
||||
assert startup_complete
|
||||
assert shutdown_complete
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_lifespan_auto_with_error() -> None:
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
assert scope["type"] == "http"
|
||||
|
||||
config = Config(app=app, lifespan="auto")
|
||||
lifespan = Lifespan(config)
|
||||
|
||||
async with AutoBackend().start_soon(lifespan.main):
|
||||
await lifespan.startup()
|
||||
await lifespan.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_lifespan_on_with_error() -> None:
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
assert scope["type"] == "lifespan"
|
||||
raise RuntimeError("Oops")
|
||||
|
||||
config = Config(app=app, lifespan="on")
|
||||
lifespan = Lifespan(config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Oops"):
|
||||
async with AutoBackend().start_soon(lifespan.main):
|
||||
await lifespan.startup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("mode", ("auto", "on"))
|
||||
@pytest.mark.parametrize("raise_exception", (True, False))
|
||||
async def test_lifespan_with_failed_startup(mode: str, raise_exception: bool) -> None:
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.startup"
|
||||
await send({"type": "lifespan.startup.failed"})
|
||||
if raise_exception:
|
||||
# App should be able to re-raise an exception if startup failed.
|
||||
raise RuntimeError()
|
||||
|
||||
config = Config(app=app, lifespan=mode)
|
||||
lifespan = Lifespan(config)
|
||||
|
||||
with pytest.raises(LifespanFailure):
|
||||
async with AutoBackend().start_soon(lifespan.main):
|
||||
await lifespan.startup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("mode", ("auto", "on"))
|
||||
async def test_lifespan_scope_asgi3app(mode: str) -> None:
|
||||
async def asgi3app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
assert scope == {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": "3.0", "spec_version": "2.0"},
|
||||
}
|
||||
|
||||
config = Config(app=asgi3app, lifespan=mode)
|
||||
lifespan = Lifespan(config)
|
||||
|
||||
async with AutoBackend().start_soon(lifespan.main):
|
||||
await lifespan.startup()
|
||||
await lifespan.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("mode", ("auto", "on"))
|
||||
async def test_lifespan_scope_asgi2app(mode: str) -> None:
|
||||
def asgi2app(scope: dict) -> Callable:
|
||||
assert scope == {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": "2.0", "spec_version": "2.0"},
|
||||
}
|
||||
|
||||
async def asgi(receive: Callable, send: Callable) -> None:
|
||||
pass
|
||||
|
||||
return asgi
|
||||
|
||||
config = Config(app=asgi2app, lifespan=mode)
|
||||
lifespan = Lifespan(config)
|
||||
|
||||
async with AutoBackend().start_soon(lifespan.main):
|
||||
await lifespan.startup()
|
||||
await lifespan.shutdown()
|
||||
88
tests/async_agnostic/test_main.py
Normal file
88
tests/async_agnostic/test_main.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from uvicorn._async_agnostic import Server
|
||||
from uvicorn.config import Config
|
||||
|
||||
from .utils import HTTP11_IMPLEMENTATIONS
|
||||
|
||||
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
assert scope["type"] == "http"
|
||||
await send({"type": "http.response.start", "status": 204, "headers": []})
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"host, url",
|
||||
[
|
||||
pytest.param(None, "http://127.0.0.1:8000", id="default"),
|
||||
pytest.param("localhost", "http://127.0.0.1:8000", id="hostname"),
|
||||
pytest.param("::1", "http://[::1]:8000", id="ipv6"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
def test_run(host: str, url: str, async_library: str, http: str) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
host=host,
|
||||
async_library=async_library,
|
||||
http=http,
|
||||
limit_max_requests=1,
|
||||
)
|
||||
|
||||
with Server(config).run_in_thread():
|
||||
response = requests.get(url)
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
|
||||
def test_run_multiprocess(async_library: str) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
workers=2,
|
||||
async_library=async_library,
|
||||
limit_max_requests=1,
|
||||
)
|
||||
|
||||
with Server(config).run_in_thread():
|
||||
response = requests.get("http://127.0.0.1:8000")
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
|
||||
def test_run_reload(async_library: str) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
reload=True,
|
||||
async_library=async_library,
|
||||
limit_max_requests=1,
|
||||
)
|
||||
|
||||
with Server(config).run_in_thread():
|
||||
response = requests.get("http://127.0.0.1:8000")
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
|
||||
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
|
||||
def test_run_with_shutdown(async_library: str, http: str) -> None:
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown_immediately() -> None:
|
||||
pass
|
||||
|
||||
config = Config(
|
||||
app=app,
|
||||
workers=2,
|
||||
shutdown_trigger=shutdown_immediately,
|
||||
async_library=async_library,
|
||||
http=http,
|
||||
)
|
||||
|
||||
with Server(config).run_in_thread():
|
||||
pass
|
||||
94
tests/async_agnostic/test_ssl.py
Normal file
94
tests/async_agnostic/test_ssl.py
Normal file
@ -0,0 +1,94 @@
|
||||
import contextlib
|
||||
import sys
|
||||
import warnings
|
||||
from functools import partialmethod
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from urllib3.exceptions import InsecureRequestWarning
|
||||
|
||||
from uvicorn._async_agnostic import Server
|
||||
from uvicorn.config import Config
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_ssl_verification(session=requests.Session): # type: ignore
|
||||
old_request = session.request
|
||||
session.request = partialmethod(old_request, verify=False)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", InsecureRequestWarning)
|
||||
yield
|
||||
|
||||
session.request = old_request
|
||||
|
||||
|
||||
async def app(scope: dict, receive: Callable, send: Callable) -> None:
|
||||
assert scope["type"] == "http"
|
||||
await send({"type": "http.response.start", "status": 204, "headers": []})
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Skipping SSL test on Windows"
|
||||
)
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
|
||||
def test_run(
|
||||
async_library: str,
|
||||
tls_ca_certificate_pem_path: str,
|
||||
tls_ca_certificate_private_key_path: str,
|
||||
) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
async_library=async_library,
|
||||
limit_max_requests=1,
|
||||
ssl_keyfile=tls_ca_certificate_private_key_path,
|
||||
ssl_certfile=tls_ca_certificate_pem_path,
|
||||
)
|
||||
|
||||
with Server(config).run_in_thread():
|
||||
with no_ssl_verification():
|
||||
response = requests.get("https://127.0.0.1:8000")
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Skipping SSL test on Windows"
|
||||
)
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
|
||||
def test_run_chain(async_library: str, tls_certificate_pem_path: str) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
async_library=async_library,
|
||||
limit_max_requests=1,
|
||||
ssl_certfile=tls_certificate_pem_path,
|
||||
)
|
||||
|
||||
with Server(config).run_in_thread():
|
||||
with no_ssl_verification():
|
||||
response = requests.get("https://127.0.0.1:8000")
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"), reason="Skipping SSL test on Windows"
|
||||
)
|
||||
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
|
||||
def test_run_password(
|
||||
async_library: str,
|
||||
tls_ca_certificate_pem_path: str,
|
||||
tls_ca_certificate_private_key_encrypted_path: str,
|
||||
) -> None:
|
||||
config = Config(
|
||||
app=app,
|
||||
async_library=async_library,
|
||||
limit_max_requests=1,
|
||||
ssl_keyfile=tls_ca_certificate_private_key_encrypted_path,
|
||||
ssl_certfile=tls_ca_certificate_pem_path,
|
||||
ssl_keyfile_password="uvicorn password for the win",
|
||||
)
|
||||
with Server(config).run_in_thread():
|
||||
with no_ssl_verification():
|
||||
response = requests.get("https://127.0.0.1:8000")
|
||||
assert response.status_code == 204
|
||||
8
tests/async_agnostic/utils.py
Normal file
8
tests/async_agnostic/utils.py
Normal file
@ -0,0 +1,8 @@
|
||||
HTTP11_IMPLEMENTATIONS = ["h11"]
|
||||
|
||||
try:
|
||||
import httptools # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
else:
|
||||
HTTP11_IMPLEMENTATIONS.append("httptools")
|
||||
5
uvicorn/_async_agnostic/__init__.py
Normal file
5
uvicorn/_async_agnostic/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .server import Server
|
||||
|
||||
__all__ = [
|
||||
"Server",
|
||||
]
|
||||
124
uvicorn/_async_agnostic/asgi.py
Normal file
124
uvicorn/_async_agnostic/asgi.py
Normal file
@ -0,0 +1,124 @@
|
||||
import logging
|
||||
from typing import AsyncIterator, Awaitable, Callable
|
||||
|
||||
from .backends.auto import AutoBackend
|
||||
from .http11.connection import HTTP11Connection
|
||||
from .utils import STATUS_PHRASES, get_path_with_query_string
|
||||
|
||||
|
||||
class ASGIRequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
conn: HTTP11Connection,
|
||||
scope: dict,
|
||||
request_body: AsyncIterator[bytes],
|
||||
send_response_body: Callable[[bytes], Awaitable[None]],
|
||||
access_log: bool,
|
||||
) -> None:
|
||||
self._conn = conn
|
||||
self._scope = scope
|
||||
self._request_body = request_body
|
||||
self._do_send_response_body = send_response_body
|
||||
self._access_log = access_log
|
||||
|
||||
self._backend = AutoBackend()
|
||||
self._logger = logging.getLogger("uvicorn.error")
|
||||
self._access_logger = logging.getLogger("uvicorn.access")
|
||||
self._response_started = False
|
||||
self._response_complete = False
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: Callable) -> None:
|
||||
try:
|
||||
result = await app(self._scope, self._receive, self._send)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
raise RuntimeError(msg % result)
|
||||
|
||||
if not self._response_started:
|
||||
msg = "ASGI callable returned without starting response."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if not self._response_complete:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
async def _send_response(self, message: dict) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
msg = "Expected ASGI message 'http.response.start', but got '%s'."
|
||||
raise RuntimeError(msg % message["type"])
|
||||
|
||||
self._response_started = True
|
||||
self._waiting_for_100_continue = False
|
||||
|
||||
status_code = message["status"]
|
||||
headers = self._conn.basic_headers() + message.get("headers", [])
|
||||
reason = STATUS_PHRASES[status_code]
|
||||
|
||||
if self._access_log:
|
||||
self._access_logger.info(
|
||||
'%s - "%s %s HTTP/%s" %d',
|
||||
self._conn.client,
|
||||
self._scope["method"],
|
||||
get_path_with_query_string(self._scope),
|
||||
self._scope["http_version"],
|
||||
status_code,
|
||||
extra={"status_code": status_code, "scope": self._scope},
|
||||
)
|
||||
|
||||
await self._conn.send_response(
|
||||
status_code=status_code, headers=headers, reason=reason
|
||||
)
|
||||
|
||||
async def _send_response_body(self, message: dict) -> None:
|
||||
if message["type"] != "http.response.body":
|
||||
msg = "Expected ASGI message 'http.response.body', but got '%s'."
|
||||
raise RuntimeError(msg % message["type"])
|
||||
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
if self._scope["method"] == "HEAD":
|
||||
body = b""
|
||||
|
||||
await self._do_send_response_body(body)
|
||||
|
||||
if not more_body:
|
||||
if body != b"":
|
||||
await self._do_send_response_body(b"")
|
||||
self._response_complete = True
|
||||
|
||||
# ASGI interface
|
||||
|
||||
async def _send(self, message: dict) -> None:
|
||||
if not self._response_started:
|
||||
await self._send_response(message)
|
||||
|
||||
elif not self._response_complete:
|
||||
await self._send_response_body(message)
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
msg = "Unexpected ASGI message '%s' sent, after response already completed."
|
||||
raise RuntimeError(msg % message["type"])
|
||||
|
||||
async def _receive(self) -> dict:
|
||||
if self._response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
try:
|
||||
chunk = await self._request_body.__anext__()
|
||||
except StopAsyncIteration:
|
||||
chunk = b""
|
||||
more_body = False
|
||||
else:
|
||||
more_body = True
|
||||
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": more_body,
|
||||
}
|
||||
0
uvicorn/_async_agnostic/backends/__init__.py
Normal file
0
uvicorn/_async_agnostic/backends/__init__.py
Normal file
355
uvicorn/_async_agnostic/backends/asyncio.py
Normal file
355
uvicorn/_async_agnostic/backends/asyncio.py
Normal file
@ -0,0 +1,355 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import socket
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from ...config import Config
|
||||
from ..compat import asynccontextmanager
|
||||
from ..state import ServerState
|
||||
from ..utils import get_sock_local_addr, get_sock_remote_addr
|
||||
from .base import (
|
||||
AsyncBackend,
|
||||
AsyncListener,
|
||||
AsyncSocket,
|
||||
Event,
|
||||
Queue,
|
||||
TaskHandle,
|
||||
TaskStatus,
|
||||
)
|
||||
|
||||
HIGH_WATER_LIMIT = 2 ** 16
|
||||
|
||||
|
||||
class AsyncioEvent(Event):
|
||||
def __init__(self) -> None:
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._event.is_set()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._event.clear()
|
||||
|
||||
|
||||
class AsyncioQueue(Queue):
|
||||
def __init__(self, size: int) -> None:
|
||||
self._queue: "asyncio.Queue[Any]" = asyncio.Queue(size)
|
||||
self._closed = False
|
||||
|
||||
async def get(self) -> Any:
|
||||
if self._closed:
|
||||
raise RuntimeError("Queue closed")
|
||||
return await self._queue.get()
|
||||
|
||||
async def put(self, item: Any) -> None:
|
||||
await self._queue.put(item)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
|
||||
class AsyncioSocket(AsyncSocket):
|
||||
def __init__(
|
||||
self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
self._stream_reader = stream_reader
|
||||
self._stream_writer = stream_writer
|
||||
|
||||
def get_local_addr(self) -> Optional[Tuple[str, int]]:
|
||||
sock = self._stream_writer.get_extra_info("socket")
|
||||
if sock is not None:
|
||||
return get_sock_local_addr(sock)
|
||||
|
||||
info = self._stream_writer.get_extra_info("peername")
|
||||
try:
|
||||
host, port = info
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
return str(host), int(port)
|
||||
|
||||
def get_remote_addr(self) -> Optional[Tuple[str, int]]:
|
||||
sock = self._stream_writer.get_extra_info("socket")
|
||||
if sock is not None:
|
||||
return get_sock_remote_addr(sock)
|
||||
|
||||
info = self._stream_writer.get_extra_info("peername")
|
||||
try:
|
||||
host, port = info
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
return str(host), int(port)
|
||||
|
||||
@property
|
||||
def is_ssl(self) -> bool:
|
||||
transport = self._stream_writer.transport
|
||||
return bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
async def read(self, n: int) -> bytes:
|
||||
return await self._stream_reader.read(n)
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
self._stream_writer.write(data)
|
||||
await self._stream_writer.drain()
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
try:
|
||||
self._stream_writer.write_eof()
|
||||
except (NotImplementedError, OSError, RuntimeError):
|
||||
pass # Likely SSL connection
|
||||
|
||||
async def aclose(self) -> None:
|
||||
try:
|
||||
self._stream_writer.close()
|
||||
await self._stream_writer.wait_closed()
|
||||
except (BrokenPipeError, ConnectionResetError):
|
||||
pass # Already closed
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._stream_writer.is_closing()
|
||||
|
||||
|
||||
class AsyncioListener(AsyncListener):
|
||||
def __init__(self, sock: socket.SocketType) -> None:
|
||||
self._sock = sock
|
||||
|
||||
@property
|
||||
def socket(self) -> socket.SocketType:
|
||||
return self._sock
|
||||
|
||||
|
||||
class AsyncioTaskHandle(TaskHandle):
|
||||
def __init__(self, cancel_event: asyncio.Event) -> None:
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
async def cancel(self) -> None:
|
||||
self._cancel_event.set()
|
||||
|
||||
|
||||
class AsyncioBackend(AsyncBackend):
|
||||
def create_event(self) -> Event:
|
||||
return AsyncioEvent()
|
||||
|
||||
def create_queue(self, size: int) -> Queue:
|
||||
return AsyncioQueue(size)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await asyncio.sleep(seconds)
|
||||
|
||||
def _get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# We're probably in a new thread.
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop
|
||||
|
||||
def run(self, async_fn: Callable, *args: Any) -> None:
|
||||
loop = self._get_event_loop()
|
||||
loop.run_until_complete(async_fn(*args))
|
||||
|
||||
async def move_on_after(
|
||||
self, seconds: float, async_fn: Callable, *args: Any
|
||||
) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(async_fn(*args), seconds)
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
@asynccontextmanager
|
||||
async def start_soon(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncIterator[None]:
|
||||
loop = self._get_event_loop()
|
||||
task = loop.create_task(async_fn(*args))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if cancel_on_exit:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def start(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncIterator[Any]:
|
||||
loop = self._get_event_loop()
|
||||
task_status = self.create_task_status()
|
||||
task = loop.create_task(async_fn(*args, task_status=task_status))
|
||||
try:
|
||||
value = await task_status.get_value()
|
||||
yield value
|
||||
finally:
|
||||
if cancel_on_exit:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def wait_then_call(
|
||||
self,
|
||||
seconds: float,
|
||||
async_fn: Callable,
|
||||
*args: Any,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
cancel_event = asyncio.Event()
|
||||
await task_status.started(AsyncioTaskHandle(cancel_event))
|
||||
|
||||
# Wait for the user to cancel the callback, or for the timeout to expire.
|
||||
# Be sure to clean up after asyncio.
|
||||
tasks: set = {
|
||||
asyncio.create_task(asyncio.sleep(seconds)),
|
||||
asyncio.create_task(cancel_event.wait()),
|
||||
}
|
||||
try:
|
||||
_, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
except asyncio.CancelledError:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
else:
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
await async_fn(*args)
|
||||
|
||||
async def serve_tcp(
|
||||
self,
|
||||
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
|
||||
state: ServerState,
|
||||
config: Config,
|
||||
*,
|
||||
sockets: List[socket.SocketType] = None,
|
||||
wait_close: Callable,
|
||||
on_close: Callable = None,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
async def asyncio_handler(
|
||||
stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
sock = AsyncioSocket(stream_reader, stream_writer)
|
||||
await handler(sock, state, config)
|
||||
|
||||
servers = []
|
||||
|
||||
if sockets is not None:
|
||||
# Explicitly passed a list of open sockets.
|
||||
|
||||
def _win32_share_socket(sock: socket.SocketType) -> socket.SocketType:
|
||||
# Windows requires the socket be explicitly shared across
|
||||
# multiple workers (processes).
|
||||
from socket import fromshare # type: ignore
|
||||
|
||||
sock_data = sock.share(os.getpid()) # type: ignore
|
||||
return fromshare(sock_data)
|
||||
|
||||
for sock in sockets:
|
||||
if config.workers > 1 and platform.system() == "Windows":
|
||||
sock = _win32_share_socket(sock)
|
||||
server = await asyncio.start_server(
|
||||
asyncio_handler,
|
||||
sock=sock,
|
||||
ssl=config.ssl,
|
||||
backlog=config.backlog,
|
||||
)
|
||||
servers.append(server)
|
||||
|
||||
listener_sockets = sockets
|
||||
|
||||
elif config.fd is not None:
|
||||
# Use an existing socket, from a file descriptor.
|
||||
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server = await asyncio.start_server(
|
||||
asyncio_handler, sock=sock, ssl=config.ssl, backlog=config.backlog
|
||||
)
|
||||
assert server.sockets is not None
|
||||
listener_sockets = server.sockets
|
||||
servers.append(server)
|
||||
|
||||
elif config.uds is not None:
|
||||
# Create a socket using UNIX domain socket.
|
||||
uds_perms = 0o666
|
||||
if os.path.exists(config.uds):
|
||||
uds_perms = os.stat(config.uds).st_mode
|
||||
server = await asyncio.start_unix_server(
|
||||
asyncio_handler,
|
||||
path=config.uds,
|
||||
ssl=config.ssl,
|
||||
backlog=config.backlog,
|
||||
)
|
||||
os.chmod(config.uds, uds_perms)
|
||||
assert server.sockets is not None
|
||||
listener_sockets = server.sockets
|
||||
servers.append(server)
|
||||
|
||||
else:
|
||||
# Standard case. Create a socket from a host/port pair.
|
||||
server = await asyncio.start_server(
|
||||
asyncio_handler,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
ssl=config.ssl,
|
||||
backlog=config.backlog,
|
||||
)
|
||||
assert server.sockets is not None
|
||||
listener_sockets = server.sockets
|
||||
servers.append(server)
|
||||
|
||||
listeners = [AsyncioListener(sock) for sock in listener_sockets]
|
||||
await task_status.started(listeners)
|
||||
|
||||
await wait_close()
|
||||
|
||||
# Stop accepting new connections.
|
||||
for server in servers:
|
||||
server.close()
|
||||
for sock in sockets or []:
|
||||
sock.close()
|
||||
|
||||
# Run any custom shutdown behavior.
|
||||
if on_close is not None:
|
||||
await on_close()
|
||||
|
||||
for server in servers:
|
||||
await server.wait_closed()
|
||||
|
||||
async def listen_signals(
|
||||
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
signal_event = asyncio.Event()
|
||||
|
||||
def wrapped_handler(*args: Any) -> None:
|
||||
signal_event.set()
|
||||
|
||||
loop = self._get_event_loop()
|
||||
try:
|
||||
for sig in signals:
|
||||
loop.add_signal_handler(sig, wrapped_handler, sig, None)
|
||||
except NotImplementedError:
|
||||
# Windows
|
||||
for sig in signals:
|
||||
signal.signal(sig, wrapped_handler)
|
||||
|
||||
while True:
|
||||
await signal_event.wait()
|
||||
signal_event.clear()
|
||||
await handler()
|
||||
101
uvicorn/_async_agnostic/backends/auto.py
Normal file
101
uvicorn/_async_agnostic/backends/auto.py
Normal file
@ -0,0 +1,101 @@
|
||||
import signal
|
||||
import socket
|
||||
from typing import Any, AsyncContextManager, Awaitable, Callable, List
|
||||
|
||||
import sniffio
|
||||
|
||||
from ...config import Config
|
||||
from ..state import ServerState
|
||||
from .base import AsyncBackend, AsyncSocket, Event, Queue, TaskStatus
|
||||
|
||||
|
||||
def select_async_backend(library: str) -> AsyncBackend:
|
||||
if library == "asyncio":
|
||||
from .asyncio import AsyncioBackend
|
||||
|
||||
return AsyncioBackend()
|
||||
|
||||
if library == "trio":
|
||||
from .trio import TrioBackend
|
||||
|
||||
return TrioBackend()
|
||||
|
||||
if library == "curio":
|
||||
from .curio import CurioBackend
|
||||
|
||||
return CurioBackend()
|
||||
|
||||
raise NotImplementedError(library)
|
||||
|
||||
|
||||
class AutoBackend(AsyncBackend):
|
||||
@property
|
||||
def _backend(self) -> AsyncBackend:
|
||||
if not hasattr(self, "_backend_impl"):
|
||||
library = sniffio.current_async_library()
|
||||
self._backend_impl = select_async_backend(library)
|
||||
return self._backend_impl
|
||||
|
||||
def create_event(self) -> Event:
|
||||
return self._backend.create_event()
|
||||
|
||||
def create_queue(self, size: int) -> Queue:
|
||||
return self._backend.create_queue(size)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await self._backend.sleep(seconds)
|
||||
|
||||
def run(self, async_fn: Callable, *args: Any) -> None:
|
||||
self._backend.run(async_fn)
|
||||
|
||||
async def move_on_after(
|
||||
self, seconds: float, async_fn: Callable, *args: Any
|
||||
) -> None:
|
||||
await self._backend.move_on_after(seconds, async_fn, *args)
|
||||
|
||||
def start_soon(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncContextManager[None]:
|
||||
return self._backend.start_soon(async_fn, *args, cancel_on_exit=cancel_on_exit)
|
||||
|
||||
def start(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncContextManager[Any]:
|
||||
return self._backend.start(async_fn, *args, cancel_on_exit=cancel_on_exit)
|
||||
|
||||
async def wait_then_call(
|
||||
self,
|
||||
seconds: float,
|
||||
async_fn: Callable,
|
||||
*args: Any,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
await self._backend.wait_then_call(
|
||||
seconds, async_fn, *args, task_status=task_status
|
||||
)
|
||||
|
||||
async def serve_tcp(
|
||||
self,
|
||||
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
|
||||
state: ServerState,
|
||||
config: Config,
|
||||
*,
|
||||
sockets: List[socket.SocketType] = None,
|
||||
wait_close: Callable,
|
||||
on_close: Callable = None,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
await self._backend.serve_tcp(
|
||||
handler,
|
||||
state,
|
||||
config,
|
||||
sockets=sockets,
|
||||
wait_close=wait_close,
|
||||
on_close=on_close,
|
||||
task_status=task_status,
|
||||
)
|
||||
|
||||
async def listen_signals(
|
||||
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
await self._backend.listen_signals(*signals, handler=handler)
|
||||
160
uvicorn/_async_agnostic/backends/base.py
Normal file
160
uvicorn/_async_agnostic/backends/base.py
Normal file
@ -0,0 +1,160 @@
|
||||
import signal
|
||||
import socket
|
||||
from typing import Any, AsyncContextManager, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from ...config import Config
|
||||
from ..state import ServerState
|
||||
|
||||
|
||||
class Event:
|
||||
async def set(self) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def is_set(self) -> bool:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def wait(self) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
class Queue:
|
||||
async def get(self) -> Any:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def put(self, item: Any) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def aclose(self) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
class AsyncSocket:
|
||||
def get_local_addr(self) -> Optional[Tuple[str, int]]:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def get_remote_addr(self) -> Optional[Tuple[str, int]]:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@property
|
||||
def is_ssl(self) -> bool:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def read(self, n: int) -> bytes:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def aclose(self) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
class AsyncListener:
|
||||
@property
|
||||
def socket(self) -> socket.SocketType:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
class TaskStatus:
|
||||
IGNORED: "IgnoredTaskStatus"
|
||||
|
||||
def __init__(self, event: Event) -> None:
|
||||
self._value_event = event
|
||||
|
||||
async def started(self, value: Any = None) -> None:
|
||||
self._value = value
|
||||
await self._value_event.set()
|
||||
|
||||
async def get_value(self) -> Any:
|
||||
await self._value_event.wait()
|
||||
assert hasattr(self, "_value")
|
||||
return self._value
|
||||
|
||||
|
||||
class TaskHandle:
|
||||
async def cancel(self) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
class IgnoredTaskStatus(TaskStatus):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(None) # type: ignore
|
||||
|
||||
async def started(self, value: Any = None) -> None:
|
||||
pass
|
||||
|
||||
async def get_value(self) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
TaskStatus.IGNORED = IgnoredTaskStatus()
|
||||
|
||||
|
||||
class AsyncBackend:
|
||||
def create_event(self) -> Event:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def create_queue(self, size: int) -> Queue:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def create_task_status(self) -> TaskStatus:
|
||||
event = self.create_event()
|
||||
return TaskStatus(event)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def run(self, async_fn: Callable, *args: Any) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def move_on_after(
|
||||
self, seconds: float, async_fn: Callable, *args: Any
|
||||
) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def start_soon(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncContextManager[None]:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def start(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncContextManager[Any]:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def wait_then_call(
|
||||
self,
|
||||
seconds: float,
|
||||
async_fn: Callable,
|
||||
*args: Any,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def serve_tcp(
|
||||
self,
|
||||
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
|
||||
state: ServerState,
|
||||
config: Config,
|
||||
*,
|
||||
sockets: List[socket.SocketType] = None,
|
||||
wait_close: Callable,
|
||||
on_close: Callable = None,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
async def listen_signals(
|
||||
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
259
uvicorn/_async_agnostic/backends/curio.py
Normal file
259
uvicorn/_async_agnostic/backends/curio.py
Normal file
@ -0,0 +1,259 @@
|
||||
import functools
|
||||
import signal
|
||||
import socket
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
import curio
|
||||
|
||||
from ...config import Config
|
||||
from ..compat import asynccontextmanager
|
||||
from ..state import ServerState
|
||||
from ..utils import get_sock_local_addr, get_sock_remote_addr
|
||||
from .base import (
|
||||
AsyncBackend,
|
||||
AsyncListener,
|
||||
AsyncSocket,
|
||||
Event,
|
||||
Queue,
|
||||
TaskHandle,
|
||||
TaskStatus,
|
||||
)
|
||||
|
||||
|
||||
class CurioEvent(Event):
|
||||
def __init__(self) -> None:
|
||||
# TODO: consider curio.UniversalEvent
|
||||
self._event = curio.Event()
|
||||
|
||||
async def set(self) -> None:
|
||||
await self._event.set()
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._event.is_set()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._event.clear()
|
||||
|
||||
|
||||
class CurioSocket(AsyncSocket):
|
||||
def __init__(self, sock: curio.io.Socket, is_ssl: bool) -> None:
|
||||
self._sock = sock
|
||||
self._stream = sock.as_stream()
|
||||
self._is_closed = False
|
||||
self._is_ssl = is_ssl
|
||||
|
||||
def get_local_addr(self) -> Optional[Tuple[str, int]]:
|
||||
return get_sock_local_addr(self._sock)
|
||||
|
||||
def get_remote_addr(self) -> Optional[Tuple[str, int]]:
|
||||
return get_sock_remote_addr(self._sock)
|
||||
|
||||
@property
|
||||
def is_ssl(self) -> bool:
|
||||
return self._is_ssl
|
||||
|
||||
async def read(self, n: int) -> bytes:
|
||||
try:
|
||||
return await self._stream.read(n)
|
||||
except ValueError: # TODO
|
||||
return b""
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
try:
|
||||
await self._stream.write(data)
|
||||
except ValueError: # TODO
|
||||
pass
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
pass
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._sock.close()
|
||||
self._is_closed = True
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
|
||||
class CurioQueue(Queue):
|
||||
def __init__(self, size: int) -> None:
|
||||
self._queue = curio.Queue(size)
|
||||
|
||||
async def get(self) -> Any:
|
||||
return await self._queue.get()
|
||||
|
||||
async def put(self, item: Any) -> None:
|
||||
await self._queue.put(item)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class CurioListener(AsyncListener):
|
||||
def __init__(self, sock: socket.SocketType) -> None:
|
||||
self._sock = sock
|
||||
|
||||
@property
|
||||
def socket(self) -> socket.SocketType:
|
||||
return self._sock
|
||||
|
||||
|
||||
class CurioTaskHandle(TaskHandle):
|
||||
def __init__(self, cancel_event: curio.Event) -> None:
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
async def cancel(self) -> None:
|
||||
await self._cancel_event.set()
|
||||
|
||||
|
||||
class CurioBackend(AsyncBackend):
|
||||
def create_event(self) -> Event:
|
||||
return CurioEvent()
|
||||
|
||||
def create_queue(self, size: int) -> Queue:
|
||||
return CurioQueue(size)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await curio.sleep(seconds)
|
||||
|
||||
def run(self, async_fn: Callable, *args: Any) -> None:
|
||||
curio.run(async_fn, *args)
|
||||
|
||||
async def move_on_after(
|
||||
self, seconds: float, async_fn: Callable, *args: Any
|
||||
) -> None:
|
||||
async with curio.ignore_after(seconds):
|
||||
await async_fn(*args)
|
||||
|
||||
@asynccontextmanager
|
||||
async def start_soon(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncIterator[None]:
|
||||
async with curio.TaskGroup() as g:
|
||||
task = await g.spawn(async_fn, *args)
|
||||
yield
|
||||
|
||||
if cancel_on_exit:
|
||||
await task.cancel()
|
||||
else:
|
||||
await task.wait()
|
||||
|
||||
try:
|
||||
task.result
|
||||
except curio.TaskCancelled:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def start(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncIterator[Any]:
|
||||
async with curio.TaskGroup() as g:
|
||||
task_status = self.create_task_status()
|
||||
async_fn = functools.partial(async_fn, task_status=task_status)
|
||||
task = await g.spawn(async_fn, *args)
|
||||
value = await task_status.get_value()
|
||||
yield value
|
||||
|
||||
if cancel_on_exit:
|
||||
await task.cancel()
|
||||
else:
|
||||
await task.wait()
|
||||
|
||||
try:
|
||||
task.result
|
||||
except curio.TaskCancelled:
|
||||
pass
|
||||
|
||||
async def wait_then_call(
|
||||
self,
|
||||
seconds: float,
|
||||
async_fn: Callable,
|
||||
*args: Any,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
cancel_event = curio.Event()
|
||||
await task_status.started(CurioTaskHandle(cancel_event))
|
||||
|
||||
async with curio.TaskGroup(wait=any) as g:
|
||||
await g.spawn(curio.sleep, seconds)
|
||||
await g.spawn(cancel_event.wait)
|
||||
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
await async_fn()
|
||||
|
||||
async def serve_tcp(
|
||||
self,
|
||||
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
|
||||
state: ServerState,
|
||||
config: Config,
|
||||
*,
|
||||
sockets: List[socket.SocketType] = None,
|
||||
wait_close: Callable,
|
||||
on_close: Callable = None,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
async def client_connected_task(csock: curio.io.Socket, addr: Any) -> None:
|
||||
sock = CurioSocket(csock, is_ssl=bool(config.ssl))
|
||||
await handler(sock, state, config)
|
||||
|
||||
async with curio.TaskGroup() as g:
|
||||
if sockets is not None:
|
||||
# Explicitly passed a list of open sockets.
|
||||
listener_sockets = sockets
|
||||
|
||||
elif config.fd is not None:
|
||||
# Use an existing socket, from a file descriptor.
|
||||
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
listener_sockets = [sock]
|
||||
|
||||
elif config.uds is not None:
|
||||
# Create a socket using UNIX domain socket.
|
||||
sock = curio.unix_server_socket(config.uds, backlog=config.backlog)
|
||||
listener_sockets = [sock]
|
||||
|
||||
else:
|
||||
# Standard case. Create a socket from a host/port pair.
|
||||
sock = curio.tcp_server_socket(
|
||||
config.host, config.port, backlog=config.backlog
|
||||
)
|
||||
listener_sockets = [sock]
|
||||
|
||||
for sock in listener_sockets:
|
||||
run_server = functools.partial(curio.network.run_server, ssl=config.ssl)
|
||||
await g.spawn(run_server, sock, client_connected_task)
|
||||
|
||||
value = [CurioListener(sock) for sock in listener_sockets]
|
||||
await task_status.started(value)
|
||||
|
||||
await wait_close()
|
||||
|
||||
# Run any custom shutdown behavior.
|
||||
if on_close is not None:
|
||||
await on_close()
|
||||
|
||||
# Connections are properly closed, we can go ahead and hard-stop
|
||||
# the servers.
|
||||
await g.cancel_remaining()
|
||||
|
||||
async def listen_signals(
|
||||
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
# https://curio.readthedocs.io/en/latest/howto.html#how-do-you-catch-signals
|
||||
signal_event = curio.UniversalEvent()
|
||||
|
||||
def wrapped_handler(*args: Any) -> None:
|
||||
signal_event.set()
|
||||
|
||||
for sig in signals:
|
||||
signal.signal(sig, wrapped_handler)
|
||||
|
||||
while True:
|
||||
await signal_event.wait()
|
||||
signal_event.clear()
|
||||
await handler()
|
||||
289
uvicorn/_async_agnostic/backends/trio.py
Normal file
289
uvicorn/_async_agnostic/backends/trio.py
Normal file
@ -0,0 +1,289 @@
|
||||
import functools
|
||||
import signal
|
||||
import socket
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from ...config import Config
|
||||
from ..compat import asynccontextmanager
|
||||
from ..exceptions import BrokenSocket
|
||||
from ..state import ServerState
|
||||
from ..utils import get_sock_local_addr, get_sock_remote_addr
|
||||
from .base import (
|
||||
AsyncBackend,
|
||||
AsyncListener,
|
||||
AsyncSocket,
|
||||
Event,
|
||||
Queue,
|
||||
TaskHandle,
|
||||
TaskStatus,
|
||||
)
|
||||
|
||||
|
||||
class TrioEvent(Event):
|
||||
def __init__(self) -> None:
|
||||
self._event = trio.Event()
|
||||
|
||||
async def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._event.is_set()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._event = trio.Event()
|
||||
|
||||
|
||||
class TrioSocket(AsyncSocket):
|
||||
def __init__(self, stream: Union[trio.SocketStream, trio.SSLStream]) -> None:
|
||||
self._stream = stream
|
||||
self._is_closed = False
|
||||
|
||||
def _unwrap_ssl(self) -> trio.SocketStream:
|
||||
stream: trio.abc.Stream = self._stream
|
||||
while isinstance(stream, trio.SSLStream):
|
||||
stream = stream.transport_stream
|
||||
assert isinstance(stream, trio.SocketStream)
|
||||
return stream
|
||||
|
||||
def get_local_addr(self) -> Optional[Tuple[str, int]]:
|
||||
stream = self._unwrap_ssl()
|
||||
return get_sock_local_addr(stream.socket)
|
||||
|
||||
def get_remote_addr(self) -> Optional[Tuple[str, int]]:
|
||||
stream = self._unwrap_ssl()
|
||||
return get_sock_remote_addr(stream.socket)
|
||||
|
||||
@property
|
||||
def is_ssl(self) -> bool:
|
||||
return isinstance(self._stream, trio.SSLStream)
|
||||
|
||||
async def read(self, n: int) -> bytes:
|
||||
try:
|
||||
return await self._stream.receive_some(n)
|
||||
except (trio.BrokenResourceError, trio.ClosedResourceError):
|
||||
return b""
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
try:
|
||||
await self._stream.send_all(data)
|
||||
except trio.BrokenResourceError:
|
||||
pass
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
stream = self._unwrap_ssl()
|
||||
try:
|
||||
await stream.send_eof()
|
||||
except trio.BrokenResourceError:
|
||||
raise BrokenSocket()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._stream.aclose()
|
||||
self._is_closed = True
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
|
||||
class TrioQueue(Queue):
|
||||
def __init__(self, size: int) -> None:
|
||||
self._send_channel, self._receive_channel = trio.open_memory_channel[Any](size)
|
||||
|
||||
async def get(self) -> Any:
|
||||
return await self._receive_channel.receive()
|
||||
|
||||
async def put(self, item: Any) -> None:
|
||||
try:
|
||||
await self._send_channel.send(item)
|
||||
except (trio.BrokenResourceError, trio.ClosedResourceError):
|
||||
pass # Already closed.
|
||||
|
||||
async def aclose(self) -> None:
|
||||
try:
|
||||
await self._receive_channel.aclose()
|
||||
except trio.ClosedResourceError:
|
||||
pass # Already closed.
|
||||
|
||||
try:
|
||||
await self._send_channel.aclose()
|
||||
except trio.ClosedResourceError:
|
||||
pass # Already closed.
|
||||
|
||||
|
||||
class TrioListener(AsyncListener):
|
||||
def __init__(self, listener: trio.abc.Listener) -> None:
|
||||
self._listener = listener
|
||||
|
||||
def _unwrap_ssl(self) -> trio.SocketListener:
|
||||
listener = self._listener
|
||||
# Unwrap SSL.
|
||||
while isinstance(listener, trio.SSLListener):
|
||||
listener = listener.transport_listener
|
||||
assert isinstance(listener, trio.SocketListener)
|
||||
return listener
|
||||
|
||||
@property
|
||||
def socket(self) -> socket.SocketType:
|
||||
listener = self._unwrap_ssl()
|
||||
return listener.socket
|
||||
|
||||
|
||||
class TrioTaskHandle(TaskHandle):
|
||||
def __init__(self, cancel_scope: trio.CancelScope) -> None:
|
||||
self._cancel_scope = cancel_scope
|
||||
|
||||
async def cancel(self) -> None:
|
||||
self._cancel_scope.cancel()
|
||||
|
||||
|
||||
class TrioBackend(AsyncBackend):
|
||||
def create_event(self) -> Event:
|
||||
return TrioEvent()
|
||||
|
||||
def create_queue(self, size: int) -> Queue:
|
||||
return TrioQueue(size)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await trio.sleep(seconds)
|
||||
|
||||
def run(self, async_fn: Callable, *args: Any) -> None:
|
||||
trio.run(async_fn, *args)
|
||||
|
||||
async def move_on_after(
|
||||
self, seconds: float, async_fn: Callable, *args: Any
|
||||
) -> None:
|
||||
with trio.move_on_after(seconds):
|
||||
await async_fn(*args)
|
||||
|
||||
@asynccontextmanager
|
||||
async def start_soon(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncIterator[None]:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(async_fn, *args)
|
||||
yield
|
||||
if cancel_on_exit:
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
@asynccontextmanager
|
||||
async def start(
|
||||
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
|
||||
) -> AsyncIterator[Any]:
|
||||
async with trio.open_nursery() as nursery:
|
||||
task_status = self.create_task_status()
|
||||
async_fn = functools.partial(async_fn, task_status=task_status)
|
||||
nursery.start_soon(async_fn, *args)
|
||||
value = await task_status.get_value()
|
||||
yield value
|
||||
if cancel_on_exit:
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async def wait_then_call(
|
||||
self,
|
||||
seconds: float,
|
||||
async_fn: Callable,
|
||||
*args: Any,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
cancel_scope = trio.CancelScope()
|
||||
await task_status.started(TrioTaskHandle(cancel_scope))
|
||||
with cancel_scope:
|
||||
await trio.sleep(seconds)
|
||||
cancel_scope.shield = True
|
||||
await async_fn()
|
||||
|
||||
async def serve_tcp(
|
||||
self,
|
||||
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
|
||||
state: ServerState,
|
||||
config: Config,
|
||||
*,
|
||||
sockets: List[socket.SocketType] = None,
|
||||
wait_close: Callable,
|
||||
on_close: Callable = None,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
async def trio_handler(stream: trio.SocketStream) -> None:
|
||||
sock = TrioSocket(stream)
|
||||
await handler(sock, state, config)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
listeners: Sequence[trio.abc.Listener] = []
|
||||
|
||||
if sockets is not None:
|
||||
# Explicitly passed a list of open sockets.
|
||||
# (We need to wrap them around the trio equivalents.)
|
||||
trio_sockets = []
|
||||
for sock in sockets:
|
||||
sock.listen(config.backlog)
|
||||
trio_socket = trio.socket.fromfd(
|
||||
sock.fileno(), sock.family, sock.type
|
||||
)
|
||||
trio_sockets.append(trio_socket)
|
||||
listeners = [trio.SocketListener(sock) for sock in trio_sockets]
|
||||
|
||||
elif config.fd is not None:
|
||||
# Use an existing socket, from a file descriptor.
|
||||
sock = trio.socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
listeners = [trio.SocketListener(sock)]
|
||||
|
||||
elif config.uds is not None:
|
||||
# Create a socket using UNIX domain socket.
|
||||
# XXX: trio does not provide highlevel support for UDS servers yet.
|
||||
# See: https://github.com/python-trio/trio/issues/279
|
||||
# This may land soon-ish, though.
|
||||
# See: https://github.com/python-trio/trio/pull/1433
|
||||
# Use the API proposed in the RFC above as a draft.
|
||||
listeners = await trio.open_unix_listeners( # type: ignore
|
||||
config.uds, backlog=config.backlog
|
||||
)
|
||||
|
||||
else:
|
||||
# Standard case. Create a socket from a host/port pair.
|
||||
listeners = await trio.open_tcp_listeners(
|
||||
config.port, host=config.host, backlog=config.backlog
|
||||
)
|
||||
|
||||
if config.ssl:
|
||||
listeners = [
|
||||
trio.SSLListener(listener, config.ssl, https_compatible=False)
|
||||
for listener in listeners
|
||||
]
|
||||
|
||||
listeners = await nursery.start(
|
||||
trio.serve_listeners, trio_handler, listeners
|
||||
)
|
||||
value = [TrioListener(listener) for listener in listeners]
|
||||
await task_status.started(value)
|
||||
|
||||
await wait_close()
|
||||
|
||||
# Run any custom shutdown behavior.
|
||||
if on_close is not None:
|
||||
await on_close()
|
||||
|
||||
# Connections are properly closed, we can go ahead and hard-stop
|
||||
# the server.
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async def listen_signals(
|
||||
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
with trio.open_signal_receiver(*signals) as signal_receiver:
|
||||
async for _ in signal_receiver:
|
||||
await handler()
|
||||
12
uvicorn/_async_agnostic/compat.py
Normal file
12
uvicorn/_async_agnostic/compat.py
Normal file
@ -0,0 +1,12 @@
|
||||
try:
|
||||
from contextlib import AsyncExitStack
|
||||
except ImportError: # pragma: no cover
|
||||
from async_exit_stack import AsyncExitStack # type: ignore
|
||||
|
||||
try:
|
||||
from contextlib import asynccontextmanager
|
||||
except ImportError: # pragma: no cover
|
||||
from async_generator import asynccontextmanager # type: ignore
|
||||
|
||||
|
||||
__all__ = ["AsyncExitStack", "asynccontextmanager"]
|
||||
10
uvicorn/_async_agnostic/exceptions.py
Normal file
10
uvicorn/_async_agnostic/exceptions.py
Normal file
@ -0,0 +1,10 @@
|
||||
class BrokenSocket(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LifespanFailure(Exception):
|
||||
pass
|
||||
0
uvicorn/_async_agnostic/http11/__init__.py
Normal file
0
uvicorn/_async_agnostic/http11/__init__.py
Normal file
230
uvicorn/_async_agnostic/http11/connection.py
Normal file
230
uvicorn/_async_agnostic/http11/connection.py
Normal file
@ -0,0 +1,230 @@
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, AsyncIterator, List, Optional, Tuple
|
||||
|
||||
from ..backends.auto import AutoBackend
|
||||
from ..backends.base import AsyncSocket
|
||||
from ..exceptions import BrokenSocket, ProtocolError
|
||||
from ..utils import STATUS_PHRASES, find_upgrade_header, to_internet_date
|
||||
from .parsers.base import Event, HTTP11Parser
|
||||
|
||||
TRACE_LOG_LEVEL = 5
|
||||
NEXT_ID = itertools.count()
|
||||
|
||||
|
||||
class HTTP11Connection:
|
||||
MAX_RECV = 2 ** 16
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sock: AsyncSocket,
|
||||
default_headers: List[Tuple[bytes, bytes]],
|
||||
parser: HTTP11Parser,
|
||||
) -> None:
|
||||
self._sock = sock
|
||||
self._default_headers = default_headers
|
||||
self._parser = parser
|
||||
|
||||
self._obj_id = next(NEXT_ID)
|
||||
self._logger = logging.getLogger("uvicorn.error")
|
||||
self._backend = AutoBackend()
|
||||
|
||||
def trace(self, msg: str, *args: Any) -> None:
|
||||
self._logger.log(TRACE_LOG_LEVEL, f"conn(%s): {msg}", self._obj_id, *args)
|
||||
|
||||
def debug(self, msg: str, *args: Any) -> None:
|
||||
self._logger.debug(f"conn(%s): {msg}", self._obj_id, *args)
|
||||
|
||||
@property
|
||||
def scheme(self) -> str:
|
||||
return "https" if self._sock.is_ssl else "http"
|
||||
|
||||
@property
|
||||
def server(self) -> Optional[Tuple[str, int]]:
|
||||
return self._sock.get_local_addr()
|
||||
|
||||
@property
|
||||
def client(self) -> Optional[Tuple[str, int]]:
|
||||
return self._sock.get_remote_addr()
|
||||
|
||||
def basic_headers(self) -> List[Tuple[bytes, bytes]]:
|
||||
return [
|
||||
(b"date", to_internet_date(time.time()).encode("utf-8")),
|
||||
] + self._default_headers
|
||||
|
||||
# State machine helpers
|
||||
|
||||
def states(self) -> dict:
|
||||
return self._parser.states()
|
||||
|
||||
async def _send_event(self, event: Event) -> None:
|
||||
if event["type"] == "Response":
|
||||
self.trace(
|
||||
"send_event event=Response("
|
||||
"status_code=%d, headers=<Headers(...)>, reason=%s)",
|
||||
event["status_code"],
|
||||
event["reason"],
|
||||
)
|
||||
elif event["type"] == "Data":
|
||||
self.trace("send_event event=Data(<%d bytes>)", len(event["data"]))
|
||||
elif event["type"] == "EndOfMessage":
|
||||
self.trace("send_event event=EndOfMessage(headers=<Headers(...)>")
|
||||
else:
|
||||
assert event["type"] == "ConnectionClosed"
|
||||
self.trace("send_event event=ConnectionClosed()")
|
||||
|
||||
data = self._parser.send(event)
|
||||
if data is None:
|
||||
assert event["type"] == "ConnectionClosed", event
|
||||
await self._sock.write(b"")
|
||||
await self.shutdown_and_clean_up()
|
||||
else:
|
||||
await self._sock.write(data)
|
||||
|
||||
async def _read_from_peer(self) -> None:
|
||||
if self._parser.they_are_waiting_for_100_continue:
|
||||
self.trace("Sending 100 Continue")
|
||||
await self._send_event({"type": "InformationalResponse"})
|
||||
|
||||
data = await self._sock.read(self.MAX_RECV)
|
||||
self.trace("read_data Data(<%d bytes>)", len(data))
|
||||
self._parser.receive_data(data)
|
||||
|
||||
async def _receive_event(self) -> Any:
|
||||
while True:
|
||||
try:
|
||||
event = self._parser.next_event()
|
||||
except ProtocolError as exc:
|
||||
raise ProtocolError(f"Invalid HTTP request received: {exc}")
|
||||
|
||||
if event["type"] == "NEED_DATA":
|
||||
await self._read_from_peer()
|
||||
continue
|
||||
|
||||
if event["type"] == "Request":
|
||||
self.trace(
|
||||
"receive_event event=Request("
|
||||
"http_version=%s, method=%s, target=%s, headers=...)",
|
||||
event["http_version"],
|
||||
event["method"],
|
||||
event["target"],
|
||||
)
|
||||
elif event["type"] == "Data":
|
||||
self.trace("receive_event event=Data(<%d bytes>)", len(event["data"]))
|
||||
elif event["type"] == "EndOfMessage":
|
||||
self.trace("receive_event event=EndOfMessage(headers=<Headers(...)>")
|
||||
else:
|
||||
assert event["type"] == "ConnectionClosed"
|
||||
self.trace("receive_event event=ConnectionClosed()")
|
||||
|
||||
return event
|
||||
|
||||
async def read_request(
|
||||
self,
|
||||
) -> Tuple[bytes, bytes, bytes, List[Tuple[bytes, bytes]], Optional[bytes]]:
|
||||
event = await self._receive_event()
|
||||
|
||||
if event["type"] == "ConnectionClosed":
|
||||
raise BrokenSocket("Client has disconnected")
|
||||
|
||||
assert event["type"] == "Request"
|
||||
|
||||
http_version: bytes = event["http_version"]
|
||||
method: bytes = event["method"]
|
||||
path: bytes = event["target"]
|
||||
headers = [(key.lower(), value) for key, value in event["headers"]]
|
||||
upgrade = find_upgrade_header(headers)
|
||||
|
||||
return (http_version, method, path, headers, upgrade)
|
||||
|
||||
async def aiter_request_body(self) -> AsyncIterator[bytes]:
|
||||
async def receive_data() -> bytes:
|
||||
event = await self._receive_event()
|
||||
if event["type"] == "EndOfMessage":
|
||||
return b""
|
||||
assert event["type"] == "Data"
|
||||
return event["data"]
|
||||
|
||||
async def request_body(data: bytes) -> AsyncIterator[bytes]:
|
||||
while data:
|
||||
yield data
|
||||
data = await receive_data()
|
||||
|
||||
# Read at least one event so that we get a chance of seeing `EndOfMessage`
|
||||
# right away in case the client does not send a body (eg HEAD or GET requests).
|
||||
initial = await receive_data()
|
||||
|
||||
return request_body(initial)
|
||||
|
||||
async def send_response(
|
||||
self, status_code: int, headers: List[Tuple[bytes, bytes]], reason: bytes = b""
|
||||
) -> None:
|
||||
if not reason:
|
||||
reason = STATUS_PHRASES[status_code]
|
||||
event = {
|
||||
"type": "Response",
|
||||
"status_code": status_code,
|
||||
"headers": headers,
|
||||
"reason": reason,
|
||||
}
|
||||
await self._send_event(event)
|
||||
|
||||
async def send_simple_response(
|
||||
self, status_code: int, content_type: str, body: bytes
|
||||
) -> None:
|
||||
self.trace("send_simple_response %d (%d bytes)", status_code, len(body))
|
||||
headers = self.basic_headers() + [
|
||||
(b"Content-Type", content_type.encode("utf-8")),
|
||||
(b"Content-Length", str(len(body)).encode("utf-8")),
|
||||
]
|
||||
await self.send_response(status_code=status_code, headers=headers)
|
||||
await self._send_event({"type": "Data", "data": body})
|
||||
await self._send_event({"type": "EndOfMessage"})
|
||||
|
||||
async def send_response_body(self, chunk: bytes) -> None:
|
||||
if chunk:
|
||||
event = {"type": "Data", "data": chunk}
|
||||
else:
|
||||
event = {"type": "EndOfMessage"}
|
||||
await self._send_event(event)
|
||||
|
||||
def set_keepalive(self) -> None:
|
||||
try:
|
||||
self._parser.start_next_cycle()
|
||||
except ProtocolError:
|
||||
raise
|
||||
|
||||
async def trigger_shutdown(self) -> None:
|
||||
self.trace("triggering shutdown")
|
||||
states = self._parser.states()
|
||||
if states["server"] in {"IDLE", "DONE"}:
|
||||
await self._send_event({"type": "ConnectionClosed"})
|
||||
|
||||
async def shutdown_and_clean_up(self) -> None:
|
||||
self.trace("shutting down")
|
||||
|
||||
try:
|
||||
await self._sock.send_eof()
|
||||
except BrokenSocket:
|
||||
self.trace("failed to send EOF: client is already gone")
|
||||
return
|
||||
|
||||
self.trace("EOF sent")
|
||||
|
||||
# Wait and read for a bit to give them a chance to see that we closed
|
||||
# things, but eventually give up and just close the socket.
|
||||
async def attempt_read_until_eof() -> None:
|
||||
try:
|
||||
while True:
|
||||
data = await self._sock.read(self.MAX_RECV)
|
||||
if not data:
|
||||
self.trace("EOF acknowledged by peer")
|
||||
break
|
||||
except Exception:
|
||||
pass # It broke.
|
||||
|
||||
try:
|
||||
await self._backend.move_on_after(5, attempt_read_until_eof)
|
||||
finally:
|
||||
await self._sock.aclose()
|
||||
155
uvicorn/_async_agnostic/http11/handler.py
Normal file
155
uvicorn/_async_agnostic/http11/handler.py
Normal file
@ -0,0 +1,155 @@
|
||||
import logging
|
||||
from typing import AsyncIterator, List, Tuple
|
||||
from urllib.parse import unquote
|
||||
|
||||
from uvicorn.config import Config
|
||||
|
||||
from ..asgi import ASGIRequestResponseCycle
|
||||
from ..backends.base import AsyncSocket
|
||||
from ..exceptions import BrokenSocket, ProtocolError
|
||||
from ..state import ServerState
|
||||
from .connection import HTTP11Connection
|
||||
from .keepalive import KeepAlive
|
||||
from .parsers.h11 import H11Parser
|
||||
|
||||
try:
|
||||
import httptools
|
||||
|
||||
from .parsers.httptools import HttpToolsParser
|
||||
except ImportError: # pragma: no cover
|
||||
httptools = None # type: ignore
|
||||
HttpToolsParser = None # type: ignore
|
||||
|
||||
|
||||
def create_http11_connection(
|
||||
sock: AsyncSocket, state: ServerState, config: Config
|
||||
) -> HTTP11Connection:
|
||||
use_httptools = config.http == "httptools" or (
|
||||
config.http == "auto" and httptools is not None
|
||||
)
|
||||
parser = HttpToolsParser() if use_httptools else H11Parser()
|
||||
return HTTP11Connection(sock, default_headers=state.default_headers, parser=parser)
|
||||
|
||||
|
||||
async def handle_http11(sock: AsyncSocket, state: ServerState, config: Config) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
conn = create_http11_connection(sock, state, config)
|
||||
keepalive = KeepAlive(conn, config)
|
||||
state.connections.add(conn)
|
||||
conn.debug("Connection made")
|
||||
|
||||
while True:
|
||||
assert conn.states() == {"client": "IDLE", "server": "IDLE"}
|
||||
try:
|
||||
(
|
||||
http_version,
|
||||
method,
|
||||
path,
|
||||
headers,
|
||||
upgrade,
|
||||
) = await conn.read_request()
|
||||
assert http_version == b"1.1", http_version
|
||||
assert upgrade is None, "WebSocket not supported yet"
|
||||
request_body = await conn.aiter_request_body()
|
||||
await send_h11_response(
|
||||
conn,
|
||||
config=config,
|
||||
method=method,
|
||||
path=path,
|
||||
headers=headers,
|
||||
request_body=request_body,
|
||||
)
|
||||
except BrokenSocket:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error("Error while responding to request: %s", exc, exc_info=exc)
|
||||
await maybe_send_error_response(conn)
|
||||
else:
|
||||
state.total_requests += 1
|
||||
|
||||
states = conn.states()
|
||||
if states["server"] == "MUST_CLOSE":
|
||||
conn.trace("Connection is not reusable, shutting down")
|
||||
break
|
||||
|
||||
conn.trace("Trying to reuse connection")
|
||||
try:
|
||||
conn.set_keepalive()
|
||||
except ProtocolError as exc:
|
||||
conn.trace("Connection is not reusable, bailing out: %s ", exc)
|
||||
await maybe_send_error_response(conn)
|
||||
break
|
||||
else:
|
||||
await keepalive.reset()
|
||||
await keepalive.schedule()
|
||||
conn.debug("Connection kept alive")
|
||||
|
||||
await keepalive.aclose()
|
||||
await conn.shutdown_and_clean_up()
|
||||
state.connections.discard(conn)
|
||||
conn.debug("Connection closed")
|
||||
|
||||
|
||||
async def send_h11_response(
|
||||
conn: HTTP11Connection,
|
||||
*,
|
||||
config: Config,
|
||||
method: bytes,
|
||||
path: bytes,
|
||||
headers: List[Tuple[bytes, bytes]],
|
||||
request_body: AsyncIterator[bytes],
|
||||
) -> None:
|
||||
conn.trace("Sending response")
|
||||
|
||||
raw_path, _, query_string = path.partition(b"?")
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"asgi": {
|
||||
"version": "3.0",
|
||||
"spec_version": "2.1",
|
||||
},
|
||||
"http_version": "1.1",
|
||||
"server": conn.server,
|
||||
"client": conn.client,
|
||||
"scheme": conn.scheme,
|
||||
"method": method.decode("ascii"),
|
||||
"root_path": config.root_path,
|
||||
"path": unquote(raw_path.decode("ascii")),
|
||||
"raw_path": raw_path,
|
||||
"query_string": query_string,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
send_response_body = conn.send_response_body
|
||||
|
||||
cycle = ASGIRequestResponseCycle(
|
||||
conn,
|
||||
scope=scope,
|
||||
request_body=request_body,
|
||||
send_response_body=send_response_body,
|
||||
access_log=config.access_log,
|
||||
)
|
||||
|
||||
app = config.loaded_app
|
||||
|
||||
await cycle.run_asgi(app)
|
||||
|
||||
|
||||
async def maybe_send_error_response(conn: HTTP11Connection) -> None:
|
||||
states = conn.states()
|
||||
if states["server"] not in {"IDLE", "ACTIVE"}:
|
||||
return
|
||||
|
||||
conn.trace("send error response")
|
||||
status_code = 500
|
||||
content_type = "text/plain; charset=utf-8"
|
||||
body = b"Internal Server Error"
|
||||
try:
|
||||
await conn.send_simple_response(status_code, content_type, body)
|
||||
except Exception as exc:
|
||||
conn.trace("error while sending error response: %s", exc)
|
||||
52
uvicorn/_async_agnostic/http11/keepalive.py
Normal file
52
uvicorn/_async_agnostic/http11/keepalive.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import Optional
|
||||
|
||||
from uvicorn.config import Config
|
||||
|
||||
from ..backends.auto import AutoBackend
|
||||
from ..backends.base import TaskHandle, TaskStatus
|
||||
from ..compat import AsyncExitStack
|
||||
from .connection import HTTP11Connection
|
||||
|
||||
|
||||
class KeepAlive:
|
||||
def __init__(self, conn: HTTP11Connection, config: Config) -> None:
|
||||
self._conn = conn
|
||||
self._config = config
|
||||
self._backend = AutoBackend()
|
||||
self._task_handle: Optional[TaskHandle] = None
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def schedule(self) -> None:
|
||||
assert self._task_handle is None
|
||||
|
||||
self._task_handle = await self._exit_stack.enter_async_context(
|
||||
self._backend.start(
|
||||
self._trigger_shutdown_after_expiry,
|
||||
cancel_on_exit=True,
|
||||
),
|
||||
)
|
||||
|
||||
async def _trigger_shutdown_after_expiry(
|
||||
self, *, task_status: TaskStatus = TaskStatus.IGNORED
|
||||
) -> None:
|
||||
timeout = self._config.timeout_keep_alive
|
||||
self._conn.trace("keep-alive expiry scheduled in %d seconds", timeout)
|
||||
await self._backend.wait_then_call(
|
||||
timeout,
|
||||
async_fn=self._trigger_shutdown,
|
||||
task_status=task_status,
|
||||
)
|
||||
|
||||
async def _trigger_shutdown(self) -> None:
|
||||
self._conn.trace("keep-alive expired")
|
||||
await self._conn.trigger_shutdown()
|
||||
|
||||
async def reset(self) -> None:
|
||||
if self._task_handle is not None:
|
||||
self._conn.trace("keep-alive reset")
|
||||
await self._task_handle.cancel()
|
||||
self._task_handle = None
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._exit_stack.aclose()
|
||||
self._conn.trace("keep-alive expiry cancelled")
|
||||
0
uvicorn/_async_agnostic/http11/parsers/__init__.py
Normal file
0
uvicorn/_async_agnostic/http11/parsers/__init__.py
Normal file
73
uvicorn/_async_agnostic/http11/parsers/base.py
Normal file
73
uvicorn/_async_agnostic/http11/parsers/base.py
Normal file
@ -0,0 +1,73 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# We use the same state machine approach than `h11` because it's very convenient
|
||||
# to program against.
|
||||
#
|
||||
# NOTE: Everything below is in the `h11` state machine docs:
|
||||
# https://h11.readthedocs.io/en/latest/api.html#the-state-machine
|
||||
# We describe the expected state machine to explicitly state the expected contract.
|
||||
#
|
||||
# Possible states
|
||||
# ---------------
|
||||
# IDLE
|
||||
# SEND_RESPONSE
|
||||
# SEND_BODY
|
||||
# DONE
|
||||
# MUST_CLOSE
|
||||
# CLOSED
|
||||
# ERROR
|
||||
#
|
||||
# Happy path
|
||||
# ----------
|
||||
# Client:
|
||||
# IDLE --(Request)-> SEND_BODY [--(Data)-> SEND_BODY] --(EndOfMessage)-> DONE
|
||||
# Server:
|
||||
# IDLE --(Request)-> SEND_RESPONSE [--(InformationalResponse)-> SEND_RESPONSE] \
|
||||
# --(Response)-> SEND_BODY [--(Data)-> SEND_BODY] --(EndOfMessage)-> DONE
|
||||
#
|
||||
# Keep-Alive
|
||||
# ----------
|
||||
# ENABLED --(HTTP/1.0 | Connection: close)-> DISABLED
|
||||
# Client + Server: DONE --(KeepAlive ENABLED)-> IDLE
|
||||
#
|
||||
# Closing paths
|
||||
# -------------
|
||||
# IDLE --(ConnectionClosed)-> CLOSED
|
||||
# IDLE --(Peer CLOSED)-> MUST_CLOSE
|
||||
# DONE --(KeepAlive DISABLED | Peer CLOSED | Peer ERROR)-> MUST_CLOSE
|
||||
# MUST_CLOSE --(ConnectionClosed)-> CLOSED
|
||||
# DONE --(ConnectionClosed)-> CLOSED
|
||||
# CLOSED --(ConnectionClosed)-> CLOSED
|
||||
#
|
||||
# Error paths
|
||||
# -----------
|
||||
# * --> ERROR
|
||||
# NOTE: this path is managed by `h11` internally. For other libraries, the ERROR state
|
||||
# should be entered anytime the parser raises a parsing exception.
|
||||
|
||||
Event = Dict[str, Any]
|
||||
|
||||
|
||||
class HTTP11Parser:
|
||||
"""
|
||||
An event-based HTTP/1.1 parser interface, inspired by `h11`.
|
||||
"""
|
||||
|
||||
def states(self) -> dict:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@property
|
||||
def they_are_waiting_for_100_continue(self) -> bool:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def receive_data(self, data: bytes) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def next_event(self) -> Event:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def send(self, event: Event) -> Optional[bytes]:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def start_next_cycle(self) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
98
uvicorn/_async_agnostic/http11/parsers/h11.py
Normal file
98
uvicorn/_async_agnostic/http11/parsers/h11.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import h11
|
||||
|
||||
from ...exceptions import ProtocolError
|
||||
from .base import Event, HTTP11Parser
|
||||
|
||||
H11_STATES_MAP = {
|
||||
h11.IDLE: "IDLE",
|
||||
h11.SEND_RESPONSE: "SEND_RESPONSE",
|
||||
h11.SEND_BODY: "SEND_BODY",
|
||||
h11.DONE: "DONE",
|
||||
h11.MUST_CLOSE: "MUST_CLOSE",
|
||||
h11.CLOSED: "CLOSED",
|
||||
h11.ERROR: "ERROR",
|
||||
}
|
||||
|
||||
|
||||
class H11Parser(HTTP11Parser):
|
||||
"""
|
||||
An HTTP/1.1 parser backed by the `h11` library.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._h11_state = h11.Connection(h11.SERVER)
|
||||
|
||||
def states(self) -> dict:
|
||||
return {
|
||||
"client": H11_STATES_MAP[self._h11_state.their_state],
|
||||
"server": H11_STATES_MAP[self._h11_state.our_state],
|
||||
}
|
||||
|
||||
@property
|
||||
def they_are_waiting_for_100_continue(self) -> bool:
|
||||
return self._h11_state.they_are_waiting_for_100_continue
|
||||
|
||||
def receive_data(self, data: bytes) -> None:
|
||||
self._h11_state.receive_data(data)
|
||||
|
||||
def next_event(self) -> Event:
|
||||
event = self._h11_state.next_event()
|
||||
return from_h11_event(event)
|
||||
|
||||
def send(self, event: Event) -> Optional[bytes]:
|
||||
h11_event = to_h11_event(event)
|
||||
return self._h11_state.send(h11_event)
|
||||
|
||||
def start_next_cycle(self) -> None:
|
||||
try:
|
||||
self._h11_state.start_next_cycle()
|
||||
except h11.ProtocolError as exc:
|
||||
raise ProtocolError(exc)
|
||||
|
||||
|
||||
def from_h11_event(event: Any) -> Event:
|
||||
if event is h11.NEED_DATA:
|
||||
return {"type": "NEED_DATA"}
|
||||
|
||||
if isinstance(event, h11.Request):
|
||||
return {
|
||||
"type": "Request",
|
||||
"http_version": event.http_version,
|
||||
"method": event.method,
|
||||
"target": event.target,
|
||||
"headers": event.headers,
|
||||
}
|
||||
|
||||
if isinstance(event, h11.ConnectionClosed):
|
||||
return {"type": "ConnectionClosed"}
|
||||
|
||||
if isinstance(event, h11.Data):
|
||||
return {"type": "Data", "data": event.data}
|
||||
|
||||
if isinstance(event, h11.EndOfMessage):
|
||||
return {"type": "EndOfMessage"}
|
||||
|
||||
raise RuntimeError(f"Unknown event type: {type(event)}")
|
||||
|
||||
|
||||
def to_h11_event(event: Event) -> Any:
|
||||
if event["type"] == "InformationalResponse":
|
||||
return h11.InformationalResponse(status_code=100)
|
||||
|
||||
if event["type"] == "Response":
|
||||
return h11.Response(
|
||||
status_code=event["status_code"],
|
||||
headers=event["headers"],
|
||||
reason=event["reason"],
|
||||
)
|
||||
|
||||
if event["type"] == "Data":
|
||||
return h11.Data(data=event["data"])
|
||||
|
||||
if event["type"] == "EndOfMessage":
|
||||
return h11.EndOfMessage()
|
||||
|
||||
assert event["type"] == "ConnectionClosed"
|
||||
return h11.ConnectionClosed()
|
||||
345
uvicorn/_async_agnostic/http11/parsers/httptools.py
Normal file
345
uvicorn/_async_agnostic/http11/parsers/httptools.py
Normal file
@ -0,0 +1,345 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httptools
|
||||
|
||||
from ...exceptions import ProtocolError
|
||||
from .base import Event, HTTP11Parser
|
||||
|
||||
HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
|
||||
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")
|
||||
|
||||
|
||||
class HttpToolsParser(HTTP11Parser):
|
||||
"""
|
||||
An HTTP/1.1 parser backed by the `httptools` library.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._parser = httptools.HttpRequestParser(self)
|
||||
self._state = State()
|
||||
# Parser-specific state.
|
||||
self._parsed_url: Optional[Any] = None
|
||||
self._headers: List[Tuple[bytes, bytes]] = []
|
||||
self._chunked_encoding: Optional[bool] = None
|
||||
self._expected_content_length: Optional[int] = None
|
||||
|
||||
# Parser API.
|
||||
|
||||
def states(self) -> dict:
|
||||
return self._state.states()
|
||||
|
||||
@property
|
||||
def they_are_waiting_for_100_continue(self) -> bool:
|
||||
return self._state.client_waiting_for_100_continue
|
||||
|
||||
def receive_data(self, data: bytes) -> None:
|
||||
try:
|
||||
self._parser.feed_data(data)
|
||||
except (
|
||||
httptools.HttpParserInvalidMethodError,
|
||||
httptools.HttpParserInvalidURLError,
|
||||
httptools.HttpParserError,
|
||||
) as exc:
|
||||
self._state.process_error("client")
|
||||
raise ProtocolError(exc)
|
||||
except (
|
||||
httptools.HttpParserInvalidStatusError,
|
||||
httptools.HttpParserCallbackError,
|
||||
) as exc:
|
||||
self._state.process_error("server")
|
||||
raise ProtocolError(exc)
|
||||
|
||||
if not data:
|
||||
self._state.process_client_event({"type": "ConnectionClosed"})
|
||||
|
||||
def next_event(self) -> Event:
|
||||
return self._state.next_event()
|
||||
|
||||
def send(self, event: Event) -> Optional[bytes]:
|
||||
self._state.process_server_event(event)
|
||||
|
||||
if event["type"] == "InformationalResponse":
|
||||
return self._render_informational_response()
|
||||
|
||||
if event["type"] == "Response":
|
||||
if self._parser.get_method() == b"HEAD":
|
||||
self._expected_content_length = 0
|
||||
status_code = event["status_code"]
|
||||
headers = event["headers"]
|
||||
reason = event["reason"]
|
||||
return self._render_response(status_code, headers, reason)
|
||||
|
||||
if event["type"] == "Data":
|
||||
body = event["data"]
|
||||
return self._render_response_body(body)
|
||||
|
||||
if event["type"] == "EndOfMessage":
|
||||
num_bytes_remaining = self._expected_content_length or 0
|
||||
if num_bytes_remaining != 0:
|
||||
raise ProtocolError(
|
||||
"Too little data for declared Content-Length: "
|
||||
f"{num_bytes_remaining} remaining"
|
||||
)
|
||||
return b""
|
||||
|
||||
assert event["type"] == "ConnectionClosed"
|
||||
return None
|
||||
|
||||
def start_next_cycle(self) -> None:
|
||||
try:
|
||||
self._state.start_next_cycle()
|
||||
except ProtocolError:
|
||||
raise
|
||||
|
||||
# Reset.
|
||||
self._parsed_url = None
|
||||
self._headers.clear()
|
||||
self._chunked_encoding = None
|
||||
self._expected_content_length = None
|
||||
|
||||
# Response rendering helpers.
|
||||
|
||||
def _render_informational_response(self) -> bytes:
|
||||
return b"HTTP/1.1 100 Continue\r\n\r\n"
|
||||
|
||||
def _render_response(
|
||||
self, status_code: int, headers: List[Tuple[bytes, bytes]], reason: bytes
|
||||
) -> bytes:
|
||||
status_line = b"".join(
|
||||
[b"HTTP/1.1", b" ", str(status_code).encode("utf-8"), b" ", reason, b"\r\n"]
|
||||
)
|
||||
|
||||
content = [status_line]
|
||||
|
||||
for name, value in headers:
|
||||
if HEADER_RE.search(name):
|
||||
raise RuntimeError("Invalid HTTP header name")
|
||||
if HEADER_VALUE_RE.search(value):
|
||||
raise RuntimeError("Invalid HTTP header value")
|
||||
|
||||
name = name.lower()
|
||||
if name == b"content-length" and self._chunked_encoding is None:
|
||||
self._expected_content_length = int(value.decode("ascii"))
|
||||
self._chunked_encoding = False
|
||||
elif name == b"transfer-encoding" and value.lower() == b"chunked":
|
||||
self._expected_content_length = 0
|
||||
self._chunked_encoding = True
|
||||
elif name == b"connection" and value.lower() == b"close":
|
||||
self._state.process_keep_alive_disabled()
|
||||
content.extend([name, b": ", value, b"\r\n"])
|
||||
|
||||
if (
|
||||
self._chunked_encoding is None
|
||||
and self._parser.get_method() != b"HEAD"
|
||||
and status_code not in (204, 304)
|
||||
):
|
||||
# Neither content-length nor transfer-encoding specified
|
||||
self._chunked_encoding = True
|
||||
content.append(b"transfer-encoding: chunked\r\n")
|
||||
|
||||
content.append(b"\r\n")
|
||||
|
||||
if self._chunked_encoding:
|
||||
content.append(b"0\r\n\r\n")
|
||||
|
||||
return b"".join(content)
|
||||
|
||||
def _render_response_body(self, body: bytes) -> bytes:
|
||||
if self._chunked_encoding:
|
||||
content = [b"%x\r\n" % len(body), body, b"\r\n"] if body else []
|
||||
content.append(b"0\r\n\r\n")
|
||||
return b"".join(content)
|
||||
|
||||
assert self._expected_content_length is not None
|
||||
if len(body) > self._expected_content_length:
|
||||
raise RuntimeError("Response content longer than Content-Length")
|
||||
self._expected_content_length -= len(body)
|
||||
|
||||
return body
|
||||
|
||||
# HttpTools callbacks.
|
||||
|
||||
def on_message_begin(self) -> None:
|
||||
if self._parser.get_http_version() == "1.0":
|
||||
self._state.process_keep_alive_disabled()
|
||||
|
||||
def on_url(self, url: str) -> None:
|
||||
assert self._parsed_url is None
|
||||
self._parsed_url = httptools.parse_url(url)
|
||||
|
||||
def on_header(self, name: bytes, value: bytes) -> None:
|
||||
name = name.lower()
|
||||
if name == b"expect" and value.lower() == b"100-continue":
|
||||
self._state.process_expect_100_continue()
|
||||
if name == b"connection" and value.lower() == b"close":
|
||||
self._state.process_keep_alive_disabled()
|
||||
self._headers.append((name, value))
|
||||
|
||||
def on_headers_complete(self) -> None:
|
||||
assert self._parsed_url is not None
|
||||
|
||||
target = self._parsed_url.path
|
||||
if self._parsed_url.query:
|
||||
target += b"?%s" % self._parsed_url.query
|
||||
|
||||
event = {
|
||||
"type": "Request",
|
||||
"http_version": self._parser.get_http_version().encode("ascii"),
|
||||
"method": self._parser.get_method(),
|
||||
"target": target,
|
||||
"headers": self._headers,
|
||||
}
|
||||
|
||||
self._state.process_client_event(event)
|
||||
|
||||
def on_body(self, body: bytes) -> None:
|
||||
event = {"type": "Data", "data": body}
|
||||
self._state.process_client_event(event)
|
||||
|
||||
def on_message_complete(self) -> None:
|
||||
event = {"type": "EndOfMessage"}
|
||||
self._state.process_client_event(event)
|
||||
|
||||
|
||||
# Code below is adapted from h11's state machine code.
|
||||
|
||||
EVENT_TRIGGERED_TRANSITIONS: Dict[
|
||||
str, Dict[str, Dict[Union[str, Tuple[str, str]], str]]
|
||||
] = {
|
||||
# Role -> State -> Event -> New state
|
||||
"client": {
|
||||
"IDLE": {"Request": "SEND_BODY", "ConnectionClosed": "CLOSED"},
|
||||
"SEND_BODY": {"Data": "SEND_BODY", "EndOfMessage": "DONE"},
|
||||
"DONE": {"ConnectionClosed": "CLOSED"},
|
||||
"MUST_CLOSE": {"ConnectionClosed": "CLOSED"},
|
||||
"CLOSED": {"ConnectionClosed": "CLOSED"},
|
||||
"ERROR": {},
|
||||
},
|
||||
"server": {
|
||||
"IDLE": {
|
||||
"ConnectionClosed": "CLOSED",
|
||||
"Response": "SEND_BODY",
|
||||
# Special case: server sees client Request events.
|
||||
("Request", "client"): "SEND_RESPONSE",
|
||||
},
|
||||
"SEND_RESPONSE": {
|
||||
"InformationalResponse": "SEND_RESPONSE",
|
||||
"Response": "SEND_BODY",
|
||||
},
|
||||
"SEND_BODY": {"Data": "SEND_BODY", "EndOfMessage": "DONE"},
|
||||
"DONE": {"ConnectionClosed": "CLOSED"},
|
||||
"MUST_CLOSE": {"ConnectionClosed": "CLOSED"},
|
||||
"CLOSED": {"ConnectionClosed": "CLOSED"},
|
||||
"ERROR": {},
|
||||
},
|
||||
}
|
||||
|
||||
STATE_TRIGGERED_TRANSITIONS: Dict[Tuple[str, str], Dict[str, str]] = {
|
||||
# (Client state, Server state) -> New states
|
||||
# Socket shutdown
|
||||
("CLOSED", "DONE"): {"server": "MUST_CLOSE"},
|
||||
("CLOSED", "IDLE"): {"server": "MUST_CLOSE"},
|
||||
("ERROR", "DONE"): {"server": "MUST_CLOSE"},
|
||||
("DONE", "CLOSED"): {"client": "MUST_CLOSE"},
|
||||
("IDLE", "CLOSED"): {"client": "MUST_CLOSE"},
|
||||
("DONE", "ERROR"): {"clientt": "MUST_CLOSE"},
|
||||
}
|
||||
|
||||
|
||||
class State:
|
||||
def __init__(self) -> None:
|
||||
self._states = {"client": "IDLE", "server": "IDLE"}
|
||||
self._client_events: List[Event] = []
|
||||
self._keep_alive_enabled = True
|
||||
self._expect_100_continue = False
|
||||
|
||||
@property
|
||||
def client_waiting_for_100_continue(self) -> bool:
|
||||
return self._expect_100_continue
|
||||
|
||||
def states(self) -> dict:
|
||||
return dict(self._states)
|
||||
|
||||
def _process_event(self, role: str, event: Event) -> None:
|
||||
self._fire_event_triggered_transitions(role, event["type"])
|
||||
if event["type"] == "Request":
|
||||
# Special case: the server state does get to see Request events.
|
||||
self._fire_event_triggered_transitions("server", ("Request", "client"))
|
||||
self._fire_state_triggered_transitions()
|
||||
|
||||
def process_client_event(self, event: Event) -> None:
|
||||
self._process_event("client", event)
|
||||
self._client_events.append(event)
|
||||
|
||||
def process_server_event(self, event: Event) -> None:
|
||||
self._process_event("server", event)
|
||||
|
||||
def process_keep_alive_disabled(self) -> None:
|
||||
self._keep_alive_enabled = False
|
||||
|
||||
def process_expect_100_continue(self) -> None:
|
||||
assert not self._expect_100_continue
|
||||
self._expect_100_continue = True
|
||||
|
||||
def process_error(self, role: str) -> None:
|
||||
self._states[role] = "ERROR"
|
||||
self._fire_state_triggered_transitions() # Peer may have to close.
|
||||
|
||||
def _fire_event_triggered_transitions(
|
||||
self, role: str, event_type: Union[str, Tuple[str, str]]
|
||||
) -> None:
|
||||
state = self._states[role]
|
||||
|
||||
try:
|
||||
new_state = EVENT_TRIGGERED_TRANSITIONS[role][state][event_type]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"can't handle event type {event_type} when "
|
||||
f"role={role.upper()} and state={state}"
|
||||
)
|
||||
|
||||
self._states[role] = new_state
|
||||
|
||||
if role == "server" and event_type in {"Response", "InformationalResponse"}:
|
||||
self._expect_100_continue = False
|
||||
|
||||
def _fire_state_triggered_transitions(self) -> None:
|
||||
# Apply transitions until we converge to a fixed point.
|
||||
while True:
|
||||
states = dict(self._states)
|
||||
|
||||
if not self._keep_alive_enabled:
|
||||
for role in ("client", "server"):
|
||||
if self._states[role] == "DONE":
|
||||
self._states[role] = "MUST_CLOSE"
|
||||
|
||||
joint_state = (self._states["client"], self._states["server"])
|
||||
changes = STATE_TRIGGERED_TRANSITIONS.get(joint_state, {})
|
||||
self._states.update(changes)
|
||||
|
||||
if self._states == states:
|
||||
break
|
||||
|
||||
def next_event(self) -> Event:
|
||||
if self._states["client"] == "ERROR":
|
||||
raise ProtocolError("Can't receive data when peer state is ERROR")
|
||||
|
||||
if self._states["client"] == "CLOSED":
|
||||
return {"type": "ConnectionClosed"}
|
||||
|
||||
try:
|
||||
return self._client_events.pop(0)
|
||||
except IndexError:
|
||||
return {"type": "NEED_DATA"}
|
||||
|
||||
def start_next_cycle(self) -> None:
|
||||
if self._states != {"client": "DONE", "server": "DONE"}:
|
||||
raise ProtocolError(f"Not in a reusable state: {self._states}")
|
||||
|
||||
assert self._keep_alive_enabled
|
||||
assert not self._expect_100_continue
|
||||
# Reset.
|
||||
self._states = {"client": "IDLE", "server": "IDLE"}
|
||||
self._keep_alive_enabled = True
|
||||
self._client_events.clear()
|
||||
103
uvicorn/_async_agnostic/lifespan.py
Normal file
103
uvicorn/_async_agnostic/lifespan.py
Normal file
@ -0,0 +1,103 @@
|
||||
import logging
|
||||
|
||||
from ..config import Config
|
||||
from .backends.auto import AutoBackend
|
||||
from .exceptions import LifespanFailure
|
||||
|
||||
STATE_TRANSITION_ERROR = "Got invalid state transition on lifespan protocol."
|
||||
TRACE_LOG_LEVEL = 5
|
||||
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self, config: Config) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
self._config = config
|
||||
self._logger = logging.getLogger("uvicorn.error")
|
||||
self._backend = AutoBackend()
|
||||
self._startup_event = self._backend.create_event()
|
||||
self._shutdown_event = self._backend.create_event()
|
||||
self._receive_queue = self._backend.create_queue(0)
|
||||
self._mode = config.lifespan
|
||||
self._supported = self._mode in ("on", "auto")
|
||||
|
||||
async def startup(self) -> None:
|
||||
if not self._supported:
|
||||
self._logger.log(TRACE_LOG_LEVEL, "lifespan startup skipped")
|
||||
return
|
||||
|
||||
self._logger.info("Waiting for application startup.")
|
||||
await self._receive_queue.put({"type": "lifespan.startup"})
|
||||
await self._startup_event.wait()
|
||||
self._logger.info("Application startup complete.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._supported:
|
||||
self._logger.log(TRACE_LOG_LEVEL, "lifespan shutdown skipped")
|
||||
return
|
||||
|
||||
self._logger.info("Waiting for application shutdown.")
|
||||
await self._receive_queue.put({"type": "lifespan.shutdown"})
|
||||
await self._shutdown_event.wait()
|
||||
self._logger.info("Application shutdown complete.")
|
||||
|
||||
async def main(self) -> None:
|
||||
if not self._supported:
|
||||
self._logger.log(TRACE_LOG_LEVEL, "lifespan main skipped")
|
||||
return
|
||||
|
||||
scope = {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": self._config.asgi_version, "spec_version": "2.0"},
|
||||
}
|
||||
|
||||
app = self._config.loaded_app
|
||||
|
||||
try:
|
||||
await app(scope, self._asgi_receive, self._asgi_send)
|
||||
except LifespanFailure:
|
||||
self._logger.error("Exception in 'lifespan' protocol")
|
||||
raise # Lifespan failures should stop the server.
|
||||
except Exception:
|
||||
self._logger.info("ASGI 'lifespan' protocol appears unsupported.")
|
||||
self._supported = False
|
||||
if self._mode == "on":
|
||||
raise
|
||||
finally:
|
||||
await self._startup_event.set()
|
||||
await self._shutdown_event.set()
|
||||
await self._receive_queue.aclose()
|
||||
|
||||
async def _asgi_receive(self) -> dict:
|
||||
return await self._receive_queue.get()
|
||||
|
||||
async def _asgi_send(self, message: dict) -> None:
|
||||
assert message["type"] in (
|
||||
"lifespan.startup.complete",
|
||||
"lifespan.startup.failed",
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
), message["type"]
|
||||
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
assert not self._startup_event.is_set(), STATE_TRANSITION_ERROR
|
||||
assert not self._shutdown_event.is_set(), STATE_TRANSITION_ERROR
|
||||
await self._startup_event.set()
|
||||
|
||||
elif message["type"] == "lifespan.startup.failed":
|
||||
assert not self._startup_event.is_set(), STATE_TRANSITION_ERROR
|
||||
assert not self._shutdown_event.is_set(), STATE_TRANSITION_ERROR
|
||||
await self._startup_event.set()
|
||||
raise LifespanFailure(message.get("message", ""))
|
||||
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
assert self._startup_event.is_set(), STATE_TRANSITION_ERROR
|
||||
assert not self._shutdown_event.is_set(), STATE_TRANSITION_ERROR
|
||||
await self._shutdown_event.set()
|
||||
|
||||
else:
|
||||
assert message["type"] == "lifespan.shutdown.failed"
|
||||
assert self._startup_event.is_set(), STATE_TRANSITION_ERROR
|
||||
assert not self._shutdown_event.is_set(), STATE_TRANSITION_ERROR
|
||||
await self._shutdown_event.set()
|
||||
raise LifespanFailure(message.get("message", ""))
|
||||
281
uvicorn/_async_agnostic/server.py
Normal file
281
uvicorn/_async_agnostic/server.py
Normal file
@ -0,0 +1,281 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Iterator, List, Optional
|
||||
|
||||
import click
|
||||
|
||||
from ..config import Config
|
||||
from .backends.auto import select_async_backend
|
||||
from .backends.base import AsyncListener, Event, TaskStatus
|
||||
from .http11.handler import handle_http11
|
||||
from .lifespan import Lifespan
|
||||
from .state import ServerState
|
||||
from .utils import to_internet_date
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self._config = config
|
||||
self._state = ServerState()
|
||||
self._backend = select_async_backend(config.async_library)
|
||||
self._logger = logging.getLogger("uvicorn.error")
|
||||
self._last_notified = time.time()
|
||||
self._force_exit = False
|
||||
|
||||
@property
|
||||
def _shutdown_event(self) -> Event:
|
||||
if not hasattr(self, "_shutdown_event_obj"):
|
||||
# Can't be created on init due to a limitation of multiprocessing.
|
||||
self._shutdown_event_obj = self._backend.create_event()
|
||||
return self._shutdown_event_obj
|
||||
|
||||
@property
|
||||
def _closed_event(self) -> Event:
|
||||
if not hasattr(self, "_closed_event_obj"):
|
||||
# Can't be created on init due to a limitation of multiprocessing.
|
||||
self._closed_event_obj = self._backend.create_event()
|
||||
return self._closed_event_obj
|
||||
|
||||
def run(
|
||||
self,
|
||||
sockets: List[socket.SocketType] = None,
|
||||
started: Callable[[], None] = None,
|
||||
) -> None:
|
||||
if self._config.async_library == "asyncio":
|
||||
self._config.setup_event_loop()
|
||||
self._backend.run(self.serve, sockets, started)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_in_thread(self, sockets: List[socket.SocketType] = None) -> Iterator[None]:
|
||||
thread_exc: Optional[Exception] = None
|
||||
|
||||
# Cannot be stored as an instance attribute because Event objects are not
|
||||
# picklable, so this would cause issues with multiprocessing support.
|
||||
started_event = threading.Event()
|
||||
|
||||
def target() -> None:
|
||||
nonlocal thread_exc
|
||||
try:
|
||||
self.run(sockets=sockets, started=started_event.set)
|
||||
except Exception as exc:
|
||||
self._logger.exception(exc)
|
||||
started_event.set()
|
||||
thread_exc = exc
|
||||
|
||||
thread = threading.Thread(target=target)
|
||||
thread.start()
|
||||
try:
|
||||
started_event.wait()
|
||||
if thread_exc is not None:
|
||||
raise thread_exc
|
||||
yield
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
async def serve(
|
||||
self,
|
||||
sockets: List[socket.SocketType] = None,
|
||||
started: Callable[[], None] = None,
|
||||
) -> None:
|
||||
started = (lambda: None) if started is None else started
|
||||
|
||||
process_id = os.getpid()
|
||||
message = "Started server process [%d]"
|
||||
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
|
||||
self._logger.info(message, process_id, extra={"color_message": color_message})
|
||||
|
||||
config = self._config
|
||||
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
lifespan = Lifespan(config)
|
||||
shutdown_trigger = (
|
||||
self._shutdown_event.wait
|
||||
if config.shutdown_trigger is None
|
||||
else config.shutdown_trigger
|
||||
)
|
||||
|
||||
async with self._backend.start_soon(lifespan.main, cancel_on_exit=True):
|
||||
await lifespan.startup()
|
||||
|
||||
try:
|
||||
async with self._backend.start(self._serve, sockets) as listeners:
|
||||
async with (
|
||||
self._backend.start_soon(self._main_loop),
|
||||
self._backend.start_soon(self._listen_signals),
|
||||
):
|
||||
# Server has started.
|
||||
started()
|
||||
self._log_started_message(listeners, sockets=sockets)
|
||||
# Let the server run until exit is requested.
|
||||
await shutdown_trigger()
|
||||
await self._shutdown_event.set()
|
||||
finally:
|
||||
if not self._force_exit:
|
||||
await lifespan.shutdown()
|
||||
|
||||
message = "Finished server process [%d]"
|
||||
color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
|
||||
self._logger.info(
|
||||
"Finished server process [%d]",
|
||||
process_id,
|
||||
extra={"color_message": color_message},
|
||||
)
|
||||
|
||||
async def _serve(
|
||||
self,
|
||||
sockets: List[socket.SocketType] = None,
|
||||
*,
|
||||
task_status: TaskStatus = TaskStatus.IGNORED,
|
||||
) -> None:
|
||||
await self._backend.serve_tcp(
|
||||
handle_http11,
|
||||
self._state,
|
||||
self._config,
|
||||
sockets=sockets,
|
||||
wait_close=self._shutdown_event.wait,
|
||||
on_close=self._on_close,
|
||||
task_status=task_status,
|
||||
)
|
||||
|
||||
async def _main_loop(self) -> None:
|
||||
counter = 0
|
||||
should_exit = await self._tick(counter)
|
||||
while not should_exit:
|
||||
counter += 1
|
||||
counter = counter % 864000
|
||||
await self._backend.sleep(0.1)
|
||||
should_exit = await self._tick(counter)
|
||||
await self._shutdown_event.set()
|
||||
|
||||
async def _tick(self, counter: int) -> bool:
|
||||
state = self._state
|
||||
config = self._config
|
||||
|
||||
# Update the default headers, once per second.
|
||||
if counter % 10 == 0:
|
||||
current_time = time.time()
|
||||
current_date = to_internet_date(current_time).encode()
|
||||
state.default_headers = [(b"date", current_date)] + config.encoded_headers
|
||||
|
||||
# Callback to `callback_notify` once every `timeout_notify` seconds.
|
||||
if config.callback_notify is not None:
|
||||
if current_time - self._last_notified > config.timeout_notify:
|
||||
self._last_notified = current_time
|
||||
await config.callback_notify()
|
||||
|
||||
# Determine if we should exit.
|
||||
if self._shutdown_event.is_set():
|
||||
return True
|
||||
if config.limit_max_requests is not None:
|
||||
return state.total_requests >= config.limit_max_requests
|
||||
return False
|
||||
|
||||
async def _listen_signals(self) -> None:
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
await self._closed_event.wait()
|
||||
return
|
||||
|
||||
async def handle_signal_exit() -> None:
|
||||
if self._shutdown_event.is_set():
|
||||
self._logger.info("Shutting down forcibly")
|
||||
self._force_exit = True
|
||||
else:
|
||||
await self._shutdown_event.set()
|
||||
|
||||
async def listen_signals() -> None:
|
||||
await self._backend.listen_signals(
|
||||
*HANDLED_SIGNALS, handler=handle_signal_exit
|
||||
)
|
||||
|
||||
async with self._backend.start_soon(listen_signals, cancel_on_exit=True):
|
||||
await self._closed_event.wait()
|
||||
|
||||
async def _on_close(self) -> None:
|
||||
assert self._shutdown_event.is_set()
|
||||
|
||||
self._logger.info("Shutting down")
|
||||
state = self._state
|
||||
|
||||
# Request shutdown on all existing connections.
|
||||
for conn in list(state.connections):
|
||||
await conn.trigger_shutdown()
|
||||
|
||||
# Wait for existing connections to finish sending responses.
|
||||
if state.connections and not self._force_exit:
|
||||
self._logger.info(
|
||||
"Waiting for connections to close. (CTRL+C to force quit)"
|
||||
)
|
||||
while state.connections and not self._force_exit:
|
||||
await self._backend.sleep(0.1)
|
||||
|
||||
# Wait for existing tasks to complete.
|
||||
if state.tasks and not self._force_exit:
|
||||
self._logger.info(
|
||||
"Waiting for background tasks to complete. (CTRL+C to force quit)"
|
||||
)
|
||||
while state.tasks and not self._force_exit:
|
||||
await self._backend.sleep(0.1)
|
||||
|
||||
await self._closed_event.set()
|
||||
|
||||
def _log_started_message(
|
||||
self, listeners: List[AsyncListener], sockets: List[socket.SocketType] = None
|
||||
) -> None:
|
||||
if sockets is not None:
|
||||
# We're running multiple workers, and a message has already been
|
||||
# logged by `config.bind_socket()``.
|
||||
return
|
||||
|
||||
config = self._config
|
||||
|
||||
if config.fd is not None:
|
||||
sock = listeners[0].socket
|
||||
self._logger.info(
|
||||
"Uvicorn running on socket %s (Press CTRL+C to quit)",
|
||||
sock.getsockname(),
|
||||
)
|
||||
return
|
||||
|
||||
if config.uds is not None:
|
||||
self._logger.info(
|
||||
"Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds
|
||||
)
|
||||
return
|
||||
|
||||
addr_format = "%s://%s:%d"
|
||||
host = "0.0.0.0" if config.host is None else config.host
|
||||
if ":" in host:
|
||||
# It's an IPv6 address.
|
||||
addr_format = "%s://[%s]:%d"
|
||||
|
||||
port = config.port
|
||||
if port == 0:
|
||||
sock = listeners[0].socket
|
||||
_, port = sock.getpeername()
|
||||
|
||||
protocol_name = "https" if config.ssl else "http"
|
||||
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(addr_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
self._logger.info(
|
||||
message,
|
||||
protocol_name,
|
||||
host,
|
||||
port,
|
||||
extra={"color_message": color_message},
|
||||
)
|
||||
9
uvicorn/_async_agnostic/state.py
Normal file
9
uvicorn/_async_agnostic/state.py
Normal file
@ -0,0 +1,9 @@
|
||||
from typing import Any, List, Set, Tuple
|
||||
|
||||
|
||||
class ServerState:
|
||||
def __init__(self) -> None:
|
||||
self.default_headers: List[Tuple[bytes, bytes]] = [] # Set by the server.
|
||||
self.total_requests = 0 # Updated by server handlers.
|
||||
self.connections: Set[Any] = set()
|
||||
self.tasks: Set[Any] = set() # May be used by async backends (not all do).
|
||||
65
uvicorn/_async_agnostic/utils.py
Normal file
65
uvicorn/_async_agnostic/utils.py
Normal file
@ -0,0 +1,65 @@
|
||||
import http
|
||||
import socket
|
||||
from email.utils import formatdate
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
|
||||
def _get_status_phrase(status_code: int) -> bytes:
|
||||
try:
|
||||
return http.HTTPStatus(status_code).phrase.encode()
|
||||
except ValueError:
|
||||
return b""
|
||||
|
||||
|
||||
STATUS_PHRASES = {
|
||||
status_code: _get_status_phrase(status_code) for status_code in range(100, 600)
|
||||
}
|
||||
|
||||
RECV_CHUNK_SIZE = 2 ** 16
|
||||
|
||||
TRACE_LOG_LEVEL = 5
|
||||
|
||||
|
||||
def get_sock_remote_addr(sock: socket.SocketType) -> Optional[Tuple[str, int]]:
|
||||
try:
|
||||
info = sock.getpeername()
|
||||
except OSError:
|
||||
# This case appears to inconsistently occur with uvloop
|
||||
# bound to a unix domain socket.
|
||||
return None
|
||||
else:
|
||||
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
|
||||
|
||||
|
||||
def get_sock_local_addr(sock: socket.SocketType) -> Optional[Tuple[str, int]]:
|
||||
info = sock.getsockname()
|
||||
if isinstance(info, tuple):
|
||||
return (str(info[0]), int(info[1]))
|
||||
return None
|
||||
|
||||
|
||||
def get_path_with_query_string(scope: dict) -> str:
|
||||
path = quote(scope.get("root_path", "") + scope["path"])
|
||||
qs = scope["query_string"]
|
||||
if qs:
|
||||
path += "?{}".format(qs.decode("ascii"))
|
||||
return path
|
||||
|
||||
|
||||
def to_internet_date(value: float) -> str:
|
||||
return formatdate(value, usegmt=True)
|
||||
|
||||
|
||||
def find_upgrade_header(headers: List[Tuple[bytes, bytes]]) -> Optional[bytes]:
|
||||
connection = next((value for name, value in headers if name == b"connection"), None)
|
||||
|
||||
if connection is None:
|
||||
return None
|
||||
|
||||
tokens = [token.lower().strip() for token in connection.split(b",")]
|
||||
|
||||
if b"upgrade" not in tokens:
|
||||
return None
|
||||
|
||||
return next(value.lower() for name, value in headers if name == b"upgrade")
|
||||
@ -123,6 +123,7 @@ class Config:
|
||||
port=8000,
|
||||
uds=None,
|
||||
fd=None,
|
||||
async_library=None,
|
||||
loop="auto",
|
||||
http="auto",
|
||||
ws="auto",
|
||||
@ -147,6 +148,7 @@ class Config:
|
||||
timeout_keep_alive=5,
|
||||
timeout_notify=30,
|
||||
callback_notify=None,
|
||||
shutdown_trigger=None,
|
||||
ssl_keyfile=None,
|
||||
ssl_certfile=None,
|
||||
ssl_keyfile_password=None,
|
||||
@ -161,6 +163,7 @@ class Config:
|
||||
self.port = port
|
||||
self.uds = uds
|
||||
self.fd = fd
|
||||
self.async_library = async_library
|
||||
self.loop = loop
|
||||
self.http = http
|
||||
self.ws = ws
|
||||
@ -182,6 +185,7 @@ class Config:
|
||||
self.timeout_keep_alive = timeout_keep_alive
|
||||
self.timeout_notify = timeout_notify
|
||||
self.callback_notify = callback_notify
|
||||
self.shutdown_trigger = shutdown_trigger
|
||||
self.ssl_keyfile = ssl_keyfile
|
||||
self.ssl_certfile = ssl_certfile
|
||||
self.ssl_keyfile_password = ssl_keyfile_password
|
||||
|
||||
@ -21,10 +21,13 @@ from uvicorn.config import (
|
||||
)
|
||||
from uvicorn.supervisors import ChangeReload, Multiprocess
|
||||
|
||||
from ._async_agnostic.server import Server as AsyncAgnosticServer
|
||||
|
||||
LEVEL_CHOICES = click.Choice(LOG_LEVELS.keys())
|
||||
HTTP_CHOICES = click.Choice(HTTP_PROTOCOLS.keys())
|
||||
WS_CHOICES = click.Choice(WS_PROTOCOLS.keys())
|
||||
LIFESPAN_CHOICES = click.Choice(LIFESPAN.keys())
|
||||
LIBRARY_CHOICES = click.Choice(["asyncio", "trio", "curio"])
|
||||
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
|
||||
INTERFACE_CHOICES = click.Choice(INTERFACES)
|
||||
|
||||
@ -96,6 +99,13 @@ def print_version(ctx, param, value):
|
||||
help="Number of worker processes. Defaults to the $WEB_CONCURRENCY environment"
|
||||
" variable if available. Not valid with --reload.",
|
||||
)
|
||||
@click.option(
|
||||
"--async-library",
|
||||
type=LIBRARY_CHOICES,
|
||||
default=None,
|
||||
help="Concurrency library implementation.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--loop",
|
||||
type=LOOP_CHOICES,
|
||||
@ -283,6 +293,7 @@ def main(
|
||||
port: int,
|
||||
uds: str,
|
||||
fd: int,
|
||||
async_library: typing.Optional[str],
|
||||
loop: str,
|
||||
http: str,
|
||||
ws: str,
|
||||
@ -323,6 +334,7 @@ def main(
|
||||
"port": port,
|
||||
"uds": uds,
|
||||
"fd": fd,
|
||||
"async_library": async_library,
|
||||
"loop": loop,
|
||||
"http": http,
|
||||
"ws": ws,
|
||||
@ -359,7 +371,11 @@ def main(
|
||||
|
||||
def run(app, **kwargs):
|
||||
config = Config(app, **kwargs)
|
||||
server = Server(config=config)
|
||||
|
||||
if config.async_library is not None:
|
||||
server = AsyncAgnosticServer(config)
|
||||
else:
|
||||
server = Server(config=config)
|
||||
|
||||
if (config.reload or config.workers > 1) and not isinstance(app, str):
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user