Compare commits

...

9 Commits

Author SHA1 Message Date
florimondmanca
46555b5bed
Tweak run_in_thread impl so no Event is stored on Server instances 2020-11-23 22:25:48 +01:00
florimondmanca
57534e605d
Windows compat: ensure an event loop is provisioned even if anyio shut it down 2020-11-23 22:25:34 +01:00
florimondmanca
c4f5c76462
Python 3.6 compat 2020-11-23 21:10:23 +01:00
florimondmanca
bf17e875e7
Expose async Server.serve() method 2020-11-23 21:07:39 +01:00
florimondmanca
ed4ffac972
Curio fixes 2020-11-22 19:50:19 +01:00
florimondmanca
2a01073135
call_later -> wait_then_call 2020-11-22 14:54:06 +01:00
florimondmanca
77f48d1351
Refactor: parsers w/ state machine 2020-11-22 13:17:38 +01:00
florimondmanca
a14dd4b38b
Add httptools 2020-11-22 01:32:10 +01:00
florimondmanca
34fd9cd460
HTTP/1.1 + h11 + keep-alive + asyncio + trio + curio + tests 2020-11-21 19:53:43 +01:00
34 changed files with 3461 additions and 1 deletions

View File

@ -8,8 +8,10 @@ twine
wheel
# Testing
anyio==2.0.2
autoflake
black
curio==1.4
flake8
isort
pytest
@ -17,6 +19,7 @@ pytest-mock
requests
mypy
trustme
trio==0.17.*
cryptography
coverage

View File

@ -47,6 +47,8 @@ minimal_requirements = [
"click==7.*",
"h11>=0.8",
"typing-extensions;" + env_marker_below_38,
"async_generator; python_version<'3.7'",
"async_exit_stack; python_version<'3.7'",
]
extra_requirements = [

View File

View File

@ -0,0 +1,32 @@
import asyncio
from typing import Any
import pytest
@pytest.fixture(
params=[
pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"),
pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"),
pytest.param(
("trio", {"restrict_keyboard_interrupt_to_checkpoints": True}), id="trio"
),
pytest.param(("curio", {}), id="curio"),
],
)
def anyio_backend(request: Any) -> str:
return request.param
@pytest.fixture(autouse=True)
def ensure_event_loop():
try:
asyncio.get_event_loop()
except RuntimeError:
# We use anyio; its pytest plugin shuts down and unsets the asyncio event loop
# after each test cases. For some reason on Windows when the loop is unset
# asyncio won't create one again by itself. So we give it a hand.
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield

View File

@ -0,0 +1,85 @@
from typing import Callable
import pytest
import requests
from uvicorn import Config
from uvicorn._async_agnostic import Server
from .utils import HTTP11_IMPLEMENTATIONS
async def app(scope: dict, receive: Callable, send: Callable) -> None:
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
@pytest.mark.parametrize("async_library", ["asyncio", "trio", "curio"])
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
def test_default_headers(async_library: str, http: str) -> None:
config = Config(
app=app,
async_library=async_library,
http=http,
limit_max_requests=1,
)
with Server(config=config).run_in_thread():
response = requests.get("http://localhost:8000")
assert response.headers["server"] == "uvicorn" and response.headers["date"]
@pytest.mark.parametrize("async_library", ["asyncio", "trio", "curio"])
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
def test_override_server_header(async_library: str, http: str) -> None:
config = Config(
app=app,
async_library=async_library,
http=http,
limit_max_requests=1,
headers=[("Server", "overridden")],
)
with Server(config=config).run_in_thread():
response = requests.get("http://localhost:8000")
assert response.headers["server"] == "overridden" and response.headers["date"]
@pytest.mark.parametrize("async_library", ["asyncio", "trio", "curio"])
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
def test_override_server_header_multiple_times(async_library: str, http: str) -> None:
config = Config(
app=app,
async_library=async_library,
http=http,
limit_max_requests=1,
headers=[("Server", "overridden"), ("Server", "another-value")],
)
with Server(config=config).run_in_thread():
response = requests.get("http://localhost:8000")
assert (
response.headers["server"] == "overridden, another-value"
and response.headers["date"]
)
@pytest.mark.parametrize("async_library", ["asyncio", "trio", "curio"])
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
def test_add_additional_header(async_library: str, http: str) -> None:
config = Config(
app=app,
async_library=async_library,
http=http,
limit_max_requests=1,
headers=[("X-Additional", "new-value")],
)
with Server(config=config).run_in_thread():
response = requests.get("http://localhost:8000")
assert (
response.headers["x-additional"] == "new-value"
and response.headers["server"] == "uvicorn"
and response.headers["date"]
)

View File

@ -0,0 +1,235 @@
import logging
from typing import Any, Callable, Tuple
import pytest
from uvicorn._async_agnostic.backends.auto import AutoBackend
from uvicorn._async_agnostic.backends.base import AsyncSocket
from uvicorn._async_agnostic.http11.handler import handle_http11
from uvicorn._async_agnostic.state import ServerState
from uvicorn.config import Config
from ..response import Response
from .utils import HTTP11_IMPLEMENTATIONS
SIMPLE_GET_REQUEST = b"\r\n".join(
[
b"GET / HTTP/1.1",
b"Host: example.org",
b"",
b"",
]
)
SIMPLE_HEAD_REQUEST = b"\r\n".join(
[
b"HEAD / HTTP/1.1",
b"Host: example.org",
b"",
b"",
]
)
SIMPLE_POST_REQUEST = b"\r\n".join(
[
b"POST / HTTP/1.1",
b"Host: example.org",
b"Content-Type: application/json",
b"Content-Length: 18",
b"",
b'{"hello": "world"}',
]
)
class MockSocket(AsyncSocket):
"""
An in-memory socket for testing purposes.
"""
def __init__(self, request: bytes, prevent_keepalive_loop: bool = True) -> None:
self._backend = AutoBackend()
self._request = request
self._prevent_keepalive_loop = prevent_keepalive_loop
self._response = b""
self._readable = False
self._is_closed = False
self._response_received = self._backend.create_event()
def get_remote_addr(self) -> Tuple[str, int]:
return ("127.0.0.1", 8000)
def get_local_addr(self) -> Tuple[str, int]:
return ("127.0.0.1", 42424)
is_ssl = False
@property
def response(self) -> bytes:
return self._response
async def read(self, n: int) -> bytes:
if self._readable:
return b""
if not self._request:
while not self._readable:
await self._backend.sleep(0.01)
return b""
data, self._request = self._request[:n], self._request[n:]
if not self._request and self._prevent_keepalive_loop:
await self._response_received.set()
self._readable = True
return data
async def write(self, data: bytes) -> None:
if data == b"":
await self._response_received.set()
if self._prevent_keepalive_loop:
# Simulate a client disconnect right after having received
# the response so that the HTTP handler doesn't run a keep-alive
# cycle for nothing.
self.simulate_client_disconnect()
return
self._response += data
async def wait_response_received(self) -> None:
await self._response_received.wait()
def simulate_client_disconnect(self) -> None:
self._readable = True
async def send_eof(self) -> None:
# Simulate instantaneous acknowledgement by client.
self._readable = True
async def aclose(self) -> None:
self._is_closed = True
@property
def is_closed(self) -> bool:
return self._is_closed
@pytest.mark.anyio
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
async def test_get_request(http: str) -> None:
app = Response("Hello, world", media_type="text/plain")
sock = MockSocket(SIMPLE_GET_REQUEST)
await handle_http11(sock, ServerState(), Config(app=app, http=http))
assert b"HTTP/1.1 200 OK" in sock.response
assert b"Hello, world" in sock.response
@pytest.mark.anyio
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
async def test_head_request(http: str) -> None:
app = Response("Hello, world", media_type="text/plain")
sock = MockSocket(SIMPLE_HEAD_REQUEST)
await handle_http11(sock, ServerState(), Config(app=app, http=http))
assert b"HTTP/1.1 200 OK" in sock.response
assert b"Hello, world" not in sock.response
@pytest.mark.anyio
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
async def test_post_request(http: str) -> None:
async def app(scope: dict, receive: Callable, send: Callable) -> None:
body = b""
while True:
message = await receive()
body += message.get("body", b"")
if not message.get("more_body", False):
break
response = Response(b"Body: " + body, media_type="text/plain")
await response(scope, receive, send)
sock = MockSocket(SIMPLE_POST_REQUEST)
await handle_http11(sock, ServerState(), Config(app=app, http=http))
assert b"HTTP/1.1 200 OK" in sock.response
assert b'Body: {"hello": "world"}' in sock.response
@pytest.mark.anyio
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
@pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"])
async def test_request_logging(http: str, path: str, caplog: Any) -> None:
app = Response("Hello, world", media_type="text/plain")
get_request_with_query_string = b"\r\n".join(
[
"GET {} HTTP/1.1".format(path).encode("ascii"),
b"Host: example.org",
b"",
b"",
]
)
sock = MockSocket(get_request_with_query_string)
state = ServerState()
config = Config(app=app, http=http) # Keep this here -- configures initial logging.
with caplog.at_level(logging.INFO, logger="uvicorn.access"):
logging.getLogger("uvicorn.access").propagate = True
await handle_http11(sock, state, config)
assert '"GET {} HTTP/1.1" 200'.format(path) in caplog.records[0].message
@pytest.mark.anyio
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
async def test_keepalive(http: str) -> None:
app = Response(b"", status_code=204)
sock = MockSocket(SIMPLE_GET_REQUEST, prevent_keepalive_loop=False)
config = Config(app=app, http=http)
async with AutoBackend().start_soon(handle_http11, sock, ServerState(), config):
await sock.wait_response_received()
assert b"HTTP/1.1 204 No Content" in sock.response
assert not sock.is_closed
sock.simulate_client_disconnect()
assert sock.is_closed
@pytest.mark.anyio
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
async def test_keepalive_timeout(http: str) -> None:
app = Response(b"", status_code=204)
sock = MockSocket(SIMPLE_GET_REQUEST, prevent_keepalive_loop=False)
backend = AutoBackend()
config = Config(app=app, http=http, timeout_keep_alive=0.1)
async with backend.start_soon(handle_http11, sock, ServerState(), config):
await sock.wait_response_received()
assert b"HTTP/1.1 204 No Content" in sock.response
assert not sock.is_closed
await backend.sleep(0.01)
assert not sock.is_closed
await backend.sleep(0.2)
assert sock.is_closed
@pytest.mark.anyio
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
async def test_close(http: str) -> None:
app = Response(b"", status_code=204, headers={"connection": "close"})
sock = MockSocket(SIMPLE_GET_REQUEST, prevent_keepalive_loop=False)
await handle_http11(sock, state=ServerState(), config=Config(app=app, http=http))
assert b"HTTP/1.1 204 No Content" in sock.response
assert sock.is_closed

View File

@ -0,0 +1,167 @@
from typing import Callable
import pytest
from uvicorn._async_agnostic.backends.auto import AutoBackend
from uvicorn._async_agnostic.exceptions import LifespanFailure
from uvicorn._async_agnostic.lifespan import Lifespan
from uvicorn.config import Config
@pytest.mark.anyio
async def test_lifespan_on() -> None:
startup_complete = False
shutdown_complete = False
async def app(scope: dict, receive: Callable, send: Callable) -> None:
nonlocal startup_complete, shutdown_complete
message = await receive()
assert message["type"] == "lifespan.startup"
startup_complete = True
await send({"type": "lifespan.startup.complete"})
message = await receive()
assert message["type"] == "lifespan.shutdown"
shutdown_complete = True
await send({"type": "lifespan.shutdown.complete"})
config = Config(app=app, lifespan="on")
lifespan = Lifespan(config)
async with AutoBackend().start_soon(lifespan.main):
assert not startup_complete
assert not shutdown_complete
await lifespan.startup()
assert startup_complete
assert not shutdown_complete
await lifespan.shutdown()
assert startup_complete
assert shutdown_complete
@pytest.mark.anyio
async def test_lifespan_off() -> None:
async def app(scope: dict, receive: Callable, send: Callable) -> None:
pass # pragma: no cover
config = Config(app=app, lifespan="off")
lifespan = Lifespan(config)
async with AutoBackend().start_soon(lifespan.main):
await lifespan.startup()
await lifespan.shutdown()
@pytest.mark.anyio
async def test_lifespan_auto() -> None:
startup_complete = False
shutdown_complete = False
async def app(scope: dict, receive: Callable, send: Callable) -> None:
nonlocal startup_complete, shutdown_complete
message = await receive()
assert message["type"] == "lifespan.startup"
startup_complete = True
await send({"type": "lifespan.startup.complete"})
message = await receive()
assert message["type"] == "lifespan.shutdown"
shutdown_complete = True
await send({"type": "lifespan.shutdown.complete"})
config = Config(app=app, lifespan="auto")
lifespan = Lifespan(config)
async with AutoBackend().start_soon(lifespan.main):
assert not startup_complete
assert not shutdown_complete
await lifespan.startup()
assert startup_complete
assert not shutdown_complete
await lifespan.shutdown()
assert startup_complete
assert shutdown_complete
@pytest.mark.anyio
async def test_lifespan_auto_with_error() -> None:
async def app(scope: dict, receive: Callable, send: Callable) -> None:
assert scope["type"] == "http"
config = Config(app=app, lifespan="auto")
lifespan = Lifespan(config)
async with AutoBackend().start_soon(lifespan.main):
await lifespan.startup()
await lifespan.shutdown()
@pytest.mark.anyio
async def test_lifespan_on_with_error() -> None:
async def app(scope: dict, receive: Callable, send: Callable) -> None:
assert scope["type"] == "lifespan"
raise RuntimeError("Oops")
config = Config(app=app, lifespan="on")
lifespan = Lifespan(config)
with pytest.raises(RuntimeError, match="Oops"):
async with AutoBackend().start_soon(lifespan.main):
await lifespan.startup()
@pytest.mark.anyio
@pytest.mark.parametrize("mode", ("auto", "on"))
@pytest.mark.parametrize("raise_exception", (True, False))
async def test_lifespan_with_failed_startup(mode: str, raise_exception: bool) -> None:
async def app(scope: dict, receive: Callable, send: Callable) -> None:
message = await receive()
assert message["type"] == "lifespan.startup"
await send({"type": "lifespan.startup.failed"})
if raise_exception:
# App should be able to re-raise an exception if startup failed.
raise RuntimeError()
config = Config(app=app, lifespan=mode)
lifespan = Lifespan(config)
with pytest.raises(LifespanFailure):
async with AutoBackend().start_soon(lifespan.main):
await lifespan.startup()
@pytest.mark.anyio
@pytest.mark.parametrize("mode", ("auto", "on"))
async def test_lifespan_scope_asgi3app(mode: str) -> None:
async def asgi3app(scope: dict, receive: Callable, send: Callable) -> None:
assert scope == {
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.0"},
}
config = Config(app=asgi3app, lifespan=mode)
lifespan = Lifespan(config)
async with AutoBackend().start_soon(lifespan.main):
await lifespan.startup()
await lifespan.shutdown()
@pytest.mark.anyio
@pytest.mark.parametrize("mode", ("auto", "on"))
async def test_lifespan_scope_asgi2app(mode: str) -> None:
def asgi2app(scope: dict) -> Callable:
assert scope == {
"type": "lifespan",
"asgi": {"version": "2.0", "spec_version": "2.0"},
}
async def asgi(receive: Callable, send: Callable) -> None:
pass
return asgi
config = Config(app=asgi2app, lifespan=mode)
lifespan = Lifespan(config)
async with AutoBackend().start_soon(lifespan.main):
await lifespan.startup()
await lifespan.shutdown()

View File

@ -0,0 +1,88 @@
from typing import Callable
import pytest
import requests
from uvicorn._async_agnostic import Server
from uvicorn.config import Config
from .utils import HTTP11_IMPLEMENTATIONS
async def app(scope: dict, receive: Callable, send: Callable) -> None:
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
@pytest.mark.parametrize(
"host, url",
[
pytest.param(None, "http://127.0.0.1:8000", id="default"),
pytest.param("localhost", "http://127.0.0.1:8000", id="hostname"),
pytest.param("::1", "http://[::1]:8000", id="ipv6"),
],
)
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
def test_run(host: str, url: str, async_library: str, http: str) -> None:
config = Config(
app=app,
host=host,
async_library=async_library,
http=http,
limit_max_requests=1,
)
with Server(config).run_in_thread():
response = requests.get(url)
assert response.status_code == 204
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
def test_run_multiprocess(async_library: str) -> None:
config = Config(
app=app,
workers=2,
async_library=async_library,
limit_max_requests=1,
)
with Server(config).run_in_thread():
response = requests.get("http://127.0.0.1:8000")
assert response.status_code == 204
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
def test_run_reload(async_library: str) -> None:
config = Config(
app=app,
reload=True,
async_library=async_library,
limit_max_requests=1,
)
with Server(config).run_in_thread():
response = requests.get("http://127.0.0.1:8000")
assert response.status_code == 204
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
@pytest.mark.parametrize("http", HTTP11_IMPLEMENTATIONS)
def test_run_with_shutdown(async_library: str, http: str) -> None:
async def app(scope: dict, receive: Callable, send: Callable) -> None:
pass
async def shutdown_immediately() -> None:
pass
config = Config(
app=app,
workers=2,
shutdown_trigger=shutdown_immediately,
async_library=async_library,
http=http,
)
with Server(config).run_in_thread():
pass

View File

@ -0,0 +1,94 @@
import contextlib
import sys
import warnings
from functools import partialmethod
from typing import Callable
import pytest
import requests
from urllib3.exceptions import InsecureRequestWarning
from uvicorn._async_agnostic import Server
from uvicorn.config import Config
@contextlib.contextmanager
def no_ssl_verification(session=requests.Session): # type: ignore
old_request = session.request
session.request = partialmethod(old_request, verify=False)
with warnings.catch_warnings():
warnings.simplefilter("ignore", InsecureRequestWarning)
yield
session.request = old_request
async def app(scope: dict, receive: Callable, send: Callable) -> None:
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Skipping SSL test on Windows"
)
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
def test_run(
async_library: str,
tls_ca_certificate_pem_path: str,
tls_ca_certificate_private_key_path: str,
) -> None:
config = Config(
app=app,
async_library=async_library,
limit_max_requests=1,
ssl_keyfile=tls_ca_certificate_private_key_path,
ssl_certfile=tls_ca_certificate_pem_path,
)
with Server(config).run_in_thread():
with no_ssl_verification():
response = requests.get("https://127.0.0.1:8000")
assert response.status_code == 204
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Skipping SSL test on Windows"
)
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
def test_run_chain(async_library: str, tls_certificate_pem_path: str) -> None:
config = Config(
app=app,
async_library=async_library,
limit_max_requests=1,
ssl_certfile=tls_certificate_pem_path,
)
with Server(config).run_in_thread():
with no_ssl_verification():
response = requests.get("https://127.0.0.1:8000")
assert response.status_code == 204
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Skipping SSL test on Windows"
)
@pytest.mark.parametrize("async_library", ["asyncio", "trio"])
def test_run_password(
async_library: str,
tls_ca_certificate_pem_path: str,
tls_ca_certificate_private_key_encrypted_path: str,
) -> None:
config = Config(
app=app,
async_library=async_library,
limit_max_requests=1,
ssl_keyfile=tls_ca_certificate_private_key_encrypted_path,
ssl_certfile=tls_ca_certificate_pem_path,
ssl_keyfile_password="uvicorn password for the win",
)
with Server(config).run_in_thread():
with no_ssl_verification():
response = requests.get("https://127.0.0.1:8000")
assert response.status_code == 204

View File

@ -0,0 +1,8 @@
HTTP11_IMPLEMENTATIONS = ["h11"]
try:
import httptools # noqa
except ImportError: # pragma: no cover
pass
else:
HTTP11_IMPLEMENTATIONS.append("httptools")

View File

@ -0,0 +1,5 @@
from .server import Server
__all__ = [
"Server",
]

View File

@ -0,0 +1,124 @@
import logging
from typing import AsyncIterator, Awaitable, Callable
from .backends.auto import AutoBackend
from .http11.connection import HTTP11Connection
from .utils import STATUS_PHRASES, get_path_with_query_string
class ASGIRequestResponseCycle:
def __init__(
self,
conn: HTTP11Connection,
scope: dict,
request_body: AsyncIterator[bytes],
send_response_body: Callable[[bytes], Awaitable[None]],
access_log: bool,
) -> None:
self._conn = conn
self._scope = scope
self._request_body = request_body
self._do_send_response_body = send_response_body
self._access_log = access_log
self._backend = AutoBackend()
self._logger = logging.getLogger("uvicorn.error")
self._access_logger = logging.getLogger("uvicorn.access")
self._response_started = False
self._response_complete = False
# ASGI exception wrapper
async def run_asgi(self, app: Callable) -> None:
try:
result = await app(self._scope, self._receive, self._send)
except Exception:
raise
if result is not None:
msg = "ASGI callable should return None, but returned '%s'."
raise RuntimeError(msg % result)
if not self._response_started:
msg = "ASGI callable returned without starting response."
raise RuntimeError(msg)
if not self._response_complete:
msg = "ASGI callable returned without completing response."
raise RuntimeError(msg)
async def _send_response(self, message: dict) -> None:
if message["type"] != "http.response.start":
msg = "Expected ASGI message 'http.response.start', but got '%s'."
raise RuntimeError(msg % message["type"])
self._response_started = True
self._waiting_for_100_continue = False
status_code = message["status"]
headers = self._conn.basic_headers() + message.get("headers", [])
reason = STATUS_PHRASES[status_code]
if self._access_log:
self._access_logger.info(
'%s - "%s %s HTTP/%s" %d',
self._conn.client,
self._scope["method"],
get_path_with_query_string(self._scope),
self._scope["http_version"],
status_code,
extra={"status_code": status_code, "scope": self._scope},
)
await self._conn.send_response(
status_code=status_code, headers=headers, reason=reason
)
async def _send_response_body(self, message: dict) -> None:
if message["type"] != "http.response.body":
msg = "Expected ASGI message 'http.response.body', but got '%s'."
raise RuntimeError(msg % message["type"])
body = message.get("body", b"")
more_body = message.get("more_body", False)
if self._scope["method"] == "HEAD":
body = b""
await self._do_send_response_body(body)
if not more_body:
if body != b"":
await self._do_send_response_body(b"")
self._response_complete = True
# ASGI interface
async def _send(self, message: dict) -> None:
if not self._response_started:
await self._send_response(message)
elif not self._response_complete:
await self._send_response_body(message)
else:
# Response already sent
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message["type"])
async def _receive(self) -> dict:
if self._response_complete:
return {"type": "http.disconnect"}
try:
chunk = await self._request_body.__anext__()
except StopAsyncIteration:
chunk = b""
more_body = False
else:
more_body = True
return {
"type": "http.request",
"body": chunk,
"more_body": more_body,
}

View File

@ -0,0 +1,355 @@
import asyncio
import os
import platform
import signal
import socket
from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Tuple
from ...config import Config
from ..compat import asynccontextmanager
from ..state import ServerState
from ..utils import get_sock_local_addr, get_sock_remote_addr
from .base import (
AsyncBackend,
AsyncListener,
AsyncSocket,
Event,
Queue,
TaskHandle,
TaskStatus,
)
HIGH_WATER_LIMIT = 2 ** 16
class AsyncioEvent(Event):
def __init__(self) -> None:
self._event = asyncio.Event()
async def set(self) -> None:
self._event.set()
def is_set(self) -> bool:
return self._event.is_set()
async def wait(self) -> None:
await self._event.wait()
def clear(self) -> None:
self._event.clear()
class AsyncioQueue(Queue):
def __init__(self, size: int) -> None:
self._queue: "asyncio.Queue[Any]" = asyncio.Queue(size)
self._closed = False
async def get(self) -> Any:
if self._closed:
raise RuntimeError("Queue closed")
return await self._queue.get()
async def put(self, item: Any) -> None:
await self._queue.put(item)
async def aclose(self) -> None:
self._closed = True
class AsyncioSocket(AsyncSocket):
def __init__(
self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter
) -> None:
self._stream_reader = stream_reader
self._stream_writer = stream_writer
def get_local_addr(self) -> Optional[Tuple[str, int]]:
sock = self._stream_writer.get_extra_info("socket")
if sock is not None:
return get_sock_local_addr(sock)
info = self._stream_writer.get_extra_info("peername")
try:
host, port = info
except ValueError:
return None
else:
return str(host), int(port)
def get_remote_addr(self) -> Optional[Tuple[str, int]]:
sock = self._stream_writer.get_extra_info("socket")
if sock is not None:
return get_sock_remote_addr(sock)
info = self._stream_writer.get_extra_info("peername")
try:
host, port = info
except ValueError:
return None
else:
return str(host), int(port)
@property
def is_ssl(self) -> bool:
transport = self._stream_writer.transport
return bool(transport.get_extra_info("sslcontext"))
async def read(self, n: int) -> bytes:
return await self._stream_reader.read(n)
async def write(self, data: bytes) -> None:
self._stream_writer.write(data)
await self._stream_writer.drain()
async def send_eof(self) -> None:
try:
self._stream_writer.write_eof()
except (NotImplementedError, OSError, RuntimeError):
pass # Likely SSL connection
async def aclose(self) -> None:
try:
self._stream_writer.close()
await self._stream_writer.wait_closed()
except (BrokenPipeError, ConnectionResetError):
pass # Already closed
@property
def is_closed(self) -> bool:
return self._stream_writer.is_closing()
class AsyncioListener(AsyncListener):
def __init__(self, sock: socket.SocketType) -> None:
self._sock = sock
@property
def socket(self) -> socket.SocketType:
return self._sock
class AsyncioTaskHandle(TaskHandle):
def __init__(self, cancel_event: asyncio.Event) -> None:
self._cancel_event = cancel_event
async def cancel(self) -> None:
self._cancel_event.set()
class AsyncioBackend(AsyncBackend):
def create_event(self) -> Event:
return AsyncioEvent()
def create_queue(self, size: int) -> Queue:
return AsyncioQueue(size)
async def sleep(self, seconds: float) -> None:
await asyncio.sleep(seconds)
def _get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# We're probably in a new thread.
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run(self, async_fn: Callable, *args: Any) -> None:
loop = self._get_event_loop()
loop.run_until_complete(async_fn(*args))
async def move_on_after(
self, seconds: float, async_fn: Callable, *args: Any
) -> None:
try:
await asyncio.wait_for(async_fn(*args), seconds)
except asyncio.TimeoutError:
return
@asynccontextmanager
async def start_soon(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncIterator[None]:
loop = self._get_event_loop()
task = loop.create_task(async_fn(*args))
try:
yield
finally:
if cancel_on_exit:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@asynccontextmanager
async def start(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncIterator[Any]:
loop = self._get_event_loop()
task_status = self.create_task_status()
task = loop.create_task(async_fn(*args, task_status=task_status))
try:
value = await task_status.get_value()
yield value
finally:
if cancel_on_exit:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def wait_then_call(
self,
seconds: float,
async_fn: Callable,
*args: Any,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
cancel_event = asyncio.Event()
await task_status.started(AsyncioTaskHandle(cancel_event))
# Wait for the user to cancel the callback, or for the timeout to expire.
# Be sure to clean up after asyncio.
tasks: set = {
asyncio.create_task(asyncio.sleep(seconds)),
asyncio.create_task(cancel_event.wait()),
}
try:
_, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
else:
for task in pending:
task.cancel()
if cancel_event.is_set():
return
await async_fn(*args)
async def serve_tcp(
self,
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
state: ServerState,
config: Config,
*,
sockets: List[socket.SocketType] = None,
wait_close: Callable,
on_close: Callable = None,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
async def asyncio_handler(
stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter
) -> None:
sock = AsyncioSocket(stream_reader, stream_writer)
await handler(sock, state, config)
servers = []
if sockets is not None:
# Explicitly passed a list of open sockets.
def _win32_share_socket(sock: socket.SocketType) -> socket.SocketType:
# Windows requires the socket be explicitly shared across
# multiple workers (processes).
from socket import fromshare # type: ignore
sock_data = sock.share(os.getpid()) # type: ignore
return fromshare(sock_data)
for sock in sockets:
if config.workers > 1 and platform.system() == "Windows":
sock = _win32_share_socket(sock)
server = await asyncio.start_server(
asyncio_handler,
sock=sock,
ssl=config.ssl,
backlog=config.backlog,
)
servers.append(server)
listener_sockets = sockets
elif config.fd is not None:
# Use an existing socket, from a file descriptor.
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
server = await asyncio.start_server(
asyncio_handler, sock=sock, ssl=config.ssl, backlog=config.backlog
)
assert server.sockets is not None
listener_sockets = server.sockets
servers.append(server)
elif config.uds is not None:
# Create a socket using UNIX domain socket.
uds_perms = 0o666
if os.path.exists(config.uds):
uds_perms = os.stat(config.uds).st_mode
server = await asyncio.start_unix_server(
asyncio_handler,
path=config.uds,
ssl=config.ssl,
backlog=config.backlog,
)
os.chmod(config.uds, uds_perms)
assert server.sockets is not None
listener_sockets = server.sockets
servers.append(server)
else:
# Standard case. Create a socket from a host/port pair.
server = await asyncio.start_server(
asyncio_handler,
host=config.host,
port=config.port,
ssl=config.ssl,
backlog=config.backlog,
)
assert server.sockets is not None
listener_sockets = server.sockets
servers.append(server)
listeners = [AsyncioListener(sock) for sock in listener_sockets]
await task_status.started(listeners)
await wait_close()
# Stop accepting new connections.
for server in servers:
server.close()
for sock in sockets or []:
sock.close()
# Run any custom shutdown behavior.
if on_close is not None:
await on_close()
for server in servers:
await server.wait_closed()
async def listen_signals(
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
) -> None:
signal_event = asyncio.Event()
def wrapped_handler(*args: Any) -> None:
signal_event.set()
loop = self._get_event_loop()
try:
for sig in signals:
loop.add_signal_handler(sig, wrapped_handler, sig, None)
except NotImplementedError:
# Windows
for sig in signals:
signal.signal(sig, wrapped_handler)
while True:
await signal_event.wait()
signal_event.clear()
await handler()

View File

@ -0,0 +1,101 @@
import signal
import socket
from typing import Any, AsyncContextManager, Awaitable, Callable, List
import sniffio
from ...config import Config
from ..state import ServerState
from .base import AsyncBackend, AsyncSocket, Event, Queue, TaskStatus
def select_async_backend(library: str) -> AsyncBackend:
if library == "asyncio":
from .asyncio import AsyncioBackend
return AsyncioBackend()
if library == "trio":
from .trio import TrioBackend
return TrioBackend()
if library == "curio":
from .curio import CurioBackend
return CurioBackend()
raise NotImplementedError(library)
class AutoBackend(AsyncBackend):
@property
def _backend(self) -> AsyncBackend:
if not hasattr(self, "_backend_impl"):
library = sniffio.current_async_library()
self._backend_impl = select_async_backend(library)
return self._backend_impl
def create_event(self) -> Event:
return self._backend.create_event()
def create_queue(self, size: int) -> Queue:
return self._backend.create_queue(size)
async def sleep(self, seconds: float) -> None:
await self._backend.sleep(seconds)
def run(self, async_fn: Callable, *args: Any) -> None:
self._backend.run(async_fn)
async def move_on_after(
self, seconds: float, async_fn: Callable, *args: Any
) -> None:
await self._backend.move_on_after(seconds, async_fn, *args)
def start_soon(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncContextManager[None]:
return self._backend.start_soon(async_fn, *args, cancel_on_exit=cancel_on_exit)
def start(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncContextManager[Any]:
return self._backend.start(async_fn, *args, cancel_on_exit=cancel_on_exit)
async def wait_then_call(
self,
seconds: float,
async_fn: Callable,
*args: Any,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
await self._backend.wait_then_call(
seconds, async_fn, *args, task_status=task_status
)
async def serve_tcp(
self,
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
state: ServerState,
config: Config,
*,
sockets: List[socket.SocketType] = None,
wait_close: Callable,
on_close: Callable = None,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
await self._backend.serve_tcp(
handler,
state,
config,
sockets=sockets,
wait_close=wait_close,
on_close=on_close,
task_status=task_status,
)
async def listen_signals(
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
) -> None:
await self._backend.listen_signals(*signals, handler=handler)

View File

@ -0,0 +1,160 @@
import signal
import socket
from typing import Any, AsyncContextManager, Awaitable, Callable, List, Optional, Tuple
from ...config import Config
from ..state import ServerState
class Event:
async def set(self) -> None:
raise NotImplementedError # pragma: no cover
def is_set(self) -> bool:
raise NotImplementedError # pragma: no cover
async def wait(self) -> None:
raise NotImplementedError # pragma: no cover
def clear(self) -> None:
raise NotImplementedError # pragma: no cover
class Queue:
async def get(self) -> Any:
raise NotImplementedError # pragma: no cover
async def put(self, item: Any) -> None:
raise NotImplementedError # pragma: no cover
async def aclose(self) -> None:
raise NotImplementedError # pragma: no cover
class AsyncSocket:
def get_local_addr(self) -> Optional[Tuple[str, int]]:
raise NotImplementedError # pragma: no cover
def get_remote_addr(self) -> Optional[Tuple[str, int]]:
raise NotImplementedError # pragma: no cover
@property
def is_ssl(self) -> bool:
raise NotImplementedError # pragma: no cover
async def read(self, n: int) -> bytes:
raise NotImplementedError # pragma: no cover
async def write(self, data: bytes) -> None:
raise NotImplementedError # pragma: no cover
async def send_eof(self) -> None:
raise NotImplementedError # pragma: no cover
async def aclose(self) -> None:
raise NotImplementedError # pragma: no cover
@property
def is_closed(self) -> bool:
raise NotImplementedError # pragma: no cover
class AsyncListener:
@property
def socket(self) -> socket.SocketType:
raise NotImplementedError # pragma: no cover
class TaskStatus:
IGNORED: "IgnoredTaskStatus"
def __init__(self, event: Event) -> None:
self._value_event = event
async def started(self, value: Any = None) -> None:
self._value = value
await self._value_event.set()
async def get_value(self) -> Any:
await self._value_event.wait()
assert hasattr(self, "_value")
return self._value
class TaskHandle:
async def cancel(self) -> None:
raise NotImplementedError # pragma: no cover
class IgnoredTaskStatus(TaskStatus):
def __init__(self) -> None:
super().__init__(None) # type: ignore
async def started(self, value: Any = None) -> None:
pass
async def get_value(self) -> Any:
return None
TaskStatus.IGNORED = IgnoredTaskStatus()
class AsyncBackend:
def create_event(self) -> Event:
raise NotImplementedError # pragma: no cover
def create_queue(self, size: int) -> Queue:
raise NotImplementedError # pragma: no cover
def create_task_status(self) -> TaskStatus:
event = self.create_event()
return TaskStatus(event)
async def sleep(self, seconds: float) -> None:
raise NotImplementedError # pragma: no cover
def run(self, async_fn: Callable, *args: Any) -> None:
raise NotImplementedError # pragma: no cover
async def move_on_after(
self, seconds: float, async_fn: Callable, *args: Any
) -> None:
raise NotImplementedError # pragma: no cover
def start_soon(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncContextManager[None]:
raise NotImplementedError # pragma: no cover
def start(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncContextManager[Any]:
raise NotImplementedError # pragma: no cover
async def wait_then_call(
self,
seconds: float,
async_fn: Callable,
*args: Any,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
raise NotImplementedError # pragma: no cover
async def serve_tcp(
self,
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
state: ServerState,
config: Config,
*,
sockets: List[socket.SocketType] = None,
wait_close: Callable,
on_close: Callable = None,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
raise NotImplementedError # pragma: no cover
async def listen_signals(
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
) -> None:
raise NotImplementedError # pragma: no cover

View File

@ -0,0 +1,259 @@
import functools
import signal
import socket
from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Tuple
import curio
from ...config import Config
from ..compat import asynccontextmanager
from ..state import ServerState
from ..utils import get_sock_local_addr, get_sock_remote_addr
from .base import (
AsyncBackend,
AsyncListener,
AsyncSocket,
Event,
Queue,
TaskHandle,
TaskStatus,
)
class CurioEvent(Event):
def __init__(self) -> None:
# TODO: consider curio.UniversalEvent
self._event = curio.Event()
async def set(self) -> None:
await self._event.set()
def is_set(self) -> bool:
return self._event.is_set()
async def wait(self) -> None:
await self._event.wait()
def clear(self) -> None:
self._event.clear()
class CurioSocket(AsyncSocket):
def __init__(self, sock: curio.io.Socket, is_ssl: bool) -> None:
self._sock = sock
self._stream = sock.as_stream()
self._is_closed = False
self._is_ssl = is_ssl
def get_local_addr(self) -> Optional[Tuple[str, int]]:
return get_sock_local_addr(self._sock)
def get_remote_addr(self) -> Optional[Tuple[str, int]]:
return get_sock_remote_addr(self._sock)
@property
def is_ssl(self) -> bool:
return self._is_ssl
async def read(self, n: int) -> bytes:
try:
return await self._stream.read(n)
except ValueError: # TODO
return b""
async def write(self, data: bytes) -> None:
try:
await self._stream.write(data)
except ValueError: # TODO
pass
async def send_eof(self) -> None:
pass
async def aclose(self) -> None:
await self._sock.close()
self._is_closed = True
@property
def is_closed(self) -> bool:
return self._is_closed
class CurioQueue(Queue):
def __init__(self, size: int) -> None:
self._queue = curio.Queue(size)
async def get(self) -> Any:
return await self._queue.get()
async def put(self, item: Any) -> None:
await self._queue.put(item)
async def aclose(self) -> None:
pass
class CurioListener(AsyncListener):
def __init__(self, sock: socket.SocketType) -> None:
self._sock = sock
@property
def socket(self) -> socket.SocketType:
return self._sock
class CurioTaskHandle(TaskHandle):
def __init__(self, cancel_event: curio.Event) -> None:
self._cancel_event = cancel_event
async def cancel(self) -> None:
await self._cancel_event.set()
class CurioBackend(AsyncBackend):
def create_event(self) -> Event:
return CurioEvent()
def create_queue(self, size: int) -> Queue:
return CurioQueue(size)
async def sleep(self, seconds: float) -> None:
await curio.sleep(seconds)
def run(self, async_fn: Callable, *args: Any) -> None:
curio.run(async_fn, *args)
async def move_on_after(
self, seconds: float, async_fn: Callable, *args: Any
) -> None:
async with curio.ignore_after(seconds):
await async_fn(*args)
@asynccontextmanager
async def start_soon(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncIterator[None]:
async with curio.TaskGroup() as g:
task = await g.spawn(async_fn, *args)
yield
if cancel_on_exit:
await task.cancel()
else:
await task.wait()
try:
task.result
except curio.TaskCancelled:
pass
@asynccontextmanager
async def start(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncIterator[Any]:
async with curio.TaskGroup() as g:
task_status = self.create_task_status()
async_fn = functools.partial(async_fn, task_status=task_status)
task = await g.spawn(async_fn, *args)
value = await task_status.get_value()
yield value
if cancel_on_exit:
await task.cancel()
else:
await task.wait()
try:
task.result
except curio.TaskCancelled:
pass
async def wait_then_call(
self,
seconds: float,
async_fn: Callable,
*args: Any,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
cancel_event = curio.Event()
await task_status.started(CurioTaskHandle(cancel_event))
async with curio.TaskGroup(wait=any) as g:
await g.spawn(curio.sleep, seconds)
await g.spawn(cancel_event.wait)
if cancel_event.is_set():
return
await async_fn()
async def serve_tcp(
self,
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
state: ServerState,
config: Config,
*,
sockets: List[socket.SocketType] = None,
wait_close: Callable,
on_close: Callable = None,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
async def client_connected_task(csock: curio.io.Socket, addr: Any) -> None:
sock = CurioSocket(csock, is_ssl=bool(config.ssl))
await handler(sock, state, config)
async with curio.TaskGroup() as g:
if sockets is not None:
# Explicitly passed a list of open sockets.
listener_sockets = sockets
elif config.fd is not None:
# Use an existing socket, from a file descriptor.
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
listener_sockets = [sock]
elif config.uds is not None:
# Create a socket using UNIX domain socket.
sock = curio.unix_server_socket(config.uds, backlog=config.backlog)
listener_sockets = [sock]
else:
# Standard case. Create a socket from a host/port pair.
sock = curio.tcp_server_socket(
config.host, config.port, backlog=config.backlog
)
listener_sockets = [sock]
for sock in listener_sockets:
run_server = functools.partial(curio.network.run_server, ssl=config.ssl)
await g.spawn(run_server, sock, client_connected_task)
value = [CurioListener(sock) for sock in listener_sockets]
await task_status.started(value)
await wait_close()
# Run any custom shutdown behavior.
if on_close is not None:
await on_close()
# Connections are properly closed, we can go ahead and hard-stop
# the servers.
await g.cancel_remaining()
async def listen_signals(
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
) -> None:
# https://curio.readthedocs.io/en/latest/howto.html#how-do-you-catch-signals
signal_event = curio.UniversalEvent()
def wrapped_handler(*args: Any) -> None:
signal_event.set()
for sig in signals:
signal.signal(sig, wrapped_handler)
while True:
await signal_event.wait()
signal_event.clear()
await handler()

View File

@ -0,0 +1,289 @@
import functools
import signal
import socket
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
List,
Optional,
Sequence,
Tuple,
Union,
)
import trio
from ...config import Config
from ..compat import asynccontextmanager
from ..exceptions import BrokenSocket
from ..state import ServerState
from ..utils import get_sock_local_addr, get_sock_remote_addr
from .base import (
AsyncBackend,
AsyncListener,
AsyncSocket,
Event,
Queue,
TaskHandle,
TaskStatus,
)
class TrioEvent(Event):
def __init__(self) -> None:
self._event = trio.Event()
async def set(self) -> None:
self._event.set()
def is_set(self) -> bool:
return self._event.is_set()
async def wait(self) -> None:
await self._event.wait()
def clear(self) -> None:
self._event = trio.Event()
class TrioSocket(AsyncSocket):
def __init__(self, stream: Union[trio.SocketStream, trio.SSLStream]) -> None:
self._stream = stream
self._is_closed = False
def _unwrap_ssl(self) -> trio.SocketStream:
stream: trio.abc.Stream = self._stream
while isinstance(stream, trio.SSLStream):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)
return stream
def get_local_addr(self) -> Optional[Tuple[str, int]]:
stream = self._unwrap_ssl()
return get_sock_local_addr(stream.socket)
def get_remote_addr(self) -> Optional[Tuple[str, int]]:
stream = self._unwrap_ssl()
return get_sock_remote_addr(stream.socket)
@property
def is_ssl(self) -> bool:
return isinstance(self._stream, trio.SSLStream)
async def read(self, n: int) -> bytes:
try:
return await self._stream.receive_some(n)
except (trio.BrokenResourceError, trio.ClosedResourceError):
return b""
async def write(self, data: bytes) -> None:
try:
await self._stream.send_all(data)
except trio.BrokenResourceError:
pass
async def send_eof(self) -> None:
stream = self._unwrap_ssl()
try:
await stream.send_eof()
except trio.BrokenResourceError:
raise BrokenSocket()
async def aclose(self) -> None:
await self._stream.aclose()
self._is_closed = True
@property
def is_closed(self) -> bool:
return self._is_closed
class TrioQueue(Queue):
def __init__(self, size: int) -> None:
self._send_channel, self._receive_channel = trio.open_memory_channel[Any](size)
async def get(self) -> Any:
return await self._receive_channel.receive()
async def put(self, item: Any) -> None:
try:
await self._send_channel.send(item)
except (trio.BrokenResourceError, trio.ClosedResourceError):
pass # Already closed.
async def aclose(self) -> None:
try:
await self._receive_channel.aclose()
except trio.ClosedResourceError:
pass # Already closed.
try:
await self._send_channel.aclose()
except trio.ClosedResourceError:
pass # Already closed.
class TrioListener(AsyncListener):
def __init__(self, listener: trio.abc.Listener) -> None:
self._listener = listener
def _unwrap_ssl(self) -> trio.SocketListener:
listener = self._listener
# Unwrap SSL.
while isinstance(listener, trio.SSLListener):
listener = listener.transport_listener
assert isinstance(listener, trio.SocketListener)
return listener
@property
def socket(self) -> socket.SocketType:
listener = self._unwrap_ssl()
return listener.socket
class TrioTaskHandle(TaskHandle):
def __init__(self, cancel_scope: trio.CancelScope) -> None:
self._cancel_scope = cancel_scope
async def cancel(self) -> None:
self._cancel_scope.cancel()
class TrioBackend(AsyncBackend):
def create_event(self) -> Event:
return TrioEvent()
def create_queue(self, size: int) -> Queue:
return TrioQueue(size)
async def sleep(self, seconds: float) -> None:
await trio.sleep(seconds)
def run(self, async_fn: Callable, *args: Any) -> None:
trio.run(async_fn, *args)
async def move_on_after(
self, seconds: float, async_fn: Callable, *args: Any
) -> None:
with trio.move_on_after(seconds):
await async_fn(*args)
@asynccontextmanager
async def start_soon(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncIterator[None]:
async with trio.open_nursery() as nursery:
nursery.start_soon(async_fn, *args)
yield
if cancel_on_exit:
nursery.cancel_scope.cancel()
@asynccontextmanager
async def start(
self, async_fn: Callable, *args: Any, cancel_on_exit: bool = False
) -> AsyncIterator[Any]:
async with trio.open_nursery() as nursery:
task_status = self.create_task_status()
async_fn = functools.partial(async_fn, task_status=task_status)
nursery.start_soon(async_fn, *args)
value = await task_status.get_value()
yield value
if cancel_on_exit:
nursery.cancel_scope.cancel()
async def wait_then_call(
self,
seconds: float,
async_fn: Callable,
*args: Any,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
cancel_scope = trio.CancelScope()
await task_status.started(TrioTaskHandle(cancel_scope))
with cancel_scope:
await trio.sleep(seconds)
cancel_scope.shield = True
await async_fn()
async def serve_tcp(
self,
handler: Callable[[AsyncSocket, ServerState, Config], Awaitable[None]],
state: ServerState,
config: Config,
*,
sockets: List[socket.SocketType] = None,
wait_close: Callable,
on_close: Callable = None,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
async def trio_handler(stream: trio.SocketStream) -> None:
sock = TrioSocket(stream)
await handler(sock, state, config)
async with trio.open_nursery() as nursery:
listeners: Sequence[trio.abc.Listener] = []
if sockets is not None:
# Explicitly passed a list of open sockets.
# (We need to wrap them around the trio equivalents.)
trio_sockets = []
for sock in sockets:
sock.listen(config.backlog)
trio_socket = trio.socket.fromfd(
sock.fileno(), sock.family, sock.type
)
trio_sockets.append(trio_socket)
listeners = [trio.SocketListener(sock) for sock in trio_sockets]
elif config.fd is not None:
# Use an existing socket, from a file descriptor.
sock = trio.socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
listeners = [trio.SocketListener(sock)]
elif config.uds is not None:
# Create a socket using UNIX domain socket.
# XXX: trio does not provide highlevel support for UDS servers yet.
# See: https://github.com/python-trio/trio/issues/279
# This may land soon-ish, though.
# See: https://github.com/python-trio/trio/pull/1433
# Use the API proposed in the RFC above as a draft.
listeners = await trio.open_unix_listeners( # type: ignore
config.uds, backlog=config.backlog
)
else:
# Standard case. Create a socket from a host/port pair.
listeners = await trio.open_tcp_listeners(
config.port, host=config.host, backlog=config.backlog
)
if config.ssl:
listeners = [
trio.SSLListener(listener, config.ssl, https_compatible=False)
for listener in listeners
]
listeners = await nursery.start(
trio.serve_listeners, trio_handler, listeners
)
value = [TrioListener(listener) for listener in listeners]
await task_status.started(value)
await wait_close()
# Run any custom shutdown behavior.
if on_close is not None:
await on_close()
# Connections are properly closed, we can go ahead and hard-stop
# the server.
nursery.cancel_scope.cancel()
async def listen_signals(
self, *signals: signal.Signals, handler: Callable[[], Awaitable[None]]
) -> None:
with trio.open_signal_receiver(*signals) as signal_receiver:
async for _ in signal_receiver:
await handler()

View File

@ -0,0 +1,12 @@
try:
from contextlib import AsyncExitStack
except ImportError: # pragma: no cover
from async_exit_stack import AsyncExitStack # type: ignore
try:
from contextlib import asynccontextmanager
except ImportError: # pragma: no cover
from async_generator import asynccontextmanager # type: ignore
__all__ = ["AsyncExitStack", "asynccontextmanager"]

View File

@ -0,0 +1,10 @@
class BrokenSocket(Exception):
pass
class ProtocolError(Exception):
pass
class LifespanFailure(Exception):
pass

View File

@ -0,0 +1,230 @@
import itertools
import logging
import time
from typing import Any, AsyncIterator, List, Optional, Tuple
from ..backends.auto import AutoBackend
from ..backends.base import AsyncSocket
from ..exceptions import BrokenSocket, ProtocolError
from ..utils import STATUS_PHRASES, find_upgrade_header, to_internet_date
from .parsers.base import Event, HTTP11Parser
TRACE_LOG_LEVEL = 5
NEXT_ID = itertools.count()
class HTTP11Connection:
MAX_RECV = 2 ** 16
def __init__(
self,
sock: AsyncSocket,
default_headers: List[Tuple[bytes, bytes]],
parser: HTTP11Parser,
) -> None:
self._sock = sock
self._default_headers = default_headers
self._parser = parser
self._obj_id = next(NEXT_ID)
self._logger = logging.getLogger("uvicorn.error")
self._backend = AutoBackend()
def trace(self, msg: str, *args: Any) -> None:
self._logger.log(TRACE_LOG_LEVEL, f"conn(%s): {msg}", self._obj_id, *args)
def debug(self, msg: str, *args: Any) -> None:
self._logger.debug(f"conn(%s): {msg}", self._obj_id, *args)
@property
def scheme(self) -> str:
return "https" if self._sock.is_ssl else "http"
@property
def server(self) -> Optional[Tuple[str, int]]:
return self._sock.get_local_addr()
@property
def client(self) -> Optional[Tuple[str, int]]:
return self._sock.get_remote_addr()
def basic_headers(self) -> List[Tuple[bytes, bytes]]:
return [
(b"date", to_internet_date(time.time()).encode("utf-8")),
] + self._default_headers
# State machine helpers
def states(self) -> dict:
return self._parser.states()
async def _send_event(self, event: Event) -> None:
if event["type"] == "Response":
self.trace(
"send_event event=Response("
"status_code=%d, headers=<Headers(...)>, reason=%s)",
event["status_code"],
event["reason"],
)
elif event["type"] == "Data":
self.trace("send_event event=Data(<%d bytes>)", len(event["data"]))
elif event["type"] == "EndOfMessage":
self.trace("send_event event=EndOfMessage(headers=<Headers(...)>")
else:
assert event["type"] == "ConnectionClosed"
self.trace("send_event event=ConnectionClosed()")
data = self._parser.send(event)
if data is None:
assert event["type"] == "ConnectionClosed", event
await self._sock.write(b"")
await self.shutdown_and_clean_up()
else:
await self._sock.write(data)
async def _read_from_peer(self) -> None:
if self._parser.they_are_waiting_for_100_continue:
self.trace("Sending 100 Continue")
await self._send_event({"type": "InformationalResponse"})
data = await self._sock.read(self.MAX_RECV)
self.trace("read_data Data(<%d bytes>)", len(data))
self._parser.receive_data(data)
async def _receive_event(self) -> Any:
while True:
try:
event = self._parser.next_event()
except ProtocolError as exc:
raise ProtocolError(f"Invalid HTTP request received: {exc}")
if event["type"] == "NEED_DATA":
await self._read_from_peer()
continue
if event["type"] == "Request":
self.trace(
"receive_event event=Request("
"http_version=%s, method=%s, target=%s, headers=...)",
event["http_version"],
event["method"],
event["target"],
)
elif event["type"] == "Data":
self.trace("receive_event event=Data(<%d bytes>)", len(event["data"]))
elif event["type"] == "EndOfMessage":
self.trace("receive_event event=EndOfMessage(headers=<Headers(...)>")
else:
assert event["type"] == "ConnectionClosed"
self.trace("receive_event event=ConnectionClosed()")
return event
async def read_request(
self,
) -> Tuple[bytes, bytes, bytes, List[Tuple[bytes, bytes]], Optional[bytes]]:
event = await self._receive_event()
if event["type"] == "ConnectionClosed":
raise BrokenSocket("Client has disconnected")
assert event["type"] == "Request"
http_version: bytes = event["http_version"]
method: bytes = event["method"]
path: bytes = event["target"]
headers = [(key.lower(), value) for key, value in event["headers"]]
upgrade = find_upgrade_header(headers)
return (http_version, method, path, headers, upgrade)
async def aiter_request_body(self) -> AsyncIterator[bytes]:
async def receive_data() -> bytes:
event = await self._receive_event()
if event["type"] == "EndOfMessage":
return b""
assert event["type"] == "Data"
return event["data"]
async def request_body(data: bytes) -> AsyncIterator[bytes]:
while data:
yield data
data = await receive_data()
# Read at least one event so that we get a chance of seeing `EndOfMessage`
# right away in case the client does not send a body (eg HEAD or GET requests).
initial = await receive_data()
return request_body(initial)
async def send_response(
self, status_code: int, headers: List[Tuple[bytes, bytes]], reason: bytes = b""
) -> None:
if not reason:
reason = STATUS_PHRASES[status_code]
event = {
"type": "Response",
"status_code": status_code,
"headers": headers,
"reason": reason,
}
await self._send_event(event)
async def send_simple_response(
self, status_code: int, content_type: str, body: bytes
) -> None:
self.trace("send_simple_response %d (%d bytes)", status_code, len(body))
headers = self.basic_headers() + [
(b"Content-Type", content_type.encode("utf-8")),
(b"Content-Length", str(len(body)).encode("utf-8")),
]
await self.send_response(status_code=status_code, headers=headers)
await self._send_event({"type": "Data", "data": body})
await self._send_event({"type": "EndOfMessage"})
async def send_response_body(self, chunk: bytes) -> None:
if chunk:
event = {"type": "Data", "data": chunk}
else:
event = {"type": "EndOfMessage"}
await self._send_event(event)
def set_keepalive(self) -> None:
try:
self._parser.start_next_cycle()
except ProtocolError:
raise
async def trigger_shutdown(self) -> None:
self.trace("triggering shutdown")
states = self._parser.states()
if states["server"] in {"IDLE", "DONE"}:
await self._send_event({"type": "ConnectionClosed"})
async def shutdown_and_clean_up(self) -> None:
self.trace("shutting down")
try:
await self._sock.send_eof()
except BrokenSocket:
self.trace("failed to send EOF: client is already gone")
return
self.trace("EOF sent")
# Wait and read for a bit to give them a chance to see that we closed
# things, but eventually give up and just close the socket.
async def attempt_read_until_eof() -> None:
try:
while True:
data = await self._sock.read(self.MAX_RECV)
if not data:
self.trace("EOF acknowledged by peer")
break
except Exception:
pass # It broke.
try:
await self._backend.move_on_after(5, attempt_read_until_eof)
finally:
await self._sock.aclose()

View File

@ -0,0 +1,155 @@
import logging
from typing import AsyncIterator, List, Tuple
from urllib.parse import unquote
from uvicorn.config import Config
from ..asgi import ASGIRequestResponseCycle
from ..backends.base import AsyncSocket
from ..exceptions import BrokenSocket, ProtocolError
from ..state import ServerState
from .connection import HTTP11Connection
from .keepalive import KeepAlive
from .parsers.h11 import H11Parser
try:
import httptools
from .parsers.httptools import HttpToolsParser
except ImportError: # pragma: no cover
httptools = None # type: ignore
HttpToolsParser = None # type: ignore
def create_http11_connection(
sock: AsyncSocket, state: ServerState, config: Config
) -> HTTP11Connection:
use_httptools = config.http == "httptools" or (
config.http == "auto" and httptools is not None
)
parser = HttpToolsParser() if use_httptools else H11Parser()
return HTTP11Connection(sock, default_headers=state.default_headers, parser=parser)
async def handle_http11(sock: AsyncSocket, state: ServerState, config: Config) -> None:
if not config.loaded:
config.load()
logger = logging.getLogger("uvicorn.error")
conn = create_http11_connection(sock, state, config)
keepalive = KeepAlive(conn, config)
state.connections.add(conn)
conn.debug("Connection made")
while True:
assert conn.states() == {"client": "IDLE", "server": "IDLE"}
try:
(
http_version,
method,
path,
headers,
upgrade,
) = await conn.read_request()
assert http_version == b"1.1", http_version
assert upgrade is None, "WebSocket not supported yet"
request_body = await conn.aiter_request_body()
await send_h11_response(
conn,
config=config,
method=method,
path=path,
headers=headers,
request_body=request_body,
)
except BrokenSocket:
break
except Exception as exc:
logger.error("Error while responding to request: %s", exc, exc_info=exc)
await maybe_send_error_response(conn)
else:
state.total_requests += 1
states = conn.states()
if states["server"] == "MUST_CLOSE":
conn.trace("Connection is not reusable, shutting down")
break
conn.trace("Trying to reuse connection")
try:
conn.set_keepalive()
except ProtocolError as exc:
conn.trace("Connection is not reusable, bailing out: %s ", exc)
await maybe_send_error_response(conn)
break
else:
await keepalive.reset()
await keepalive.schedule()
conn.debug("Connection kept alive")
await keepalive.aclose()
await conn.shutdown_and_clean_up()
state.connections.discard(conn)
conn.debug("Connection closed")
async def send_h11_response(
conn: HTTP11Connection,
*,
config: Config,
method: bytes,
path: bytes,
headers: List[Tuple[bytes, bytes]],
request_body: AsyncIterator[bytes],
) -> None:
conn.trace("Sending response")
raw_path, _, query_string = path.partition(b"?")
scope = {
"type": "http",
"asgi": {
"version": "3.0",
"spec_version": "2.1",
},
"http_version": "1.1",
"server": conn.server,
"client": conn.client,
"scheme": conn.scheme,
"method": method.decode("ascii"),
"root_path": config.root_path,
"path": unquote(raw_path.decode("ascii")),
"raw_path": raw_path,
"query_string": query_string,
"headers": headers,
}
send_response_body = conn.send_response_body
cycle = ASGIRequestResponseCycle(
conn,
scope=scope,
request_body=request_body,
send_response_body=send_response_body,
access_log=config.access_log,
)
app = config.loaded_app
await cycle.run_asgi(app)
async def maybe_send_error_response(conn: HTTP11Connection) -> None:
states = conn.states()
if states["server"] not in {"IDLE", "ACTIVE"}:
return
conn.trace("send error response")
status_code = 500
content_type = "text/plain; charset=utf-8"
body = b"Internal Server Error"
try:
await conn.send_simple_response(status_code, content_type, body)
except Exception as exc:
conn.trace("error while sending error response: %s", exc)

View File

@ -0,0 +1,52 @@
from typing import Optional
from uvicorn.config import Config
from ..backends.auto import AutoBackend
from ..backends.base import TaskHandle, TaskStatus
from ..compat import AsyncExitStack
from .connection import HTTP11Connection
class KeepAlive:
def __init__(self, conn: HTTP11Connection, config: Config) -> None:
self._conn = conn
self._config = config
self._backend = AutoBackend()
self._task_handle: Optional[TaskHandle] = None
self._exit_stack = AsyncExitStack()
async def schedule(self) -> None:
assert self._task_handle is None
self._task_handle = await self._exit_stack.enter_async_context(
self._backend.start(
self._trigger_shutdown_after_expiry,
cancel_on_exit=True,
),
)
async def _trigger_shutdown_after_expiry(
self, *, task_status: TaskStatus = TaskStatus.IGNORED
) -> None:
timeout = self._config.timeout_keep_alive
self._conn.trace("keep-alive expiry scheduled in %d seconds", timeout)
await self._backend.wait_then_call(
timeout,
async_fn=self._trigger_shutdown,
task_status=task_status,
)
async def _trigger_shutdown(self) -> None:
self._conn.trace("keep-alive expired")
await self._conn.trigger_shutdown()
async def reset(self) -> None:
if self._task_handle is not None:
self._conn.trace("keep-alive reset")
await self._task_handle.cancel()
self._task_handle = None
async def aclose(self) -> None:
await self._exit_stack.aclose()
self._conn.trace("keep-alive expiry cancelled")

View File

@ -0,0 +1,73 @@
from typing import Any, Dict, Optional
# We use the same state machine approach than `h11` because it's very convenient
# to program against.
#
# NOTE: Everything below is in the `h11` state machine docs:
# https://h11.readthedocs.io/en/latest/api.html#the-state-machine
# We describe the expected state machine to explicitly state the expected contract.
#
# Possible states
# ---------------
# IDLE
# SEND_RESPONSE
# SEND_BODY
# DONE
# MUST_CLOSE
# CLOSED
# ERROR
#
# Happy path
# ----------
# Client:
# IDLE --(Request)-> SEND_BODY [--(Data)-> SEND_BODY] --(EndOfMessage)-> DONE
# Server:
# IDLE --(Request)-> SEND_RESPONSE [--(InformationalResponse)-> SEND_RESPONSE] \
# --(Response)-> SEND_BODY [--(Data)-> SEND_BODY] --(EndOfMessage)-> DONE
#
# Keep-Alive
# ----------
# ENABLED --(HTTP/1.0 | Connection: close)-> DISABLED
# Client + Server: DONE --(KeepAlive ENABLED)-> IDLE
#
# Closing paths
# -------------
# IDLE --(ConnectionClosed)-> CLOSED
# IDLE --(Peer CLOSED)-> MUST_CLOSE
# DONE --(KeepAlive DISABLED | Peer CLOSED | Peer ERROR)-> MUST_CLOSE
# MUST_CLOSE --(ConnectionClosed)-> CLOSED
# DONE --(ConnectionClosed)-> CLOSED
# CLOSED --(ConnectionClosed)-> CLOSED
#
# Error paths
# -----------
# * --> ERROR
# NOTE: this path is managed by `h11` internally. For other libraries, the ERROR state
# should be entered anytime the parser raises a parsing exception.
Event = Dict[str, Any]
class HTTP11Parser:
"""
An event-based HTTP/1.1 parser interface, inspired by `h11`.
"""
def states(self) -> dict:
raise NotImplementedError # pragma: no cover
@property
def they_are_waiting_for_100_continue(self) -> bool:
raise NotImplementedError # pragma: no cover
def receive_data(self, data: bytes) -> None:
raise NotImplementedError # pragma: no cover
def next_event(self) -> Event:
raise NotImplementedError # pragma: no cover
def send(self, event: Event) -> Optional[bytes]:
raise NotImplementedError # pragma: no cover
def start_next_cycle(self) -> None:
raise NotImplementedError # pragma: no cover

View File

@ -0,0 +1,98 @@
from typing import Any, Optional
import h11
from ...exceptions import ProtocolError
from .base import Event, HTTP11Parser
H11_STATES_MAP = {
h11.IDLE: "IDLE",
h11.SEND_RESPONSE: "SEND_RESPONSE",
h11.SEND_BODY: "SEND_BODY",
h11.DONE: "DONE",
h11.MUST_CLOSE: "MUST_CLOSE",
h11.CLOSED: "CLOSED",
h11.ERROR: "ERROR",
}
class H11Parser(HTTP11Parser):
"""
An HTTP/1.1 parser backed by the `h11` library.
"""
def __init__(self) -> None:
self._h11_state = h11.Connection(h11.SERVER)
def states(self) -> dict:
return {
"client": H11_STATES_MAP[self._h11_state.their_state],
"server": H11_STATES_MAP[self._h11_state.our_state],
}
@property
def they_are_waiting_for_100_continue(self) -> bool:
return self._h11_state.they_are_waiting_for_100_continue
def receive_data(self, data: bytes) -> None:
self._h11_state.receive_data(data)
def next_event(self) -> Event:
event = self._h11_state.next_event()
return from_h11_event(event)
def send(self, event: Event) -> Optional[bytes]:
h11_event = to_h11_event(event)
return self._h11_state.send(h11_event)
def start_next_cycle(self) -> None:
try:
self._h11_state.start_next_cycle()
except h11.ProtocolError as exc:
raise ProtocolError(exc)
def from_h11_event(event: Any) -> Event:
if event is h11.NEED_DATA:
return {"type": "NEED_DATA"}
if isinstance(event, h11.Request):
return {
"type": "Request",
"http_version": event.http_version,
"method": event.method,
"target": event.target,
"headers": event.headers,
}
if isinstance(event, h11.ConnectionClosed):
return {"type": "ConnectionClosed"}
if isinstance(event, h11.Data):
return {"type": "Data", "data": event.data}
if isinstance(event, h11.EndOfMessage):
return {"type": "EndOfMessage"}
raise RuntimeError(f"Unknown event type: {type(event)}")
def to_h11_event(event: Event) -> Any:
if event["type"] == "InformationalResponse":
return h11.InformationalResponse(status_code=100)
if event["type"] == "Response":
return h11.Response(
status_code=event["status_code"],
headers=event["headers"],
reason=event["reason"],
)
if event["type"] == "Data":
return h11.Data(data=event["data"])
if event["type"] == "EndOfMessage":
return h11.EndOfMessage()
assert event["type"] == "ConnectionClosed"
return h11.ConnectionClosed()

View File

@ -0,0 +1,345 @@
import re
from typing import Any, Dict, List, Optional, Tuple, Union
import httptools
from ...exceptions import ProtocolError
from .base import Event, HTTP11Parser
HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")
class HttpToolsParser(HTTP11Parser):
"""
An HTTP/1.1 parser backed by the `httptools` library.
"""
def __init__(self) -> None:
self._parser = httptools.HttpRequestParser(self)
self._state = State()
# Parser-specific state.
self._parsed_url: Optional[Any] = None
self._headers: List[Tuple[bytes, bytes]] = []
self._chunked_encoding: Optional[bool] = None
self._expected_content_length: Optional[int] = None
# Parser API.
def states(self) -> dict:
return self._state.states()
@property
def they_are_waiting_for_100_continue(self) -> bool:
return self._state.client_waiting_for_100_continue
def receive_data(self, data: bytes) -> None:
try:
self._parser.feed_data(data)
except (
httptools.HttpParserInvalidMethodError,
httptools.HttpParserInvalidURLError,
httptools.HttpParserError,
) as exc:
self._state.process_error("client")
raise ProtocolError(exc)
except (
httptools.HttpParserInvalidStatusError,
httptools.HttpParserCallbackError,
) as exc:
self._state.process_error("server")
raise ProtocolError(exc)
if not data:
self._state.process_client_event({"type": "ConnectionClosed"})
def next_event(self) -> Event:
return self._state.next_event()
def send(self, event: Event) -> Optional[bytes]:
self._state.process_server_event(event)
if event["type"] == "InformationalResponse":
return self._render_informational_response()
if event["type"] == "Response":
if self._parser.get_method() == b"HEAD":
self._expected_content_length = 0
status_code = event["status_code"]
headers = event["headers"]
reason = event["reason"]
return self._render_response(status_code, headers, reason)
if event["type"] == "Data":
body = event["data"]
return self._render_response_body(body)
if event["type"] == "EndOfMessage":
num_bytes_remaining = self._expected_content_length or 0
if num_bytes_remaining != 0:
raise ProtocolError(
"Too little data for declared Content-Length: "
f"{num_bytes_remaining} remaining"
)
return b""
assert event["type"] == "ConnectionClosed"
return None
def start_next_cycle(self) -> None:
try:
self._state.start_next_cycle()
except ProtocolError:
raise
# Reset.
self._parsed_url = None
self._headers.clear()
self._chunked_encoding = None
self._expected_content_length = None
# Response rendering helpers.
def _render_informational_response(self) -> bytes:
return b"HTTP/1.1 100 Continue\r\n\r\n"
def _render_response(
self, status_code: int, headers: List[Tuple[bytes, bytes]], reason: bytes
) -> bytes:
status_line = b"".join(
[b"HTTP/1.1", b" ", str(status_code).encode("utf-8"), b" ", reason, b"\r\n"]
)
content = [status_line]
for name, value in headers:
if HEADER_RE.search(name):
raise RuntimeError("Invalid HTTP header name")
if HEADER_VALUE_RE.search(value):
raise RuntimeError("Invalid HTTP header value")
name = name.lower()
if name == b"content-length" and self._chunked_encoding is None:
self._expected_content_length = int(value.decode("ascii"))
self._chunked_encoding = False
elif name == b"transfer-encoding" and value.lower() == b"chunked":
self._expected_content_length = 0
self._chunked_encoding = True
elif name == b"connection" and value.lower() == b"close":
self._state.process_keep_alive_disabled()
content.extend([name, b": ", value, b"\r\n"])
if (
self._chunked_encoding is None
and self._parser.get_method() != b"HEAD"
and status_code not in (204, 304)
):
# Neither content-length nor transfer-encoding specified
self._chunked_encoding = True
content.append(b"transfer-encoding: chunked\r\n")
content.append(b"\r\n")
if self._chunked_encoding:
content.append(b"0\r\n\r\n")
return b"".join(content)
def _render_response_body(self, body: bytes) -> bytes:
if self._chunked_encoding:
content = [b"%x\r\n" % len(body), body, b"\r\n"] if body else []
content.append(b"0\r\n\r\n")
return b"".join(content)
assert self._expected_content_length is not None
if len(body) > self._expected_content_length:
raise RuntimeError("Response content longer than Content-Length")
self._expected_content_length -= len(body)
return body
# HttpTools callbacks.
def on_message_begin(self) -> None:
if self._parser.get_http_version() == "1.0":
self._state.process_keep_alive_disabled()
def on_url(self, url: str) -> None:
assert self._parsed_url is None
self._parsed_url = httptools.parse_url(url)
def on_header(self, name: bytes, value: bytes) -> None:
name = name.lower()
if name == b"expect" and value.lower() == b"100-continue":
self._state.process_expect_100_continue()
if name == b"connection" and value.lower() == b"close":
self._state.process_keep_alive_disabled()
self._headers.append((name, value))
def on_headers_complete(self) -> None:
assert self._parsed_url is not None
target = self._parsed_url.path
if self._parsed_url.query:
target += b"?%s" % self._parsed_url.query
event = {
"type": "Request",
"http_version": self._parser.get_http_version().encode("ascii"),
"method": self._parser.get_method(),
"target": target,
"headers": self._headers,
}
self._state.process_client_event(event)
def on_body(self, body: bytes) -> None:
event = {"type": "Data", "data": body}
self._state.process_client_event(event)
def on_message_complete(self) -> None:
event = {"type": "EndOfMessage"}
self._state.process_client_event(event)
# Code below is adapted from h11's state machine code.
EVENT_TRIGGERED_TRANSITIONS: Dict[
str, Dict[str, Dict[Union[str, Tuple[str, str]], str]]
] = {
# Role -> State -> Event -> New state
"client": {
"IDLE": {"Request": "SEND_BODY", "ConnectionClosed": "CLOSED"},
"SEND_BODY": {"Data": "SEND_BODY", "EndOfMessage": "DONE"},
"DONE": {"ConnectionClosed": "CLOSED"},
"MUST_CLOSE": {"ConnectionClosed": "CLOSED"},
"CLOSED": {"ConnectionClosed": "CLOSED"},
"ERROR": {},
},
"server": {
"IDLE": {
"ConnectionClosed": "CLOSED",
"Response": "SEND_BODY",
# Special case: server sees client Request events.
("Request", "client"): "SEND_RESPONSE",
},
"SEND_RESPONSE": {
"InformationalResponse": "SEND_RESPONSE",
"Response": "SEND_BODY",
},
"SEND_BODY": {"Data": "SEND_BODY", "EndOfMessage": "DONE"},
"DONE": {"ConnectionClosed": "CLOSED"},
"MUST_CLOSE": {"ConnectionClosed": "CLOSED"},
"CLOSED": {"ConnectionClosed": "CLOSED"},
"ERROR": {},
},
}
STATE_TRIGGERED_TRANSITIONS: Dict[Tuple[str, str], Dict[str, str]] = {
# (Client state, Server state) -> New states
# Socket shutdown
("CLOSED", "DONE"): {"server": "MUST_CLOSE"},
("CLOSED", "IDLE"): {"server": "MUST_CLOSE"},
("ERROR", "DONE"): {"server": "MUST_CLOSE"},
("DONE", "CLOSED"): {"client": "MUST_CLOSE"},
("IDLE", "CLOSED"): {"client": "MUST_CLOSE"},
("DONE", "ERROR"): {"clientt": "MUST_CLOSE"},
}
class State:
def __init__(self) -> None:
self._states = {"client": "IDLE", "server": "IDLE"}
self._client_events: List[Event] = []
self._keep_alive_enabled = True
self._expect_100_continue = False
@property
def client_waiting_for_100_continue(self) -> bool:
return self._expect_100_continue
def states(self) -> dict:
return dict(self._states)
def _process_event(self, role: str, event: Event) -> None:
self._fire_event_triggered_transitions(role, event["type"])
if event["type"] == "Request":
# Special case: the server state does get to see Request events.
self._fire_event_triggered_transitions("server", ("Request", "client"))
self._fire_state_triggered_transitions()
def process_client_event(self, event: Event) -> None:
self._process_event("client", event)
self._client_events.append(event)
def process_server_event(self, event: Event) -> None:
self._process_event("server", event)
def process_keep_alive_disabled(self) -> None:
self._keep_alive_enabled = False
def process_expect_100_continue(self) -> None:
assert not self._expect_100_continue
self._expect_100_continue = True
def process_error(self, role: str) -> None:
self._states[role] = "ERROR"
self._fire_state_triggered_transitions() # Peer may have to close.
def _fire_event_triggered_transitions(
self, role: str, event_type: Union[str, Tuple[str, str]]
) -> None:
state = self._states[role]
try:
new_state = EVENT_TRIGGERED_TRANSITIONS[role][state][event_type]
except KeyError:
raise ProtocolError(
f"can't handle event type {event_type} when "
f"role={role.upper()} and state={state}"
)
self._states[role] = new_state
if role == "server" and event_type in {"Response", "InformationalResponse"}:
self._expect_100_continue = False
def _fire_state_triggered_transitions(self) -> None:
# Apply transitions until we converge to a fixed point.
while True:
states = dict(self._states)
if not self._keep_alive_enabled:
for role in ("client", "server"):
if self._states[role] == "DONE":
self._states[role] = "MUST_CLOSE"
joint_state = (self._states["client"], self._states["server"])
changes = STATE_TRIGGERED_TRANSITIONS.get(joint_state, {})
self._states.update(changes)
if self._states == states:
break
def next_event(self) -> Event:
if self._states["client"] == "ERROR":
raise ProtocolError("Can't receive data when peer state is ERROR")
if self._states["client"] == "CLOSED":
return {"type": "ConnectionClosed"}
try:
return self._client_events.pop(0)
except IndexError:
return {"type": "NEED_DATA"}
def start_next_cycle(self) -> None:
if self._states != {"client": "DONE", "server": "DONE"}:
raise ProtocolError(f"Not in a reusable state: {self._states}")
assert self._keep_alive_enabled
assert not self._expect_100_continue
# Reset.
self._states = {"client": "IDLE", "server": "IDLE"}
self._keep_alive_enabled = True
self._client_events.clear()

View File

@ -0,0 +1,103 @@
import logging
from ..config import Config
from .backends.auto import AutoBackend
from .exceptions import LifespanFailure
STATE_TRANSITION_ERROR = "Got invalid state transition on lifespan protocol."
TRACE_LOG_LEVEL = 5
class Lifespan:
def __init__(self, config: Config) -> None:
if not config.loaded:
config.load()
self._config = config
self._logger = logging.getLogger("uvicorn.error")
self._backend = AutoBackend()
self._startup_event = self._backend.create_event()
self._shutdown_event = self._backend.create_event()
self._receive_queue = self._backend.create_queue(0)
self._mode = config.lifespan
self._supported = self._mode in ("on", "auto")
async def startup(self) -> None:
if not self._supported:
self._logger.log(TRACE_LOG_LEVEL, "lifespan startup skipped")
return
self._logger.info("Waiting for application startup.")
await self._receive_queue.put({"type": "lifespan.startup"})
await self._startup_event.wait()
self._logger.info("Application startup complete.")
async def shutdown(self) -> None:
if not self._supported:
self._logger.log(TRACE_LOG_LEVEL, "lifespan shutdown skipped")
return
self._logger.info("Waiting for application shutdown.")
await self._receive_queue.put({"type": "lifespan.shutdown"})
await self._shutdown_event.wait()
self._logger.info("Application shutdown complete.")
async def main(self) -> None:
if not self._supported:
self._logger.log(TRACE_LOG_LEVEL, "lifespan main skipped")
return
scope = {
"type": "lifespan",
"asgi": {"version": self._config.asgi_version, "spec_version": "2.0"},
}
app = self._config.loaded_app
try:
await app(scope, self._asgi_receive, self._asgi_send)
except LifespanFailure:
self._logger.error("Exception in 'lifespan' protocol")
raise # Lifespan failures should stop the server.
except Exception:
self._logger.info("ASGI 'lifespan' protocol appears unsupported.")
self._supported = False
if self._mode == "on":
raise
finally:
await self._startup_event.set()
await self._shutdown_event.set()
await self._receive_queue.aclose()
async def _asgi_receive(self) -> dict:
return await self._receive_queue.get()
async def _asgi_send(self, message: dict) -> None:
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
), message["type"]
if message["type"] == "lifespan.startup.complete":
assert not self._startup_event.is_set(), STATE_TRANSITION_ERROR
assert not self._shutdown_event.is_set(), STATE_TRANSITION_ERROR
await self._startup_event.set()
elif message["type"] == "lifespan.startup.failed":
assert not self._startup_event.is_set(), STATE_TRANSITION_ERROR
assert not self._shutdown_event.is_set(), STATE_TRANSITION_ERROR
await self._startup_event.set()
raise LifespanFailure(message.get("message", ""))
elif message["type"] == "lifespan.shutdown.complete":
assert self._startup_event.is_set(), STATE_TRANSITION_ERROR
assert not self._shutdown_event.is_set(), STATE_TRANSITION_ERROR
await self._shutdown_event.set()
else:
assert message["type"] == "lifespan.shutdown.failed"
assert self._startup_event.is_set(), STATE_TRANSITION_ERROR
assert not self._shutdown_event.is_set(), STATE_TRANSITION_ERROR
await self._shutdown_event.set()
raise LifespanFailure(message.get("message", ""))

View File

@ -0,0 +1,281 @@
import contextlib
import logging
import os
import signal
import socket
import threading
import time
from typing import Callable, Iterator, List, Optional
import click
from ..config import Config
from .backends.auto import select_async_backend
from .backends.base import AsyncListener, Event, TaskStatus
from .http11.handler import handle_http11
from .lifespan import Lifespan
from .state import ServerState
from .utils import to_internet_date
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)
class Server:
def __init__(self, config: Config) -> None:
self._config = config
self._state = ServerState()
self._backend = select_async_backend(config.async_library)
self._logger = logging.getLogger("uvicorn.error")
self._last_notified = time.time()
self._force_exit = False
@property
def _shutdown_event(self) -> Event:
if not hasattr(self, "_shutdown_event_obj"):
# Can't be created on init due to a limitation of multiprocessing.
self._shutdown_event_obj = self._backend.create_event()
return self._shutdown_event_obj
@property
def _closed_event(self) -> Event:
if not hasattr(self, "_closed_event_obj"):
# Can't be created on init due to a limitation of multiprocessing.
self._closed_event_obj = self._backend.create_event()
return self._closed_event_obj
def run(
self,
sockets: List[socket.SocketType] = None,
started: Callable[[], None] = None,
) -> None:
if self._config.async_library == "asyncio":
self._config.setup_event_loop()
self._backend.run(self.serve, sockets, started)
@contextlib.contextmanager
def run_in_thread(self, sockets: List[socket.SocketType] = None) -> Iterator[None]:
thread_exc: Optional[Exception] = None
# Cannot be stored as an instance attribute because Event objects are not
# picklable, so this would cause issues with multiprocessing support.
started_event = threading.Event()
def target() -> None:
nonlocal thread_exc
try:
self.run(sockets=sockets, started=started_event.set)
except Exception as exc:
self._logger.exception(exc)
started_event.set()
thread_exc = exc
thread = threading.Thread(target=target)
thread.start()
try:
started_event.wait()
if thread_exc is not None:
raise thread_exc
yield
finally:
thread.join()
async def serve(
self,
sockets: List[socket.SocketType] = None,
started: Callable[[], None] = None,
) -> None:
started = (lambda: None) if started is None else started
process_id = os.getpid()
message = "Started server process [%d]"
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
self._logger.info(message, process_id, extra={"color_message": color_message})
config = self._config
if not config.loaded:
config.load()
lifespan = Lifespan(config)
shutdown_trigger = (
self._shutdown_event.wait
if config.shutdown_trigger is None
else config.shutdown_trigger
)
async with self._backend.start_soon(lifespan.main, cancel_on_exit=True):
await lifespan.startup()
try:
async with self._backend.start(self._serve, sockets) as listeners:
async with (
self._backend.start_soon(self._main_loop),
self._backend.start_soon(self._listen_signals),
):
# Server has started.
started()
self._log_started_message(listeners, sockets=sockets)
# Let the server run until exit is requested.
await shutdown_trigger()
await self._shutdown_event.set()
finally:
if not self._force_exit:
await lifespan.shutdown()
message = "Finished server process [%d]"
color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
self._logger.info(
"Finished server process [%d]",
process_id,
extra={"color_message": color_message},
)
async def _serve(
self,
sockets: List[socket.SocketType] = None,
*,
task_status: TaskStatus = TaskStatus.IGNORED,
) -> None:
await self._backend.serve_tcp(
handle_http11,
self._state,
self._config,
sockets=sockets,
wait_close=self._shutdown_event.wait,
on_close=self._on_close,
task_status=task_status,
)
async def _main_loop(self) -> None:
counter = 0
should_exit = await self._tick(counter)
while not should_exit:
counter += 1
counter = counter % 864000
await self._backend.sleep(0.1)
should_exit = await self._tick(counter)
await self._shutdown_event.set()
async def _tick(self, counter: int) -> bool:
state = self._state
config = self._config
# Update the default headers, once per second.
if counter % 10 == 0:
current_time = time.time()
current_date = to_internet_date(current_time).encode()
state.default_headers = [(b"date", current_date)] + config.encoded_headers
# Callback to `callback_notify` once every `timeout_notify` seconds.
if config.callback_notify is not None:
if current_time - self._last_notified > config.timeout_notify:
self._last_notified = current_time
await config.callback_notify()
# Determine if we should exit.
if self._shutdown_event.is_set():
return True
if config.limit_max_requests is not None:
return state.total_requests >= config.limit_max_requests
return False
async def _listen_signals(self) -> None:
if threading.current_thread() is not threading.main_thread():
await self._closed_event.wait()
return
async def handle_signal_exit() -> None:
if self._shutdown_event.is_set():
self._logger.info("Shutting down forcibly")
self._force_exit = True
else:
await self._shutdown_event.set()
async def listen_signals() -> None:
await self._backend.listen_signals(
*HANDLED_SIGNALS, handler=handle_signal_exit
)
async with self._backend.start_soon(listen_signals, cancel_on_exit=True):
await self._closed_event.wait()
async def _on_close(self) -> None:
assert self._shutdown_event.is_set()
self._logger.info("Shutting down")
state = self._state
# Request shutdown on all existing connections.
for conn in list(state.connections):
await conn.trigger_shutdown()
# Wait for existing connections to finish sending responses.
if state.connections and not self._force_exit:
self._logger.info(
"Waiting for connections to close. (CTRL+C to force quit)"
)
while state.connections and not self._force_exit:
await self._backend.sleep(0.1)
# Wait for existing tasks to complete.
if state.tasks and not self._force_exit:
self._logger.info(
"Waiting for background tasks to complete. (CTRL+C to force quit)"
)
while state.tasks and not self._force_exit:
await self._backend.sleep(0.1)
await self._closed_event.set()
def _log_started_message(
self, listeners: List[AsyncListener], sockets: List[socket.SocketType] = None
) -> None:
if sockets is not None:
# We're running multiple workers, and a message has already been
# logged by `config.bind_socket()``.
return
config = self._config
if config.fd is not None:
sock = listeners[0].socket
self._logger.info(
"Uvicorn running on socket %s (Press CTRL+C to quit)",
sock.getsockname(),
)
return
if config.uds is not None:
self._logger.info(
"Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds
)
return
addr_format = "%s://%s:%d"
host = "0.0.0.0" if config.host is None else config.host
if ":" in host:
# It's an IPv6 address.
addr_format = "%s://[%s]:%d"
port = config.port
if port == 0:
sock = listeners[0].socket
_, port = sock.getpeername()
protocol_name = "https" if config.ssl else "http"
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
color_message = (
"Uvicorn running on "
+ click.style(addr_format, bold=True)
+ " (Press CTRL+C to quit)"
)
self._logger.info(
message,
protocol_name,
host,
port,
extra={"color_message": color_message},
)

View File

@ -0,0 +1,9 @@
from typing import Any, List, Set, Tuple
class ServerState:
def __init__(self) -> None:
self.default_headers: List[Tuple[bytes, bytes]] = [] # Set by the server.
self.total_requests = 0 # Updated by server handlers.
self.connections: Set[Any] = set()
self.tasks: Set[Any] = set() # May be used by async backends (not all do).

View File

@ -0,0 +1,65 @@
import http
import socket
from email.utils import formatdate
from typing import List, Optional, Tuple
from urllib.parse import quote
def _get_status_phrase(status_code: int) -> bytes:
try:
return http.HTTPStatus(status_code).phrase.encode()
except ValueError:
return b""
STATUS_PHRASES = {
status_code: _get_status_phrase(status_code) for status_code in range(100, 600)
}
RECV_CHUNK_SIZE = 2 ** 16
TRACE_LOG_LEVEL = 5
def get_sock_remote_addr(sock: socket.SocketType) -> Optional[Tuple[str, int]]:
try:
info = sock.getpeername()
except OSError:
# This case appears to inconsistently occur with uvloop
# bound to a unix domain socket.
return None
else:
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
def get_sock_local_addr(sock: socket.SocketType) -> Optional[Tuple[str, int]]:
info = sock.getsockname()
if isinstance(info, tuple):
return (str(info[0]), int(info[1]))
return None
def get_path_with_query_string(scope: dict) -> str:
path = quote(scope.get("root_path", "") + scope["path"])
qs = scope["query_string"]
if qs:
path += "?{}".format(qs.decode("ascii"))
return path
def to_internet_date(value: float) -> str:
return formatdate(value, usegmt=True)
def find_upgrade_header(headers: List[Tuple[bytes, bytes]]) -> Optional[bytes]:
connection = next((value for name, value in headers if name == b"connection"), None)
if connection is None:
return None
tokens = [token.lower().strip() for token in connection.split(b",")]
if b"upgrade" not in tokens:
return None
return next(value.lower() for name, value in headers if name == b"upgrade")

View File

@ -123,6 +123,7 @@ class Config:
port=8000,
uds=None,
fd=None,
async_library=None,
loop="auto",
http="auto",
ws="auto",
@ -147,6 +148,7 @@ class Config:
timeout_keep_alive=5,
timeout_notify=30,
callback_notify=None,
shutdown_trigger=None,
ssl_keyfile=None,
ssl_certfile=None,
ssl_keyfile_password=None,
@ -161,6 +163,7 @@ class Config:
self.port = port
self.uds = uds
self.fd = fd
self.async_library = async_library
self.loop = loop
self.http = http
self.ws = ws
@ -182,6 +185,7 @@ class Config:
self.timeout_keep_alive = timeout_keep_alive
self.timeout_notify = timeout_notify
self.callback_notify = callback_notify
self.shutdown_trigger = shutdown_trigger
self.ssl_keyfile = ssl_keyfile
self.ssl_certfile = ssl_certfile
self.ssl_keyfile_password = ssl_keyfile_password

View File

@ -21,10 +21,13 @@ from uvicorn.config import (
)
from uvicorn.supervisors import ChangeReload, Multiprocess
from ._async_agnostic.server import Server as AsyncAgnosticServer
LEVEL_CHOICES = click.Choice(LOG_LEVELS.keys())
HTTP_CHOICES = click.Choice(HTTP_PROTOCOLS.keys())
WS_CHOICES = click.Choice(WS_PROTOCOLS.keys())
LIFESPAN_CHOICES = click.Choice(LIFESPAN.keys())
LIBRARY_CHOICES = click.Choice(["asyncio", "trio", "curio"])
LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"])
INTERFACE_CHOICES = click.Choice(INTERFACES)
@ -96,6 +99,13 @@ def print_version(ctx, param, value):
help="Number of worker processes. Defaults to the $WEB_CONCURRENCY environment"
" variable if available. Not valid with --reload.",
)
@click.option(
"--async-library",
type=LIBRARY_CHOICES,
default=None,
help="Concurrency library implementation.",
show_default=True,
)
@click.option(
"--loop",
type=LOOP_CHOICES,
@ -283,6 +293,7 @@ def main(
port: int,
uds: str,
fd: int,
async_library: typing.Optional[str],
loop: str,
http: str,
ws: str,
@ -323,6 +334,7 @@ def main(
"port": port,
"uds": uds,
"fd": fd,
"async_library": async_library,
"loop": loop,
"http": http,
"ws": ws,
@ -359,7 +371,11 @@ def main(
def run(app, **kwargs):
config = Config(app, **kwargs)
server = Server(config=config)
if config.async_library is not None:
server = AsyncAgnosticServer(config)
else:
server = Server(config=config)
if (config.reload or config.workers > 1) and not isinstance(app, str):
logger = logging.getLogger("uvicorn.error")