parent
0d4747e602
commit
4a503d84fa
@ -61,7 +61,8 @@ path = "uvicorn/__init__.py"
|
||||
include = ["/uvicorn"]
|
||||
|
||||
[tool.ruff]
|
||||
select = ["E", "F", "I"]
|
||||
line-length = 120
|
||||
select = ["E", "F", "I", "FA", "UP"]
|
||||
ignore = ["B904", "B028"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import importlib.util
|
||||
import os
|
||||
@ -9,6 +11,7 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from threading import Thread
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -38,14 +41,14 @@ LOGGING_CONFIG["loggers"]["uvicorn"]["propagate"] = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tls_certificate_authority() -> "trustme.CA":
|
||||
def tls_certificate_authority() -> trustme.CA:
|
||||
if not HAVE_TRUSTME:
|
||||
pytest.skip("trustme not installed") # pragma: no cover
|
||||
return trustme.CA()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tls_certificate(tls_certificate_authority: "trustme.CA") -> "trustme.LeafCert":
|
||||
def tls_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert:
|
||||
return tls_certificate_authority.issue_cert(
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
@ -54,13 +57,13 @@ def tls_certificate(tls_certificate_authority: "trustme.CA") -> "trustme.LeafCer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tls_ca_certificate_pem_path(tls_certificate_authority: "trustme.CA"):
|
||||
def tls_ca_certificate_pem_path(tls_certificate_authority: trustme.CA):
|
||||
with tls_certificate_authority.cert_pem.tempfile() as ca_cert_pem:
|
||||
yield ca_cert_pem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tls_ca_certificate_private_key_path(tls_certificate_authority: "trustme.CA"):
|
||||
def tls_ca_certificate_private_key_path(tls_certificate_authority: trustme.CA):
|
||||
with tls_certificate_authority.private_key_pem.tempfile() as private_key:
|
||||
yield private_key
|
||||
|
||||
@ -82,25 +85,25 @@ def tls_certificate_private_key_encrypted_path(tls_certificate):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tls_certificate_private_key_path(tls_certificate: "trustme.CA"):
|
||||
def tls_certificate_private_key_path(tls_certificate: trustme.CA):
|
||||
with tls_certificate.private_key_pem.tempfile() as private_key:
|
||||
yield private_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tls_certificate_key_and_chain_path(tls_certificate: "trustme.LeafCert"):
|
||||
def tls_certificate_key_and_chain_path(tls_certificate: trustme.LeafCert):
|
||||
with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem:
|
||||
yield cert_pem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tls_certificate_server_cert_path(tls_certificate: "trustme.LeafCert"):
|
||||
def tls_certificate_server_cert_path(tls_certificate: trustme.LeafCert):
|
||||
with tls_certificate.cert_chain_pems[0].tempfile() as cert_pem:
|
||||
yield cert_pem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tls_ca_ssl_context(tls_certificate_authority: "trustme.CA") -> ssl.SSLContext:
|
||||
def tls_ca_ssl_context(tls_certificate_authority: trustme.CA) -> ssl.SSLContext:
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||
tls_certificate_authority.configure_trust(ssl_ctx)
|
||||
return ssl_ctx
|
||||
@ -172,7 +175,7 @@ def anyio_backend() -> str:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def logging_config() -> dict:
|
||||
def logging_config() -> dict[str, Any]:
|
||||
return deepcopy(LOGGING_CONFIG)
|
||||
|
||||
|
||||
@ -250,9 +253,7 @@ def unused_tcp_port() -> int:
|
||||
params=[
|
||||
pytest.param(
|
||||
"uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
|
||||
marks=pytest.mark.skipif(
|
||||
not importlib.util.find_spec("wsproto"), reason="wsproto not installed."
|
||||
),
|
||||
marks=pytest.mark.skipif(not importlib.util.find_spec("wsproto"), reason="wsproto not installed."),
|
||||
id="wsproto",
|
||||
),
|
||||
pytest.param(
|
||||
|
||||
@ -60,9 +60,7 @@ async def test_trace_logging(caplog, logging_config, unused_tcp_port: int):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
assert response.status_code == 204
|
||||
messages = [
|
||||
record.message for record in caplog.records if record.name == "uvicorn.asgi"
|
||||
]
|
||||
messages = [record.message for record in caplog.records if record.name == "uvicorn.asgi"]
|
||||
assert "ASGI [1] Started scope=" in messages.pop(0)
|
||||
assert "ASGI [1] Raised exception" in messages.pop(0)
|
||||
assert "ASGI [2] Started scope=" in messages.pop(0)
|
||||
@ -72,9 +70,7 @@ async def test_trace_logging(caplog, logging_config, unused_tcp_port: int):
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_trace_logging_on_http_protocol(
|
||||
http_protocol_cls, caplog, logging_config, unused_tcp_port: int
|
||||
):
|
||||
async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging_config, unused_tcp_port: int):
|
||||
config = Config(
|
||||
app=app,
|
||||
log_level="trace",
|
||||
@ -87,11 +83,7 @@ async def test_trace_logging_on_http_protocol(
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
assert response.status_code == 204
|
||||
messages = [
|
||||
record.message
|
||||
for record in caplog.records
|
||||
if record.name == "uvicorn.error"
|
||||
]
|
||||
messages = [record.message for record in caplog.records if record.name == "uvicorn.error"]
|
||||
assert any(" - HTTP connection made" in message for message in messages)
|
||||
assert any(" - HTTP connection lost" in message for message in messages)
|
||||
|
||||
@ -127,11 +119,7 @@ async def test_trace_logging_on_ws_protocol(
|
||||
async with run_server(config):
|
||||
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
||||
assert is_open
|
||||
messages = [
|
||||
record.message
|
||||
for record in caplog.records
|
||||
if record.name == "uvicorn.error"
|
||||
]
|
||||
messages = [record.message for record in caplog.records if record.name == "uvicorn.error"]
|
||||
assert any(" - Upgrading to WebSocket" in message for message in messages)
|
||||
assert any(" - WebSocket connection made" in message for message in messages)
|
||||
assert any(" - WebSocket connection lost" in message for message in messages)
|
||||
@ -140,39 +128,27 @@ async def test_trace_logging_on_ws_protocol(
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("use_colors", [(True), (False), (None)])
|
||||
async def test_access_logging(use_colors, caplog, logging_config, unused_tcp_port: int):
|
||||
config = Config(
|
||||
app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port
|
||||
)
|
||||
config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port)
|
||||
with caplog_for_logger(caplog, "uvicorn.access"):
|
||||
async with run_server(config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
|
||||
assert response.status_code == 204
|
||||
messages = [
|
||||
record.message
|
||||
for record in caplog.records
|
||||
if record.name == "uvicorn.access"
|
||||
]
|
||||
messages = [record.message for record in caplog.records if record.name == "uvicorn.access"]
|
||||
assert '"GET / HTTP/1.1" 204' in messages.pop()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("use_colors", [(True), (False)])
|
||||
async def test_default_logging(
|
||||
use_colors, caplog, logging_config, unused_tcp_port: int
|
||||
):
|
||||
config = Config(
|
||||
app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port
|
||||
)
|
||||
async def test_default_logging(use_colors, caplog, logging_config, unused_tcp_port: int):
|
||||
config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port)
|
||||
with caplog_for_logger(caplog, "uvicorn.access"):
|
||||
async with run_server(config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
assert response.status_code == 204
|
||||
messages = [
|
||||
record.message for record in caplog.records if "uvicorn" in record.name
|
||||
]
|
||||
messages = [record.message for record in caplog.records if "uvicorn" in record.name]
|
||||
assert "Started server process" in messages.pop(0)
|
||||
assert "Waiting for application startup" in messages.pop(0)
|
||||
assert "ASGI 'lifespan' protocol appears unsupported" in messages.pop(0)
|
||||
@ -184,19 +160,14 @@ async def test_default_logging(
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system")
|
||||
async def test_running_log_using_uds(
|
||||
caplog, short_socket_name, unused_tcp_port: int
|
||||
): # pragma: py-win32
|
||||
async def test_running_log_using_uds(caplog, short_socket_name, unused_tcp_port: int): # pragma: py-win32
|
||||
config = Config(app=app, uds=short_socket_name, port=unused_tcp_port)
|
||||
with caplog_for_logger(caplog, "uvicorn.access"):
|
||||
async with run_server(config):
|
||||
...
|
||||
|
||||
messages = [record.message for record in caplog.records if "uvicorn" in record.name]
|
||||
assert (
|
||||
f"Uvicorn running on unix socket {short_socket_name} (Press CTRL+C to quit)"
|
||||
in messages
|
||||
)
|
||||
assert f"Uvicorn running on unix socket {short_socket_name} (Press CTRL+C to quit)" in messages
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@ -227,11 +198,7 @@ async def test_unknown_status_code(caplog, unused_tcp_port: int):
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
|
||||
assert response.status_code == 599
|
||||
messages = [
|
||||
record.message
|
||||
for record in caplog.records
|
||||
if record.name == "uvicorn.access"
|
||||
]
|
||||
messages = [record.message for record in caplog.records if record.name == "uvicorn.access"]
|
||||
assert '"GET / HTTP/1.1" 599' in messages.pop()
|
||||
|
||||
|
||||
|
||||
@ -26,9 +26,7 @@ async def test_message_logger(caplog):
|
||||
assert sum(["ASGI [1] Send" in message for message in messages]) == 2
|
||||
assert sum(["ASGI [1] Receive" in message for message in messages]) == 1
|
||||
assert sum(["ASGI [1] Completed" in message for message in messages]) == 1
|
||||
assert (
|
||||
sum(["ASGI [1] Raised exception" in message for message in messages]) == 0
|
||||
)
|
||||
assert sum(["ASGI [1] Raised exception" in message for message in messages]) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@ -48,6 +46,4 @@ async def test_message_logger_exc(caplog):
|
||||
assert sum(["ASGI [1] Send" in message for message in messages]) == 0
|
||||
assert sum(["ASGI [1] Receive" in message for message in messages]) == 0
|
||||
assert sum(["ASGI [1] Completed" in message for message in messages]) == 0
|
||||
assert (
|
||||
sum(["ASGI [1] Raised exception" in message for message in messages]) == 1
|
||||
)
|
||||
assert sum(["ASGI [1] Raised exception" in message for message in messages]) == 1
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Type, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@ -45,13 +47,9 @@ async def app(
|
||||
("192.168.0.1", "Remote: http://127.0.0.1:123"),
|
||||
],
|
||||
)
|
||||
async def test_proxy_headers_trusted_hosts(
|
||||
trusted_hosts: Union[List[str], str], response_text: str
|
||||
) -> None:
|
||||
async def test_proxy_headers_trusted_hosts(trusted_hosts: list[str] | str, response_text: str) -> None:
|
||||
app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts=trusted_hosts)
|
||||
async with httpx.AsyncClient(
|
||||
app=app_with_middleware, base_url="http://testserver"
|
||||
) as client:
|
||||
async with httpx.AsyncClient(app=app_with_middleware, base_url="http://testserver") as client:
|
||||
headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"}
|
||||
response = await client.get("/", headers=headers)
|
||||
|
||||
@ -79,13 +77,9 @@ async def test_proxy_headers_trusted_hosts(
|
||||
(["192.168.0.2", "127.0.0.1"], "Remote: https://10.0.2.1:0"),
|
||||
],
|
||||
)
|
||||
async def test_proxy_headers_multiple_proxies(
|
||||
trusted_hosts: Union[List[str], str], response_text: str
|
||||
) -> None:
|
||||
async def test_proxy_headers_multiple_proxies(trusted_hosts: list[str] | str, response_text: str) -> None:
|
||||
app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts=trusted_hosts)
|
||||
async with httpx.AsyncClient(
|
||||
app=app_with_middleware, base_url="http://testserver"
|
||||
) as client:
|
||||
async with httpx.AsyncClient(app=app_with_middleware, base_url="http://testserver") as client:
|
||||
headers = {
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-For": "1.2.3.4, 10.0.2.1, 192.168.0.2",
|
||||
@ -99,9 +93,7 @@ async def test_proxy_headers_multiple_proxies(
|
||||
@pytest.mark.anyio
|
||||
async def test_proxy_headers_invalid_x_forwarded_for() -> None:
|
||||
app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts="*")
|
||||
async with httpx.AsyncClient(
|
||||
app=app_with_middleware, base_url="http://testserver"
|
||||
) as client:
|
||||
async with httpx.AsyncClient(app=app_with_middleware, base_url="http://testserver") as client:
|
||||
headers = httpx.Headers(
|
||||
{
|
||||
"X-Forwarded-Proto": "https",
|
||||
@ -127,12 +119,14 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None:
|
||||
async def test_proxy_headers_websocket_x_forwarded_proto(
|
||||
x_forwarded_proto: str,
|
||||
addr: str,
|
||||
ws_protocol_cls: "Type[WSProtocol | WebSocketProtocol]",
|
||||
http_protocol_cls: "Type[H11Protocol | HttpToolsProtocol]",
|
||||
ws_protocol_cls: type[WSProtocol | WebSocketProtocol],
|
||||
http_protocol_cls: type[H11Protocol | HttpToolsProtocol],
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
async def websocket_app(scope, receive, send):
|
||||
async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
assert scope["type"] == "websocket"
|
||||
scheme = scope["scheme"]
|
||||
assert scope["client"] is not None
|
||||
host, port = scope["client"]
|
||||
addr = "%s://%s:%d" % (scheme, host, port)
|
||||
await send({"type": "websocket.accept"})
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import sys
|
||||
from typing import AsyncGenerator, Callable, List
|
||||
from typing import AsyncGenerator, Callable
|
||||
|
||||
import a2wsgi
|
||||
import httpx
|
||||
@ -10,7 +12,7 @@ from uvicorn._types import Environ, HTTPRequestEvent, HTTPScope, StartResponse
|
||||
from uvicorn.middleware import wsgi
|
||||
|
||||
|
||||
def hello_world(environ: Environ, start_response: StartResponse) -> List[bytes]:
|
||||
def hello_world(environ: Environ, start_response: StartResponse) -> list[bytes]:
|
||||
status = "200 OK"
|
||||
output = b"Hello World!\n"
|
||||
headers = [
|
||||
@ -21,7 +23,7 @@ def hello_world(environ: Environ, start_response: StartResponse) -> List[bytes]:
|
||||
return [output]
|
||||
|
||||
|
||||
def echo_body(environ: Environ, start_response: StartResponse) -> List[bytes]:
|
||||
def echo_body(environ: Environ, start_response: StartResponse) -> list[bytes]:
|
||||
status = "200 OK"
|
||||
output = environ["wsgi.input"].read()
|
||||
headers = [
|
||||
@ -32,11 +34,11 @@ def echo_body(environ: Environ, start_response: StartResponse) -> List[bytes]:
|
||||
return [output]
|
||||
|
||||
|
||||
def raise_exception(environ: Environ, start_response: StartResponse) -> List[bytes]:
|
||||
def raise_exception(environ: Environ, start_response: StartResponse) -> list[bytes]:
|
||||
raise RuntimeError("Something went wrong")
|
||||
|
||||
|
||||
def return_exc_info(environ: Environ, start_response: StartResponse) -> List[bytes]:
|
||||
def return_exc_info(environ: Environ, start_response: StartResponse) -> list[bytes]:
|
||||
try:
|
||||
raise RuntimeError("Something went wrong")
|
||||
except RuntimeError:
|
||||
@ -110,16 +112,14 @@ async def test_wsgi_exc_info(wsgi_middleware: Callable) -> None:
|
||||
app=app,
|
||||
raise_app_exceptions=False,
|
||||
)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url="http://testserver"
|
||||
) as client:
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
response = await client.get("/")
|
||||
assert response.status_code == 500
|
||||
assert response.text == "Internal Server Error"
|
||||
|
||||
|
||||
def test_build_environ_encoding() -> None:
|
||||
scope: "HTTPScope" = {
|
||||
scope: HTTPScope = {
|
||||
"asgi": {"version": "3.0", "spec_version": "2.0"},
|
||||
"scheme": "http",
|
||||
"raw_path": b"/\xe6\x96\x87%2Fall",
|
||||
@ -134,12 +134,12 @@ def test_build_environ_encoding() -> None:
|
||||
"headers": [(b"key", b"value1"), (b"key", b"value2")],
|
||||
"extensions": {},
|
||||
}
|
||||
message: "HTTPRequestEvent" = {
|
||||
message: HTTPRequestEvent = {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
environ = wsgi.build_environ(scope, message, io.BytesIO(b""))
|
||||
assert environ["SCRIPT_NAME"] == "/文".encode("utf8").decode("latin-1")
|
||||
assert environ["PATH_INFO"] == "/all".encode("utf8").decode("latin-1")
|
||||
assert environ["SCRIPT_NAME"] == "/文".encode().decode("latin-1")
|
||||
assert environ["PATH_INFO"] == b"/all".decode("latin-1")
|
||||
assert environ["HTTP_KEY"] == "value1,value2"
|
||||
|
||||
@ -58,9 +58,7 @@ SIMPLE_POST_REQUEST = b"\r\n".join(
|
||||
]
|
||||
)
|
||||
|
||||
CONNECTION_CLOSE_REQUEST = b"\r\n".join(
|
||||
[b"GET / HTTP/1.1", b"Host: example.org", b"Connection: close", b"", b""]
|
||||
)
|
||||
CONNECTION_CLOSE_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"Connection: close", b"", b""])
|
||||
|
||||
LARGE_POST_REQUEST = b"\r\n".join(
|
||||
[
|
||||
@ -88,9 +86,7 @@ FINISH_POST_REQUEST = b'{"hello": "world"}'
|
||||
|
||||
HTTP10_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.0", b"Host: example.org", b"", b""])
|
||||
|
||||
GET_REQUEST_WITH_RAW_PATH = b"\r\n".join(
|
||||
[b"GET /one%2Ftwo HTTP/1.1", b"Host: example.org", b"", b""]
|
||||
)
|
||||
GET_REQUEST_WITH_RAW_PATH = b"\r\n".join([b"GET /one%2Ftwo HTTP/1.1", b"Host: example.org", b"", b""])
|
||||
|
||||
UPGRADE_REQUEST = b"\r\n".join(
|
||||
[
|
||||
@ -257,11 +253,9 @@ async def test_get_request(http_protocol_cls: HTTPProtocol):
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"])
|
||||
async def test_request_logging(
|
||||
path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture):
|
||||
get_request_with_query_string = b"\r\n".join(
|
||||
["GET {} HTTP/1.1".format(path).encode("ascii"), b"Host: example.org", b"", b""]
|
||||
[f"GET {path} HTTP/1.1".encode("ascii"), b"Host: example.org", b"", b""]
|
||||
)
|
||||
caplog.set_level(logging.INFO, logger="uvicorn.access")
|
||||
logging.getLogger("uvicorn.access").propagate = True
|
||||
@ -271,7 +265,7 @@ async def test_request_logging(
|
||||
protocol = get_connected_protocol(app, http_protocol_cls, log_config=None)
|
||||
protocol.data_received(get_request_with_query_string)
|
||||
await protocol.loop.run_one()
|
||||
assert '"GET {} HTTP/1.1" 200'.format(path) in caplog.records[0].message
|
||||
assert f'"GET {path} HTTP/1.1" 200' in caplog.records[0].message
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@ -371,9 +365,7 @@ async def test_close(http_protocol_cls: HTTPProtocol):
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
|
||||
app = Response(
|
||||
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
|
||||
)
|
||||
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
@ -385,9 +377,7 @@ async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
|
||||
app = Response(
|
||||
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
|
||||
)
|
||||
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
@ -401,9 +391,7 @@ async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
|
||||
async def test_chunked_encoding_head_request(
|
||||
http_protocol_cls: HTTPProtocol,
|
||||
):
|
||||
app = Response(
|
||||
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
|
||||
)
|
||||
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_HEAD_REQUEST)
|
||||
@ -669,9 +657,7 @@ async def test_root_path(http_protocol_cls: HTTPProtocol):
|
||||
assert scope["type"] == "http"
|
||||
root_path = scope.get("root_path", "")
|
||||
path = scope["path"]
|
||||
response = Response(
|
||||
f"root_path={root_path} path={path}", media_type="text/plain"
|
||||
)
|
||||
response = Response(f"root_path={root_path} path={path}", media_type="text/plain")
|
||||
await response(scope, receive, send)
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls, root_path="/app")
|
||||
@ -821,21 +807,14 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto(
|
||||
await protocol.loop.run_one()
|
||||
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
|
||||
assert b"Hello, world" in protocol.transport.buffer
|
||||
warnings = [
|
||||
record.msg
|
||||
for record in filter(
|
||||
lambda record: record.levelname == "WARNING", caplog.records
|
||||
)
|
||||
]
|
||||
warnings = [record.msg for record in filter(lambda record: record.levelname == "WARNING", caplog.records)]
|
||||
assert "Unsupported upgrade request." in warnings
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
assert msg in warnings
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_http2_upgrade_request(
|
||||
http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol
|
||||
):
|
||||
async def test_http2_upgrade_request(http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls, ws=ws_protocol_cls)
|
||||
@ -915,9 +894,7 @@ def test_fragmentation(unused_tcp_port: int):
|
||||
def send_fragmented_req(path: str):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.connect(("127.0.0.1", unused_tcp_port))
|
||||
d = (
|
||||
f"GET {path} HTTP/1.1\r\n" "Host: localhost\r\n" "Connection: close\r\n\r\n"
|
||||
).encode()
|
||||
d = (f"GET {path} HTTP/1.1\r\n" "Host: localhost\r\n" "Connection: close\r\n\r\n").encode()
|
||||
split = len(path) // 2
|
||||
sock.sendall(d[:split])
|
||||
time.sleep(0.01)
|
||||
@ -979,9 +956,7 @@ async def test_huge_headers_httptools_will_pass():
|
||||
async def test_huge_headers_h11protocol_failure_with_setting():
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(
|
||||
app, H11Protocol, h11_max_incomplete_event_size=20 * 1024
|
||||
)
|
||||
protocol = get_connected_protocol(app, H11Protocol, h11_max_incomplete_event_size=20 * 1024)
|
||||
# Huge headers make h11 fail in it's default config
|
||||
# h11 sends back a 400 in this case
|
||||
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
|
||||
@ -1009,9 +984,7 @@ async def test_huge_headers_httptools():
|
||||
async def test_huge_headers_h11_max_incomplete():
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(
|
||||
app, H11Protocol, h11_max_incomplete_event_size=64 * 1024
|
||||
)
|
||||
protocol = get_connected_protocol(app, H11Protocol, h11_max_incomplete_event_size=64 * 1024)
|
||||
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
|
||||
protocol.data_received(GET_REQUEST_HUGE_HEADERS[1])
|
||||
await protocol.loop.run_one()
|
||||
|
||||
@ -31,20 +31,14 @@ def test_get_local_addr_with_socket():
|
||||
transport = MockTransport({"socket": MockSocket(family=socket.AF_IPX)})
|
||||
assert get_local_addr(transport) is None
|
||||
|
||||
transport = MockTransport(
|
||||
{"socket": MockSocket(family=socket.AF_INET6, sockname=("::1", 123))}
|
||||
)
|
||||
transport = MockTransport({"socket": MockSocket(family=socket.AF_INET6, sockname=("::1", 123))})
|
||||
assert get_local_addr(transport) == ("::1", 123)
|
||||
|
||||
transport = MockTransport(
|
||||
{"socket": MockSocket(family=socket.AF_INET, sockname=("123.45.6.7", 123))}
|
||||
)
|
||||
transport = MockTransport({"socket": MockSocket(family=socket.AF_INET, sockname=("123.45.6.7", 123))})
|
||||
assert get_local_addr(transport) == ("123.45.6.7", 123)
|
||||
|
||||
if hasattr(socket, "AF_UNIX"): # pragma: no cover
|
||||
transport = MockTransport(
|
||||
{"socket": MockSocket(family=socket.AF_UNIX, sockname=("127.0.0.1", 8000))}
|
||||
)
|
||||
transport = MockTransport({"socket": MockSocket(family=socket.AF_UNIX, sockname=("127.0.0.1", 8000))})
|
||||
assert get_local_addr(transport) == ("127.0.0.1", 8000)
|
||||
|
||||
|
||||
@ -52,20 +46,14 @@ def test_get_remote_addr_with_socket():
|
||||
transport = MockTransport({"socket": MockSocket(family=socket.AF_IPX)})
|
||||
assert get_remote_addr(transport) is None
|
||||
|
||||
transport = MockTransport(
|
||||
{"socket": MockSocket(family=socket.AF_INET6, peername=("::1", 123))}
|
||||
)
|
||||
transport = MockTransport({"socket": MockSocket(family=socket.AF_INET6, peername=("::1", 123))})
|
||||
assert get_remote_addr(transport) == ("::1", 123)
|
||||
|
||||
transport = MockTransport(
|
||||
{"socket": MockSocket(family=socket.AF_INET, peername=("123.45.6.7", 123))}
|
||||
)
|
||||
transport = MockTransport({"socket": MockSocket(family=socket.AF_INET, peername=("123.45.6.7", 123))})
|
||||
assert get_remote_addr(transport) == ("123.45.6.7", 123)
|
||||
|
||||
if hasattr(socket, "AF_UNIX"): # pragma: no cover
|
||||
transport = MockTransport(
|
||||
{"socket": MockSocket(family=socket.AF_UNIX, peername=("127.0.0.1", 8000))}
|
||||
)
|
||||
transport = MockTransport({"socket": MockSocket(family=socket.AF_UNIX, peername=("127.0.0.1", 8000))})
|
||||
assert get_remote_addr(transport) == ("127.0.0.1", 8000)
|
||||
|
||||
|
||||
|
||||
@ -50,9 +50,7 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
|
||||
class WebSocketResponse:
|
||||
def __init__(
|
||||
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||
):
|
||||
def __init__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
self.scope = scope
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
@ -87,15 +85,11 @@ async def wsresponse(url):
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_upgrade(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_invalid_upgrade(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
def app(scope: Scope):
|
||||
return None
|
||||
|
||||
config = Config(
|
||||
app=app, ws=ws_protocol_cls, http=http_protocol_cls, port=unused_tcp_port
|
||||
)
|
||||
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, port=unused_tcp_port)
|
||||
async with run_server(config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
@ -117,18 +111,14 @@ async def test_invalid_upgrade(
|
||||
"missing sec-websocket-key header",
|
||||
"missing sec-websocket-version header", # websockets
|
||||
"missing or empty sec-websocket-key header", # wsproto
|
||||
"failed to open a websocket connection: missing "
|
||||
"sec-websocket-key header",
|
||||
"failed to open a websocket connection: missing or empty "
|
||||
"sec-websocket-key header",
|
||||
"failed to open a websocket connection: missing " "sec-websocket-key header",
|
||||
"failed to open a websocket connection: missing or empty " "sec-websocket-key header",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_accept_connection(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_accept_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
class App(WebSocketResponse):
|
||||
async def websocket_connect(self, message):
|
||||
await self.send({"type": "websocket.accept"})
|
||||
@ -150,9 +140,7 @@ async def test_accept_connection(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_shutdown(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_shutdown(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
class App(WebSocketResponse):
|
||||
async def websocket_connect(self, message):
|
||||
await self.send({"type": "websocket.accept"})
|
||||
@ -180,9 +168,7 @@ async def test_supports_permessage_deflate_extension(
|
||||
|
||||
async def open_connection(url):
|
||||
extension_factories = [ClientPerMessageDeflateFactory()]
|
||||
async with websockets.client.connect(
|
||||
url, extensions=extension_factories
|
||||
) as websocket:
|
||||
async with websockets.client.connect(url, extensions=extension_factories) as websocket:
|
||||
return [extension.name for extension in websocket.extensions]
|
||||
|
||||
config = Config(
|
||||
@ -209,9 +195,7 @@ async def test_can_disable_permessage_deflate_extension(
|
||||
# enable per-message deflate on the client, so that we can check the server
|
||||
# won't support it when it's disabled.
|
||||
extension_factories = [ClientPerMessageDeflateFactory()]
|
||||
async with websockets.client.connect(
|
||||
url, extensions=extension_factories
|
||||
) as websocket:
|
||||
async with websockets.client.connect(url, extensions=extension_factories) as websocket:
|
||||
return [extension.name for extension in websocket.extensions]
|
||||
|
||||
config = Config(
|
||||
@ -228,9 +212,7 @@ async def test_can_disable_permessage_deflate_extension(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_close_connection(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_close_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
class App(WebSocketResponse):
|
||||
async def websocket_connect(self, message):
|
||||
await self.send({"type": "websocket.close"})
|
||||
@ -255,9 +237,7 @@ async def test_close_connection(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_headers(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
class App(WebSocketResponse):
|
||||
async def websocket_connect(self, message):
|
||||
headers = self.scope.get("headers")
|
||||
@ -267,9 +247,7 @@ async def test_headers(
|
||||
await self.send({"type": "websocket.accept"})
|
||||
|
||||
async def open_connection(url: str):
|
||||
async with websockets.client.connect(
|
||||
url, extra_headers=[("username", "abraão")]
|
||||
) as websocket:
|
||||
async with websockets.client.connect(url, extra_headers=[("username", "abraão")]) as websocket:
|
||||
return websocket.open
|
||||
|
||||
config = Config(
|
||||
@ -285,14 +263,10 @@ async def test_headers(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_extra_headers(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_extra_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
class App(WebSocketResponse):
|
||||
async def websocket_connect(self, message):
|
||||
await self.send(
|
||||
{"type": "websocket.accept", "headers": [(b"extra", b"header")]}
|
||||
)
|
||||
await self.send({"type": "websocket.accept", "headers": [(b"extra", b"header")]})
|
||||
|
||||
async def open_connection(url: str):
|
||||
async with websockets.client.connect(url) as websocket:
|
||||
@ -311,9 +285,7 @@ async def test_extra_headers(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_path_and_raw_path(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_path_and_raw_path(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
class App(WebSocketResponse):
|
||||
async def websocket_connect(self, message):
|
||||
path = self.scope.get("path")
|
||||
@ -515,9 +487,7 @@ async def test_send_after_protocol_close(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_missing_handshake(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_missing_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
pass
|
||||
|
||||
@ -561,9 +531,7 @@ async def test_send_before_handshake(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_duplicate_handshake(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send({"type": "websocket.accept"})
|
||||
await send({"type": "websocket.accept"})
|
||||
@ -586,9 +554,7 @@ async def test_duplicate_handshake(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_asgi_return_value(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
"""
|
||||
The ASGI callable should return 'None'. If it doesn't, make sure that
|
||||
the connection is closed with an error condition.
|
||||
@ -668,9 +634,7 @@ async def test_app_close(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_close(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
while True:
|
||||
message = await receive()
|
||||
@ -723,9 +687,7 @@ async def test_client_connection_lost(
|
||||
port=unused_tcp_port,
|
||||
)
|
||||
async with run_server(config):
|
||||
async with websockets.client.connect(
|
||||
f"ws://127.0.0.1:{unused_tcp_port}"
|
||||
) as websocket:
|
||||
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
||||
websocket.transport.close()
|
||||
await asyncio.sleep(0.1)
|
||||
got_disconnect_event_before_shutdown = got_disconnect_event
|
||||
@ -748,7 +710,7 @@ async def test_client_connection_lost_on_send(
|
||||
try:
|
||||
await disconnect.wait()
|
||||
await send({"type": "websocket.send", "text": "123"})
|
||||
except IOError:
|
||||
except OSError:
|
||||
got_disconnect_event = True
|
||||
|
||||
config = Config(
|
||||
@ -781,7 +743,7 @@ async def test_connection_lost_before_handshake_complete(
|
||||
await send_accept_task.wait()
|
||||
disconnect_message = await receive() # type: ignore
|
||||
|
||||
response: typing.Optional[httpx.Response] = None
|
||||
response: httpx.Response | None = None
|
||||
|
||||
async def websocket_session(uri: str):
|
||||
nonlocal response
|
||||
@ -804,9 +766,7 @@ async def test_connection_lost_before_handshake_complete(
|
||||
port=unused_tcp_port,
|
||||
)
|
||||
async with run_server(config):
|
||||
task = asyncio.create_task(
|
||||
websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
||||
)
|
||||
task = asyncio.create_task(websocket_session(f"ws://127.0.0.1:{unused_tcp_port}"))
|
||||
await asyncio.sleep(0.1)
|
||||
send_accept_task.set()
|
||||
await asyncio.sleep(0.1)
|
||||
@ -835,7 +795,7 @@ async def test_send_close_on_server_shutdown(
|
||||
disconnect_message = message
|
||||
break
|
||||
|
||||
websocket: typing.Optional[websockets.client.WebSocketClientProtocol] = None
|
||||
websocket: websockets.client.WebSocketClientProtocol | None = None
|
||||
|
||||
async def websocket_session(uri: str):
|
||||
nonlocal websocket
|
||||
@ -851,9 +811,7 @@ async def test_send_close_on_server_shutdown(
|
||||
port=unused_tcp_port,
|
||||
)
|
||||
async with run_server(config):
|
||||
task = asyncio.create_task(
|
||||
websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
||||
)
|
||||
task = asyncio.create_task(websocket_session(f"ws://127.0.0.1:{unused_tcp_port}"))
|
||||
await asyncio.sleep(0.1)
|
||||
disconnect_message_before_shutdown = disconnect_message
|
||||
server_shutdown_event.set()
|
||||
@ -891,9 +849,7 @@ async def test_subprotocols(
|
||||
port=unused_tcp_port,
|
||||
)
|
||||
async with run_server(config):
|
||||
accepted_subprotocol = await get_subprotocol(
|
||||
f"ws://127.0.0.1:{unused_tcp_port}"
|
||||
)
|
||||
accepted_subprotocol = await get_subprotocol(f"ws://127.0.0.1:{unused_tcp_port}")
|
||||
assert accepted_subprotocol == subprotocol
|
||||
|
||||
|
||||
@ -1257,7 +1213,7 @@ async def test_server_multiple_websocket_http_response_start_events(
|
||||
The server should raise an exception if it sends multiple
|
||||
websocket.http.response.start events.
|
||||
"""
|
||||
exception_message: typing.Optional[str] = None
|
||||
exception_message: str | None = None
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
nonlocal exception_message
|
||||
@ -1297,8 +1253,7 @@ async def test_server_multiple_websocket_http_response_start_events(
|
||||
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
||||
|
||||
assert exception_message == (
|
||||
"Expected ASGI message 'websocket.http.response.body' but got "
|
||||
"'websocket.http.response.start'."
|
||||
"Expected ASGI message 'websocket.http.response.body' but got " "'websocket.http.response.start'."
|
||||
)
|
||||
|
||||
|
||||
@ -1369,9 +1324,7 @@ async def test_default_server_headers(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_no_server_headers(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_no_server_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
class App(WebSocketResponse):
|
||||
async def websocket_connect(self, message):
|
||||
await self.send({"type": "websocket.accept"})
|
||||
@ -1395,9 +1348,7 @@ async def test_no_server_headers(
|
||||
|
||||
@pytest.mark.anyio
|
||||
@skip_if_no_wsproto
|
||||
async def test_no_date_header_on_wsproto(
|
||||
http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_no_date_header_on_wsproto(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
class App(WebSocketResponse):
|
||||
async def websocket_connect(self, message):
|
||||
await self.send({"type": "websocket.accept"})
|
||||
@ -1452,9 +1403,7 @@ async def test_multiple_server_header(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_lifespan_state(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
expected_states = [
|
||||
{"a": 123, "b": [1]},
|
||||
{"a": 123, "b": [1, 2]},
|
||||
@ -1462,9 +1411,7 @@ async def test_lifespan_state(
|
||||
|
||||
actual_states = []
|
||||
|
||||
async def lifespan_app(
|
||||
scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||
):
|
||||
async def lifespan_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.startup" and "state" in scope
|
||||
scope["state"]["a"] = 123
|
||||
@ -1485,9 +1432,7 @@ async def test_lifespan_state(
|
||||
async with websockets.client.connect(url) as websocket:
|
||||
return websocket.open
|
||||
|
||||
async def app_wrapper(
|
||||
scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||
):
|
||||
async def app_wrapper(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
if scope["type"] == "lifespan":
|
||||
return await lifespan_app(scope, receive, send)
|
||||
return await App(scope, receive, send)
|
||||
|
||||
@ -15,10 +15,7 @@ class Response:
|
||||
{
|
||||
"type": prefix + "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": [
|
||||
[key.encode(), value.encode()]
|
||||
for key, value in self.headers.items()
|
||||
],
|
||||
"headers": [[key.encode(), value.encode()] for key, value in self.headers.items()],
|
||||
}
|
||||
)
|
||||
await send({"type": prefix + "http.response.body", "body": self.body})
|
||||
|
||||
@ -1,19 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import signal
|
||||
import socket
|
||||
from typing import List, Optional
|
||||
|
||||
from uvicorn import Config
|
||||
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
from uvicorn.supervisors import Multiprocess
|
||||
|
||||
|
||||
async def app(
|
||||
scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
def run(sockets: Optional[List[socket.socket]]) -> None:
|
||||
def run(sockets: list[socket.socket] | None) -> None:
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import platform
|
||||
import signal
|
||||
@ -5,7 +7,6 @@ import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
|
||||
@ -41,7 +42,7 @@ class TestBaseReload:
|
||||
def setup(
|
||||
self,
|
||||
reload_directory_structure: Path,
|
||||
reloader_class: Optional[Type[BaseReload]],
|
||||
reloader_class: type[BaseReload] | None,
|
||||
):
|
||||
if reloader_class is None: # pragma: no cover
|
||||
pytest.skip("Needed dependency not installed")
|
||||
@ -61,9 +62,7 @@ class TestBaseReload:
|
||||
reloader.startup()
|
||||
return reloader
|
||||
|
||||
def _reload_tester(
|
||||
self, touch_soon, reloader: BaseReload, *files: Path
|
||||
) -> Optional[List[Path]]:
|
||||
def _reload_tester(self, touch_soon, reloader: BaseReload, *files: Path) -> list[Path] | None:
|
||||
reloader.restart()
|
||||
if WatchFilesReload is not None and isinstance(reloader, WatchFilesReload):
|
||||
touch_soon(*files)
|
||||
@ -74,9 +73,7 @@ class TestBaseReload:
|
||||
file.touch()
|
||||
return next(reloader)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reloader_class", [StatReload, WatchGodReload, WatchFilesReload]
|
||||
)
|
||||
@pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload])
|
||||
def test_reloader_should_initialize(self) -> None:
|
||||
"""
|
||||
A basic sanity check.
|
||||
@ -89,9 +86,7 @@ class TestBaseReload:
|
||||
reloader = self._setup_reloader(config)
|
||||
reloader.shutdown()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reloader_class", [StatReload, WatchGodReload, WatchFilesReload]
|
||||
)
|
||||
@pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload])
|
||||
def test_reload_when_python_file_is_changed(self, touch_soon) -> None:
|
||||
file = self.reload_path / "main.py"
|
||||
|
||||
@ -104,12 +99,8 @@ class TestBaseReload:
|
||||
|
||||
reloader.shutdown()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reloader_class", [StatReload, WatchGodReload, WatchFilesReload]
|
||||
)
|
||||
def test_should_reload_when_python_file_in_subdir_is_changed(
|
||||
self, touch_soon
|
||||
) -> None:
|
||||
@pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload])
|
||||
def test_should_reload_when_python_file_in_subdir_is_changed(self, touch_soon) -> None:
|
||||
file = self.reload_path / "app" / "sub" / "sub.py"
|
||||
|
||||
with as_cwd(self.reload_path):
|
||||
@ -121,9 +112,7 @@ class TestBaseReload:
|
||||
reloader.shutdown()
|
||||
|
||||
@pytest.mark.parametrize("reloader_class", [WatchFilesReload, WatchGodReload])
|
||||
def test_should_not_reload_when_python_file_in_excluded_subdir_is_changed(
|
||||
self, touch_soon
|
||||
) -> None:
|
||||
def test_should_not_reload_when_python_file_in_excluded_subdir_is_changed(self, touch_soon) -> None:
|
||||
sub_dir = self.reload_path / "app" / "sub"
|
||||
sub_file = sub_dir / "sub.py"
|
||||
|
||||
@ -139,18 +128,12 @@ class TestBaseReload:
|
||||
|
||||
reloader.shutdown()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reloader_class, result", [(StatReload, False), (WatchFilesReload, True)]
|
||||
)
|
||||
def test_reload_when_pattern_matched_file_is_changed(
|
||||
self, result: bool, touch_soon
|
||||
) -> None:
|
||||
@pytest.mark.parametrize("reloader_class, result", [(StatReload, False), (WatchFilesReload, True)])
|
||||
def test_reload_when_pattern_matched_file_is_changed(self, result: bool, touch_soon) -> None:
|
||||
file = self.reload_path / "app" / "js" / "main.js"
|
||||
|
||||
with as_cwd(self.reload_path):
|
||||
config = Config(
|
||||
app="tests.test_config:asgi_app", reload=True, reload_includes=["*.js"]
|
||||
)
|
||||
config = Config(app="tests.test_config:asgi_app", reload=True, reload_includes=["*.js"])
|
||||
reloader = self._setup_reloader(config)
|
||||
|
||||
assert bool(self._reload_tester(touch_soon, reloader, file)) == result
|
||||
@ -164,9 +147,7 @@ class TestBaseReload:
|
||||
WatchGodReload,
|
||||
],
|
||||
)
|
||||
def test_should_not_reload_when_exclude_pattern_match_file_is_changed(
|
||||
self, touch_soon
|
||||
) -> None:
|
||||
def test_should_not_reload_when_exclude_pattern_match_file_is_changed(self, touch_soon) -> None:
|
||||
python_file = self.reload_path / "app" / "src" / "main.py"
|
||||
css_file = self.reload_path / "app" / "css" / "main.css"
|
||||
js_file = self.reload_path / "app" / "js" / "main.js"
|
||||
@ -186,9 +167,7 @@ class TestBaseReload:
|
||||
|
||||
reloader.shutdown()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reloader_class", [StatReload, WatchGodReload, WatchFilesReload]
|
||||
)
|
||||
@pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload])
|
||||
def test_should_not_reload_when_dot_file_is_changed(self, touch_soon) -> None:
|
||||
file = self.reload_path / ".dotted"
|
||||
|
||||
@ -200,9 +179,7 @@ class TestBaseReload:
|
||||
|
||||
reloader.shutdown()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reloader_class", [StatReload, WatchGodReload, WatchFilesReload]
|
||||
)
|
||||
@pytest.mark.parametrize("reloader_class", [StatReload, WatchGodReload, WatchFilesReload])
|
||||
def test_should_reload_when_directories_have_same_prefix(self, touch_soon) -> None:
|
||||
app_dir = self.reload_path / "app"
|
||||
app_file = app_dir / "src" / "main.py"
|
||||
@ -230,9 +207,7 @@ class TestBaseReload:
|
||||
pytest.param(WatchFilesReload, marks=skip_if_m1),
|
||||
],
|
||||
)
|
||||
def test_should_not_reload_when_only_subdirectory_is_watched(
|
||||
self, touch_soon
|
||||
) -> None:
|
||||
def test_should_not_reload_when_only_subdirectory_is_watched(self, touch_soon) -> None:
|
||||
app_dir = self.reload_path / "app"
|
||||
app_dir_file = self.reload_path / "app" / "src" / "main.py"
|
||||
root_file = self.reload_path / "main.py"
|
||||
@ -245,9 +220,7 @@ class TestBaseReload:
|
||||
reloader = self._setup_reloader(config)
|
||||
|
||||
assert self._reload_tester(touch_soon, reloader, app_dir_file)
|
||||
assert not self._reload_tester(
|
||||
touch_soon, reloader, root_file, app_dir / "~ignored"
|
||||
)
|
||||
assert not self._reload_tester(touch_soon, reloader, root_file, app_dir / "~ignored")
|
||||
|
||||
reloader.shutdown()
|
||||
|
||||
@ -335,9 +308,7 @@ class TestBaseReload:
|
||||
reloader.shutdown()
|
||||
|
||||
@pytest.mark.parametrize("reloader_class", [WatchGodReload])
|
||||
def test_should_detect_new_reload_dirs(
|
||||
self, touch_soon, caplog: pytest.LogCaptureFixture, tmp_path: Path
|
||||
) -> None:
|
||||
def test_should_detect_new_reload_dirs(self, touch_soon, caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None:
|
||||
app_dir = tmp_path / "app"
|
||||
app_file = app_dir / "file.py"
|
||||
app_dir.mkdir()
|
||||
@ -346,9 +317,7 @@ class TestBaseReload:
|
||||
app_first_file = app_first_dir / "file.py"
|
||||
|
||||
with as_cwd(tmp_path):
|
||||
config = Config(
|
||||
app="tests.test_config:asgi_app", reload=True, reload_includes=["app*"]
|
||||
)
|
||||
config = Config(app="tests.test_config:asgi_app", reload=True, reload_includes=["app*"])
|
||||
reloader = self._setup_reloader(config)
|
||||
assert self._reload_tester(touch_soon, reloader, app_file)
|
||||
|
||||
|
||||
@ -29,9 +29,7 @@ async def test_sigint_finish_req(unused_tcp_port: int):
|
||||
await server_event.wait()
|
||||
await send({"type": "http.response.body", "body": b"end", "more_body": False})
|
||||
|
||||
config = Config(
|
||||
app=wait_app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1
|
||||
)
|
||||
config = Config(app=wait_app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1)
|
||||
server: Server
|
||||
async with run_server(config) as server:
|
||||
async with httpx.AsyncClient() as client:
|
||||
@ -64,9 +62,7 @@ async def test_sigint_abort_req(unused_tcp_port: int, caplog):
|
||||
await server_event.wait()
|
||||
await send({"type": "http.response.body", "body": b"end", "more_body": False})
|
||||
|
||||
config = Config(
|
||||
app=forever_app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1
|
||||
)
|
||||
config = Config(app=forever_app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1)
|
||||
server: Server
|
||||
async with run_server(config) as server:
|
||||
async with httpx.AsyncClient() as client:
|
||||
@ -78,10 +74,7 @@ async def test_sigint_abort_req(unused_tcp_port: int, caplog):
|
||||
await req
|
||||
|
||||
# req.result()
|
||||
assert (
|
||||
"Cancel 1 running task(s), timeout graceful shutdown exceeded"
|
||||
in caplog.messages
|
||||
)
|
||||
assert "Cancel 1 running task(s), timeout graceful shutdown exceeded" in caplog.messages
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@ -99,9 +92,7 @@ async def test_sigint_deny_request_after_triggered(unused_tcp_port: int, caplog)
|
||||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||||
await asyncio.sleep(1)
|
||||
|
||||
config = Config(
|
||||
app=app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1
|
||||
)
|
||||
config = Config(app=app, reload=False, port=unused_tcp_port, timeout_graceful_shutdown=1)
|
||||
server: Server
|
||||
async with run_server(config) as server:
|
||||
# exit and ensure we do not accept more requests
|
||||
|
||||
@ -59,7 +59,5 @@ async def test_websocket_auto():
|
||||
server_state = ServerState()
|
||||
|
||||
assert AutoWebSocketsProtocol is not None
|
||||
protocol = AutoWebSocketsProtocol(
|
||||
config=config, server_state=server_state, app_state={}
|
||||
)
|
||||
protocol = AutoWebSocketsProtocol(config=config, server_state=server_state, app_state={})
|
||||
assert type(protocol).__name__ == expected_websockets
|
||||
|
||||
@ -41,12 +41,11 @@ def test_cli_print_version() -> None:
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert (
|
||||
"Running uvicorn %s with %s %s on %s"
|
||||
% (
|
||||
uvicorn.__version__,
|
||||
platform.python_implementation(),
|
||||
platform.python_version(),
|
||||
platform.system(),
|
||||
"Running uvicorn {version} with {py_implementation} {py_version} on {system}".format(
|
||||
version=uvicorn.__version__,
|
||||
py_implementation=platform.python_implementation(),
|
||||
py_version=platform.python_version(),
|
||||
system=platform.system(),
|
||||
)
|
||||
) in result.output
|
||||
|
||||
@ -103,9 +102,7 @@ def test_cli_call_multiprocess_run() -> None:
|
||||
|
||||
|
||||
@pytest.fixture(params=(True, False))
|
||||
def uds_file(
|
||||
tmp_path: Path, request: pytest.FixtureRequest
|
||||
) -> Path: # pragma: py-win32
|
||||
def uds_file(tmp_path: Path, request: pytest.FixtureRequest) -> Path: # pragma: py-win32
|
||||
file = tmp_path / "uvicorn.sock"
|
||||
should_create_file = request.param
|
||||
if should_create_file:
|
||||
@ -119,9 +116,7 @@ def test_cli_uds(uds_file: Path) -> None: # pragma: py-win32
|
||||
|
||||
with mock.patch.object(Config, "bind_socket") as mock_bind_socket:
|
||||
with mock.patch.object(Multiprocess, "run") as mock_run:
|
||||
result = runner.invoke(
|
||||
cli, ["tests.test_cli:App", "--workers=2", "--uds", str(uds_file)]
|
||||
)
|
||||
result = runner.invoke(cli, ["tests.test_cli:App", "--workers=2", "--uds", str(uds_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert result.output == ""
|
||||
@ -136,8 +131,7 @@ def test_cli_incomplete_app_parameter() -> None:
|
||||
result = runner.invoke(cli, ["tests.test_cli"])
|
||||
|
||||
assert (
|
||||
'Error loading ASGI app. Import string "tests.test_cli" '
|
||||
'must be in format "<module>:<attribute>".'
|
||||
'Error loading ASGI app. Import string "tests.test_cli" ' 'must be in format "<module>:<attribute>".'
|
||||
) in result.output
|
||||
assert result.exit_code == 1
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -5,7 +7,7 @@ import socket
|
||||
import sys
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -42,9 +44,7 @@ def yaml_logging_config(logging_config: dict) -> str:
|
||||
return yaml.dump(logging_config)
|
||||
|
||||
|
||||
async def asgi_app(
|
||||
scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
async def asgi_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
pass # pragma: nocover
|
||||
|
||||
|
||||
@ -56,65 +56,44 @@ def wsgi_app(environ: Environ, start_response: StartResponse) -> None:
|
||||
"app, expected_should_reload",
|
||||
[(asgi_app, False), ("tests.test_config:asgi_app", True)],
|
||||
)
|
||||
def test_config_should_reload_is_set(
|
||||
app: "ASGIApplication", expected_should_reload: bool
|
||||
) -> None:
|
||||
def test_config_should_reload_is_set(app: ASGIApplication, expected_should_reload: bool) -> None:
|
||||
config = Config(app=app, reload=True)
|
||||
assert config.reload is True
|
||||
assert config.should_reload is expected_should_reload
|
||||
|
||||
|
||||
def test_should_warn_on_invalid_reload_configuration(
|
||||
tmp_path: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
def test_should_warn_on_invalid_reload_configuration(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
|
||||
config_class = Config(app=asgi_app, reload_dirs=[str(tmp_path)])
|
||||
assert not config_class.should_reload
|
||||
assert len(caplog.records) == 1
|
||||
assert (
|
||||
caplog.records[-1].message
|
||||
== "Current configuration will not reload as not all conditions are met, "
|
||||
caplog.records[-1].message == "Current configuration will not reload as not all conditions are met, "
|
||||
"please refer to documentation."
|
||||
)
|
||||
|
||||
config_no_reload = Config(
|
||||
app="tests.test_config:asgi_app", reload_dirs=[str(tmp_path)]
|
||||
)
|
||||
config_no_reload = Config(app="tests.test_config:asgi_app", reload_dirs=[str(tmp_path)])
|
||||
assert not config_no_reload.should_reload
|
||||
assert len(caplog.records) == 2
|
||||
assert (
|
||||
caplog.records[-1].message
|
||||
== "Current configuration will not reload as not all conditions are met, "
|
||||
caplog.records[-1].message == "Current configuration will not reload as not all conditions are met, "
|
||||
"please refer to documentation."
|
||||
)
|
||||
|
||||
|
||||
def test_reload_dir_is_set(
|
||||
reload_directory_structure: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
def test_reload_dir_is_set(reload_directory_structure: Path, caplog: pytest.LogCaptureFixture) -> None:
|
||||
app_dir = reload_directory_structure / "app"
|
||||
with caplog.at_level(logging.INFO):
|
||||
config = Config(
|
||||
app="tests.test_config:asgi_app", reload=True, reload_dirs=[str(app_dir)]
|
||||
)
|
||||
config = Config(app="tests.test_config:asgi_app", reload=True, reload_dirs=[str(app_dir)])
|
||||
assert len(caplog.records) == 1
|
||||
assert (
|
||||
caplog.records[-1].message
|
||||
== f"Will watch for changes in these directories: {[str(app_dir)]}"
|
||||
)
|
||||
assert caplog.records[-1].message == f"Will watch for changes in these directories: {[str(app_dir)]}"
|
||||
assert config.reload_dirs == [app_dir]
|
||||
config = Config(
|
||||
app="tests.test_config:asgi_app", reload=True, reload_dirs=str(app_dir)
|
||||
)
|
||||
config = Config(app="tests.test_config:asgi_app", reload=True, reload_dirs=str(app_dir))
|
||||
assert config.reload_dirs == [app_dir]
|
||||
|
||||
|
||||
def test_non_existant_reload_dir_is_not_set(
|
||||
reload_directory_structure: Path, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
def test_non_existant_reload_dir_is_not_set(reload_directory_structure: Path, caplog: pytest.LogCaptureFixture) -> None:
|
||||
with as_cwd(reload_directory_structure), caplog.at_level(logging.WARNING):
|
||||
config = Config(
|
||||
app="tests.test_config:asgi_app", reload=True, reload_dirs=["reload"]
|
||||
)
|
||||
config = Config(app="tests.test_config:asgi_app", reload=True, reload_dirs=["reload"])
|
||||
assert config.reload_dirs == [reload_directory_structure]
|
||||
assert (
|
||||
caplog.records[-1].message
|
||||
@ -129,9 +108,7 @@ def test_reload_subdir_removal(reload_directory_structure: Path) -> None:
|
||||
reload_dirs = [str(reload_directory_structure), "app", str(app_dir)]
|
||||
|
||||
with as_cwd(reload_directory_structure):
|
||||
config = Config(
|
||||
app="tests.test_config:asgi_app", reload=True, reload_dirs=reload_dirs
|
||||
)
|
||||
config = Config(app="tests.test_config:asgi_app", reload=True, reload_dirs=reload_dirs)
|
||||
assert config.reload_dirs == [reload_directory_structure]
|
||||
|
||||
|
||||
@ -188,9 +165,7 @@ def test_reload_excluded_subdirectories_are_removed(
|
||||
)
|
||||
assert frozenset(config.reload_dirs) == frozenset([reload_directory_structure])
|
||||
assert frozenset(config.reload_dirs_excludes) == frozenset([app_dir])
|
||||
assert frozenset(config.reload_excludes) == frozenset(
|
||||
[str(app_dir), str(app_sub_dir)]
|
||||
)
|
||||
assert frozenset(config.reload_excludes) == frozenset([str(app_dir), str(app_sub_dir)])
|
||||
|
||||
|
||||
def test_reload_includes_exclude_dir_patterns_are_matched(
|
||||
@ -209,13 +184,10 @@ def test_reload_includes_exclude_dir_patterns_are_matched(
|
||||
)
|
||||
assert len(caplog.records) == 1
|
||||
assert (
|
||||
caplog.records[-1].message
|
||||
== "Will watch for changes in these directories: "
|
||||
caplog.records[-1].message == "Will watch for changes in these directories: "
|
||||
f"{sorted([str(first_app_dir), str(second_app_dir)])}"
|
||||
)
|
||||
assert frozenset(config.reload_dirs) == frozenset(
|
||||
[first_app_dir, second_app_dir]
|
||||
)
|
||||
assert frozenset(config.reload_dirs) == frozenset([first_app_dir, second_app_dir])
|
||||
assert config.reload_includes == ["*/src"]
|
||||
|
||||
|
||||
@ -247,9 +219,7 @@ def test_app_unimportable_other(caplog: pytest.LogCaptureFixture) -> None:
|
||||
with pytest.raises(SystemExit):
|
||||
config.load()
|
||||
error_messages = [
|
||||
record.message
|
||||
for record in caplog.records
|
||||
if record.name == "uvicorn.error" and record.levelname == "ERROR"
|
||||
record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR"
|
||||
]
|
||||
assert (
|
||||
'Error loading ASGI app. Attribute "app" not found in module "tests.test_config".' # noqa: E501
|
||||
@ -258,7 +228,7 @@ def test_app_unimportable_other(caplog: pytest.LogCaptureFixture) -> None:
|
||||
|
||||
|
||||
def test_app_factory(caplog: pytest.LogCaptureFixture) -> None:
|
||||
def create_app() -> "ASGIApplication":
|
||||
def create_app() -> ASGIApplication:
|
||||
return asgi_app
|
||||
|
||||
config = Config(app=create_app, factory=True, proxy_headers=False)
|
||||
@ -319,21 +289,15 @@ def test_ssl_config_combined(tls_certificate_key_and_chain_path: str) -> None:
|
||||
assert config.is_ssl is True
|
||||
|
||||
|
||||
def asgi2_app(scope: "Scope") -> typing.Callable:
|
||||
async def asgi(
|
||||
receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None: # pragma: nocover
|
||||
def asgi2_app(scope: Scope) -> typing.Callable:
|
||||
async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: # pragma: nocover
|
||||
pass
|
||||
|
||||
return asgi # pragma: nocover
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"app, expected_interface", [(asgi_app, "3.0"), (asgi2_app, "2.0")]
|
||||
)
|
||||
def test_asgi_version(
|
||||
app: "ASGIApplication", expected_interface: Literal["2.0", "3.0"]
|
||||
) -> None:
|
||||
@pytest.mark.parametrize("app, expected_interface", [(asgi_app, "3.0"), (asgi2_app, "2.0")])
|
||||
def test_asgi_version(app: ASGIApplication, expected_interface: Literal["2.0", "3.0"]) -> None:
|
||||
config = Config(app=app)
|
||||
config.load()
|
||||
assert config.asgi_version == expected_interface
|
||||
@ -350,9 +314,9 @@ def test_asgi_version(
|
||||
)
|
||||
def test_log_config_default(
|
||||
mocked_logging_config_module: MagicMock,
|
||||
use_colors: typing.Optional[bool],
|
||||
expected: typing.Optional[bool],
|
||||
logging_config,
|
||||
use_colors: bool | None,
|
||||
expected: bool | None,
|
||||
logging_config: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Test that one can specify the use_colors option when using the default logging
|
||||
@ -369,16 +333,14 @@ def test_log_config_default(
|
||||
|
||||
def test_log_config_json(
|
||||
mocked_logging_config_module: MagicMock,
|
||||
logging_config: dict,
|
||||
logging_config: dict[str, Any],
|
||||
json_logging_config: str,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that one can load a json config from disk.
|
||||
"""
|
||||
mocked_open = mocker.patch(
|
||||
"uvicorn.config.open", mocker.mock_open(read_data=json_logging_config)
|
||||
)
|
||||
mocked_open = mocker.patch("uvicorn.config.open", mocker.mock_open(read_data=json_logging_config))
|
||||
|
||||
config = Config(app=asgi_app, log_config="log_config.json")
|
||||
config.load()
|
||||
@ -390,7 +352,7 @@ def test_log_config_json(
|
||||
@pytest.mark.parametrize("config_filename", ["log_config.yml", "log_config.yaml"])
|
||||
def test_log_config_yaml(
|
||||
mocked_logging_config_module: MagicMock,
|
||||
logging_config: dict,
|
||||
logging_config: dict[str, Any],
|
||||
yaml_logging_config: str,
|
||||
mocker: MockerFixture,
|
||||
config_filename: str,
|
||||
@ -398,9 +360,7 @@ def test_log_config_yaml(
|
||||
"""
|
||||
Test that one can load a yaml config from disk.
|
||||
"""
|
||||
mocked_open = mocker.patch(
|
||||
"uvicorn.config.open", mocker.mock_open(read_data=yaml_logging_config)
|
||||
)
|
||||
mocked_open = mocker.patch("uvicorn.config.open", mocker.mock_open(read_data=yaml_logging_config))
|
||||
|
||||
config = Config(app=asgi_app, log_config=config_filename)
|
||||
config.load()
|
||||
@ -416,9 +376,7 @@ def test_log_config_file(mocked_logging_config_module: MagicMock) -> None:
|
||||
config = Config(app=asgi_app, log_config="log_config")
|
||||
config.load()
|
||||
|
||||
mocked_logging_config_module.fileConfig.assert_called_once_with(
|
||||
"log_config", disable_existing_loggers=False
|
||||
)
|
||||
mocked_logging_config_module.fileConfig.assert_called_once_with("log_config", disable_existing_loggers=False)
|
||||
|
||||
|
||||
@pytest.fixture(params=[0, 1])
|
||||
@ -445,10 +403,7 @@ def test_env_file(
|
||||
Test that one can load environment variables using an env file.
|
||||
"""
|
||||
fp = tmp_path / ".env"
|
||||
content = (
|
||||
f"WEB_CONCURRENCY={web_concurrency}\n"
|
||||
f"FORWARDED_ALLOW_IPS={forwarded_allow_ips}\n"
|
||||
)
|
||||
content = f"WEB_CONCURRENCY={web_concurrency}\n" f"FORWARDED_ALLOW_IPS={forwarded_allow_ips}\n"
|
||||
fp.write_text(content)
|
||||
with caplog.at_level(logging.INFO):
|
||||
config = Config(app=asgi_app, env_file=fp)
|
||||
@ -488,9 +443,7 @@ def test_config_log_level(log_level: int) -> None:
|
||||
|
||||
@pytest.mark.parametrize("log_level", [None, 0, 5, 10, 20, 30, 40, 50])
|
||||
@pytest.mark.parametrize("uvicorn_logger_level", [0, 5, 10, 20, 30, 40, 50])
|
||||
def test_config_log_effective_level(
|
||||
log_level: Optional[int], uvicorn_logger_level: Optional[int]
|
||||
) -> None:
|
||||
def test_config_log_effective_level(log_level: int, uvicorn_logger_level: int) -> None:
|
||||
default_level = 30
|
||||
log_config = {
|
||||
"version": 1,
|
||||
@ -530,7 +483,7 @@ def test_ws_max_queue() -> None:
|
||||
)
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system")
|
||||
def test_bind_unix_socket_works_with_reload_or_workers(
|
||||
tmp_path, reload, workers, short_socket_name
|
||||
tmp_path: Path, reload: bool, workers: int, short_socket_name: str
|
||||
): # pragma: py-win32
|
||||
config = Config(app=asgi_app, uds=short_socket_name, reload=reload, workers=workers)
|
||||
config.load()
|
||||
@ -550,7 +503,7 @@ def test_bind_unix_socket_works_with_reload_or_workers(
|
||||
ids=["--reload=True --workers=1", "--reload=False --workers=2"],
|
||||
)
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system")
|
||||
def test_bind_fd_works_with_reload_or_workers(reload, workers): # pragma: py-win32
|
||||
def test_bind_fd_works_with_reload_or_workers(reload: bool, workers: int): # pragma: py-win32
|
||||
fdsock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
fd = fdsock.fileno()
|
||||
config = Config(app=asgi_app, fd=fd, reload=reload, workers=workers)
|
||||
@ -576,7 +529,7 @@ def test_bind_fd_works_with_reload_or_workers(reload, workers): # pragma: py-wi
|
||||
"--reload=False --workers=1",
|
||||
],
|
||||
)
|
||||
def test_config_use_subprocess(reload, workers, expected):
|
||||
def test_config_use_subprocess(reload: bool, workers: int, expected: bool):
|
||||
config = Config(app=asgi_app, reload=reload, workers=workers)
|
||||
config.load()
|
||||
assert config.use_subprocess == expected
|
||||
@ -585,7 +538,4 @@ def test_config_use_subprocess(reload, workers, expected):
|
||||
def test_warn_when_using_reload_and_workers(caplog: pytest.LogCaptureFixture) -> None:
|
||||
Config(app=asgi_app, reload=True, workers=2)
|
||||
assert len(caplog.records) == 1
|
||||
assert (
|
||||
'"workers" flag is ignored when reloading is enabled.'
|
||||
in caplog.records[0].message
|
||||
)
|
||||
assert '"workers" flag is ignored when reloading is enabled.' in caplog.records[0].message
|
||||
|
||||
@ -32,9 +32,7 @@ async def test_override_server_header(unused_tcp_port: int):
|
||||
async with run_server(config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
assert (
|
||||
response.headers["server"] == "over-ridden" and response.headers["date"]
|
||||
)
|
||||
assert response.headers["server"] == "over-ridden" and response.headers["date"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@ -64,10 +62,7 @@ async def test_override_server_header_multiple_times(unused_tcp_port: int):
|
||||
async with run_server(config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
assert (
|
||||
response.headers["server"] == "over-ridden, another-value"
|
||||
and response.headers["date"]
|
||||
)
|
||||
assert response.headers["server"] == "over-ridden, another-value" and response.headers["date"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@ -132,9 +132,7 @@ def test_lifespan_with_failed_startup(mode, raise_exception, caplog):
|
||||
async def app(scope, receive, send):
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.startup"
|
||||
await send(
|
||||
{"type": "lifespan.startup.failed", "message": "the lifespan event failed"}
|
||||
)
|
||||
await send({"type": "lifespan.startup.failed", "message": "the lifespan event failed"})
|
||||
if raise_exception:
|
||||
# App should be able to re-raise an exception if startup failed.
|
||||
raise RuntimeError()
|
||||
@ -153,9 +151,7 @@ def test_lifespan_with_failed_startup(mode, raise_exception, caplog):
|
||||
loop.run_until_complete(test())
|
||||
loop.close()
|
||||
error_messages = [
|
||||
record.message
|
||||
for record in caplog.records
|
||||
if record.name == "uvicorn.error" and record.levelname == "ERROR"
|
||||
record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR"
|
||||
]
|
||||
assert "the lifespan event failed" in error_messages.pop(0)
|
||||
assert "Application startup failed. Exiting." in error_messages.pop(0)
|
||||
@ -218,9 +214,7 @@ def test_lifespan_with_failed_shutdown(mode, raise_exception, caplog):
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.shutdown"
|
||||
await send(
|
||||
{"type": "lifespan.shutdown.failed", "message": "the lifespan event failed"}
|
||||
)
|
||||
await send({"type": "lifespan.shutdown.failed", "message": "the lifespan event failed"})
|
||||
|
||||
if raise_exception:
|
||||
# App should be able to re-raise an exception if startup failed.
|
||||
@ -240,9 +234,7 @@ def test_lifespan_with_failed_shutdown(mode, raise_exception, caplog):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(test())
|
||||
error_messages = [
|
||||
record.message
|
||||
for record in caplog.records
|
||||
if record.name == "uvicorn.error" and record.levelname == "ERROR"
|
||||
record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR"
|
||||
]
|
||||
assert "the lifespan event failed" in error_messages.pop(0)
|
||||
assert "Application shutdown failed. Exiting." in error_messages.pop(0)
|
||||
|
||||
@ -47,9 +47,7 @@ def _has_ipv6(host):
|
||||
],
|
||||
)
|
||||
async def test_run(host, url: str, unused_tcp_port: int):
|
||||
config = Config(
|
||||
app=app, host=host, loop="asyncio", limit_max_requests=1, port=unused_tcp_port
|
||||
)
|
||||
config = Config(app=app, host=host, loop="asyncio", limit_max_requests=1, port=unused_tcp_port)
|
||||
async with run_server(config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{url}:{unused_tcp_port}")
|
||||
@ -58,9 +56,7 @@ async def test_run(host, url: str, unused_tcp_port: int):
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_multiprocess(unused_tcp_port: int):
|
||||
config = Config(
|
||||
app=app, loop="asyncio", workers=2, limit_max_requests=1, port=unused_tcp_port
|
||||
)
|
||||
config = Config(app=app, loop="asyncio", workers=2, limit_max_requests=1, port=unused_tcp_port)
|
||||
async with run_server(config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
@ -69,9 +65,7 @@ async def test_run_multiprocess(unused_tcp_port: int):
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_reload(unused_tcp_port: int):
|
||||
config = Config(
|
||||
app=app, loop="asyncio", reload=True, limit_max_requests=1, port=unused_tcp_port
|
||||
)
|
||||
config = Config(app=app, loop="asyncio", reload=True, limit_max_requests=1, port=unused_tcp_port)
|
||||
async with run_server(config):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
|
||||
@ -85,8 +79,7 @@ def test_run_invalid_app_config_combination(caplog: pytest.LogCaptureFixture) ->
|
||||
assert caplog.records[-1].name == "uvicorn.error"
|
||||
assert caplog.records[-1].levelno == WARNING
|
||||
assert caplog.records[-1].message == (
|
||||
"You must pass the application as an import string to enable "
|
||||
"'reload' or 'workers'."
|
||||
"You must pass the application as an import string to enable " "'reload' or 'workers'."
|
||||
)
|
||||
|
||||
|
||||
@ -109,9 +102,7 @@ def test_run_match_config_params() -> None:
|
||||
if key not in ("self", "timeout_notify", "callback_notify")
|
||||
}
|
||||
run_params = {
|
||||
key: repr(value)
|
||||
for key, value in inspect.signature(run).parameters.items()
|
||||
if key not in ("app_dir",)
|
||||
key: repr(value) for key, value in inspect.signature(run).parameters.items() if key not in ("app_dir",)
|
||||
}
|
||||
assert config_params == run_params
|
||||
|
||||
|
||||
@ -56,9 +56,7 @@ async def test_run_chain(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_chain_only(
|
||||
tls_ca_ssl_context, tls_certificate_key_and_chain_path, unused_tcp_port: int
|
||||
):
|
||||
async def test_run_chain_only(tls_ca_ssl_context, tls_certificate_key_and_chain_path, unused_tcp_port: int):
|
||||
config = Config(
|
||||
app=app,
|
||||
loop="asyncio",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
from uvicorn._subprocess import SpawnProcess, get_subprocess, subprocess_started
|
||||
@ -7,13 +8,11 @@ from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
from uvicorn.config import Config
|
||||
|
||||
|
||||
def server_run(sockets: List[socket.socket]): # pragma: no cover
|
||||
def server_run(sockets: list[socket.socket]): # pragma: no cover
|
||||
...
|
||||
|
||||
|
||||
async def app(
|
||||
scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None: # pragma: no cover
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: # pragma: no cover
|
||||
...
|
||||
|
||||
|
||||
|
||||
@ -36,18 +36,16 @@ from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: py-lt-38
|
||||
from typing import Literal, Protocol, TypedDict
|
||||
else: # pragma: py-gte-38
|
||||
from typing_extensions import Literal, Protocol, TypedDict
|
||||
|
||||
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
||||
from typing import NotRequired
|
||||
else: # pragma: py-gte-311
|
||||
@ -239,9 +237,7 @@ class LifespanShutdownFailedEvent(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
WebSocketEvent = Union[
|
||||
WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent
|
||||
]
|
||||
WebSocketEvent = Union[WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent]
|
||||
|
||||
|
||||
ASGIReceiveEvent = Union[
|
||||
@ -281,9 +277,7 @@ class ASGI2Protocol(Protocol):
|
||||
def __init__(self, scope: Scope) -> None:
|
||||
... # pragma: no cover
|
||||
|
||||
async def __call__(
|
||||
self, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||
) -> None:
|
||||
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
|
||||
@ -127,9 +127,7 @@ def is_dir(path: Path) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def resolve_reload_patterns(
|
||||
patterns_list: list[str], directories_list: list[str]
|
||||
) -> tuple[list[str], list[Path]]:
|
||||
def resolve_reload_patterns(patterns_list: list[str], directories_list: list[str]) -> tuple[list[str], list[Path]]:
|
||||
directories: list[Path] = list(set(map(Path, directories_list.copy())))
|
||||
patterns: list[str] = patterns_list.copy()
|
||||
|
||||
@ -150,9 +148,7 @@ def resolve_reload_patterns(
|
||||
directories = list(set(directories))
|
||||
directories = list(map(Path, directories))
|
||||
directories = list(map(lambda x: x.resolve(), directories))
|
||||
directories = list(
|
||||
{reload_path for reload_path in directories if is_dir(reload_path)}
|
||||
)
|
||||
directories = list({reload_path for reload_path in directories if is_dir(reload_path)})
|
||||
|
||||
children = []
|
||||
for j in range(len(directories)):
|
||||
@ -280,12 +276,9 @@ class Config:
|
||||
self.reload_includes: list[str] = []
|
||||
self.reload_excludes: list[str] = []
|
||||
|
||||
if (
|
||||
reload_dirs or reload_includes or reload_excludes
|
||||
) and not self.should_reload:
|
||||
if (reload_dirs or reload_includes or reload_excludes) and not self.should_reload:
|
||||
logger.warning(
|
||||
"Current configuration will not reload as not all conditions are met, "
|
||||
"please refer to documentation."
|
||||
"Current configuration will not reload as not all conditions are met, " "please refer to documentation."
|
||||
)
|
||||
|
||||
if self.should_reload:
|
||||
@ -293,22 +286,15 @@ class Config:
|
||||
reload_includes = _normalize_dirs(reload_includes)
|
||||
reload_excludes = _normalize_dirs(reload_excludes)
|
||||
|
||||
self.reload_includes, self.reload_dirs = resolve_reload_patterns(
|
||||
reload_includes, reload_dirs
|
||||
)
|
||||
self.reload_includes, self.reload_dirs = resolve_reload_patterns(reload_includes, reload_dirs)
|
||||
|
||||
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(
|
||||
reload_excludes, []
|
||||
)
|
||||
self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(reload_excludes, [])
|
||||
|
||||
reload_dirs_tmp = self.reload_dirs.copy()
|
||||
|
||||
for directory in self.reload_dirs_excludes:
|
||||
for reload_directory in reload_dirs_tmp:
|
||||
if (
|
||||
directory == reload_directory
|
||||
or directory in reload_directory.parents
|
||||
):
|
||||
if directory == reload_directory or directory in reload_directory.parents:
|
||||
try:
|
||||
self.reload_dirs.remove(reload_directory)
|
||||
except ValueError:
|
||||
@ -343,9 +329,7 @@ class Config:
|
||||
|
||||
self.forwarded_allow_ips: list[str] | str
|
||||
if forwarded_allow_ips is None:
|
||||
self.forwarded_allow_ips = os.environ.get(
|
||||
"FORWARDED_ALLOW_IPS", "127.0.0.1"
|
||||
)
|
||||
self.forwarded_allow_ips = os.environ.get("FORWARDED_ALLOW_IPS", "127.0.0.1")
|
||||
else:
|
||||
self.forwarded_allow_ips = forwarded_allow_ips
|
||||
|
||||
@ -375,12 +359,8 @@ class Config:
|
||||
if self.log_config is not None:
|
||||
if isinstance(self.log_config, dict):
|
||||
if self.use_colors in (True, False):
|
||||
self.log_config["formatters"]["default"][
|
||||
"use_colors"
|
||||
] = self.use_colors
|
||||
self.log_config["formatters"]["access"][
|
||||
"use_colors"
|
||||
] = self.use_colors
|
||||
self.log_config["formatters"]["default"]["use_colors"] = self.use_colors
|
||||
self.log_config["formatters"]["access"]["use_colors"] = self.use_colors
|
||||
logging.config.dictConfig(self.log_config)
|
||||
elif self.log_config.endswith(".json"):
|
||||
with open(self.log_config) as file:
|
||||
@ -397,9 +377,7 @@ class Config:
|
||||
else:
|
||||
# See the note about fileConfig() here:
|
||||
# https://docs.python.org/3/library/logging.config.html#configuration-file-format
|
||||
logging.config.fileConfig(
|
||||
self.log_config, disable_existing_loggers=False
|
||||
)
|
||||
logging.config.fileConfig(self.log_config, disable_existing_loggers=False)
|
||||
|
||||
if self.log_level is not None:
|
||||
if isinstance(self.log_level, str):
|
||||
@ -430,10 +408,7 @@ class Config:
|
||||
else:
|
||||
self.ssl = None
|
||||
|
||||
encoded_headers = [
|
||||
(key.lower().encode("latin1"), value.encode("latin1"))
|
||||
for key, value in self.headers
|
||||
]
|
||||
encoded_headers = [(key.lower().encode("latin1"), value.encode("latin1")) for key, value in self.headers]
|
||||
self.encoded_headers = (
|
||||
[(b"server", b"uvicorn")] + encoded_headers
|
||||
if b"server" not in dict(encoded_headers) and self.server_header
|
||||
@ -469,8 +444,7 @@ class Config:
|
||||
else:
|
||||
if not self.factory:
|
||||
logger.warning(
|
||||
"ASGI app factory detected. Using it, "
|
||||
"but please consider setting the --factory flag explicitly."
|
||||
"ASGI app factory detected. Using it, " "but please consider setting the --factory flag explicitly."
|
||||
)
|
||||
|
||||
if self.interface == "auto":
|
||||
@ -492,9 +466,7 @@ class Config:
|
||||
if logger.getEffectiveLevel() <= TRACE_LOG_LEVEL:
|
||||
self.loaded_app = MessageLoggerMiddleware(self.loaded_app)
|
||||
if self.proxy_headers:
|
||||
self.loaded_app = ProxyHeadersMiddleware(
|
||||
self.loaded_app, trusted_hosts=self.forwarded_allow_ips
|
||||
)
|
||||
self.loaded_app = ProxyHeadersMiddleware(self.loaded_app, trusted_hosts=self.forwarded_allow_ips)
|
||||
|
||||
self.loaded = True
|
||||
|
||||
@ -518,21 +490,13 @@ class Config:
|
||||
|
||||
message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)"
|
||||
sock_name_format = "%s"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(sock_name_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
color_message = "Uvicorn running on " + click.style(sock_name_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger_args = [self.uds]
|
||||
elif self.fd: # pragma: py-win32
|
||||
sock = socket.fromfd(self.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
message = "Uvicorn running on socket %s (Press CTRL+C to quit)"
|
||||
fd_name_format = "%s"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(fd_name_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
color_message = "Uvicorn running on " + click.style(fd_name_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger_args = [sock.getsockname()]
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
@ -552,11 +516,7 @@ class Config:
|
||||
sys.exit(1)
|
||||
|
||||
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(addr_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
protocol_name = "https" if self.is_ssl else "http"
|
||||
logger_args = [protocol_name, self.host, sock.getsockname()[1]]
|
||||
logger.info(message, *logger_args, extra={"color_message": color_message})
|
||||
|
||||
@ -12,9 +12,7 @@ def import_from_string(import_str: Any) -> Any:
|
||||
|
||||
module_str, _, attrs_str = import_str.partition(":")
|
||||
if not module_str or not attrs_str:
|
||||
message = (
|
||||
'Import string "{import_str}" must be in format "<module>:<attribute>".'
|
||||
)
|
||||
message = 'Import string "{import_str}" must be in format "<module>:<attribute>".'
|
||||
raise ImportFromStringError(message.format(import_str=import_str))
|
||||
|
||||
try:
|
||||
@ -31,8 +29,6 @@ def import_from_string(import_str: Any) -> Any:
|
||||
instance = getattr(instance, attr_str)
|
||||
except AttributeError:
|
||||
message = 'Attribute "{attrs_str}" not found in module "{module_str}".'
|
||||
raise ImportFromStringError(
|
||||
message.format(attrs_str=attrs_str, module_str=module_str)
|
||||
)
|
||||
raise ImportFromStringError(message.format(attrs_str=attrs_str, module_str=module_str))
|
||||
|
||||
return instance
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from uvicorn import Config
|
||||
|
||||
@ -6,7 +8,7 @@ from uvicorn import Config
|
||||
class LifespanOff:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.should_exit = False
|
||||
self.state: Dict[str, Any] = {}
|
||||
self.state: dict[str, Any] = {}
|
||||
|
||||
async def startup(self) -> None:
|
||||
pass
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from asyncio import Queue
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from uvicorn import Config
|
||||
from uvicorn._types import (
|
||||
@ -35,12 +37,12 @@ class LifespanOn:
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.startup_event = asyncio.Event()
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.receive_queue: "Queue[LifespanReceiveMessage]" = asyncio.Queue()
|
||||
self.receive_queue: Queue[LifespanReceiveMessage] = asyncio.Queue()
|
||||
self.error_occured = False
|
||||
self.startup_failed = False
|
||||
self.shutdown_failed = False
|
||||
self.should_exit = False
|
||||
self.state: Dict[str, Any] = {}
|
||||
self.state: dict[str, Any] = {}
|
||||
|
||||
async def startup(self) -> None:
|
||||
self.logger.info("Waiting for application startup.")
|
||||
@ -67,9 +69,7 @@ class LifespanOn:
|
||||
await self.receive_queue.put(shutdown_event)
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
if self.shutdown_failed or (
|
||||
self.error_occured and self.config.lifespan == "on"
|
||||
):
|
||||
if self.shutdown_failed or (self.error_occured and self.config.lifespan == "on"):
|
||||
self.logger.error("Application shutdown failed. Exiting.")
|
||||
self.should_exit = True
|
||||
else:
|
||||
@ -99,7 +99,7 @@ class LifespanOn:
|
||||
self.startup_event.set()
|
||||
self.shutdown_event.set()
|
||||
|
||||
async def send(self, message: "LifespanSendMessage") -> None:
|
||||
async def send(self, message: LifespanSendMessage) -> None:
|
||||
assert message["type"] in (
|
||||
"lifespan.startup.complete",
|
||||
"lifespan.startup.failed",
|
||||
@ -133,5 +133,5 @@ class LifespanOn:
|
||||
if message.get("message"):
|
||||
self.logger.error(message["message"])
|
||||
|
||||
async def receive(self) -> "LifespanReceiveMessage":
|
||||
async def receive(self) -> LifespanReceiveMessage:
|
||||
return await self.receive_queue.get()
|
||||
|
||||
@ -26,9 +26,7 @@ class ColourizedFormatter(logging.Formatter):
|
||||
logging.INFO: lambda level_name: click.style(str(level_name), fg="green"),
|
||||
logging.WARNING: lambda level_name: click.style(str(level_name), fg="yellow"),
|
||||
logging.ERROR: lambda level_name: click.style(str(level_name), fg="red"),
|
||||
logging.CRITICAL: lambda level_name: click.style(
|
||||
str(level_name), fg="bright_red"
|
||||
),
|
||||
logging.CRITICAL: lambda level_name: click.style(str(level_name), fg="bright_red"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@ -86,7 +84,7 @@ class AccessFormatter(ColourizedFormatter):
|
||||
status_phrase = http.HTTPStatus(status_code).phrase
|
||||
except ValueError:
|
||||
status_phrase = ""
|
||||
status_and_phrase = "%s %s" % (status_code, status_phrase)
|
||||
status_and_phrase = f"{status_code} {status_phrase}"
|
||||
if self.use_colors:
|
||||
|
||||
def default(code: int) -> str:
|
||||
@ -106,7 +104,7 @@ class AccessFormatter(ColourizedFormatter):
|
||||
status_code,
|
||||
) = recordcopy.args # type: ignore[misc]
|
||||
status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type]
|
||||
request_line = "%s %s HTTP/%s" % (method, full_path, http_version)
|
||||
request_line = f"{method} {full_path} HTTP/{http_version}"
|
||||
if self.use_colors:
|
||||
request_line = click.style(request_line, bold=True)
|
||||
recordcopy.__dict__.update(
|
||||
|
||||
@ -47,12 +47,11 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
click.echo(
|
||||
"Running uvicorn %s with %s %s on %s"
|
||||
% (
|
||||
uvicorn.__version__,
|
||||
platform.python_implementation(),
|
||||
platform.python_version(),
|
||||
platform.system(),
|
||||
"Running uvicorn {version} with {py_implementation} {py_version} on {system}".format(
|
||||
version=uvicorn.__version__,
|
||||
py_implementation=platform.python_implementation(),
|
||||
py_version=platform.python_version(),
|
||||
system=platform.system(),
|
||||
)
|
||||
)
|
||||
ctx.exit()
|
||||
@ -75,16 +74,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--uds", type=str, default=None, help="Bind to a UNIX domain socket.")
|
||||
@click.option(
|
||||
"--fd", type=int, default=None, help="Bind to socket from this file descriptor."
|
||||
)
|
||||
@click.option("--fd", type=int, default=None, help="Bind to socket from this file descriptor.")
|
||||
@click.option("--reload", is_flag=True, default=False, help="Enable auto-reload.")
|
||||
@click.option(
|
||||
"--reload-dir",
|
||||
"reload_dirs",
|
||||
multiple=True,
|
||||
help="Set reload directories explicitly, instead of using the current working"
|
||||
" directory.",
|
||||
help="Set reload directories explicitly, instead of using the current working" " directory.",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.option(
|
||||
@ -109,8 +105,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
type=float,
|
||||
default=0.25,
|
||||
show_default=True,
|
||||
help="Delay between previous and next check if application needs to be."
|
||||
" Defaults to 0.25s.",
|
||||
help="Delay between previous and next check if application needs to be." " Defaults to 0.25s.",
|
||||
)
|
||||
@click.option(
|
||||
"--workers",
|
||||
@ -226,8 +221,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
"--proxy-headers/--no-proxy-headers",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to "
|
||||
"populate remote address info.",
|
||||
help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to " "populate remote address info.",
|
||||
)
|
||||
@click.option(
|
||||
"--server-header/--no-server-header",
|
||||
@ -258,8 +252,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
"--limit-concurrency",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of concurrent connections or tasks to allow, before issuing"
|
||||
" HTTP 503 responses.",
|
||||
help="Maximum number of concurrent connections or tasks to allow, before issuing" " HTTP 503 responses.",
|
||||
)
|
||||
@click.option(
|
||||
"--backlog",
|
||||
@ -286,9 +279,7 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
|
||||
default=None,
|
||||
help="Maximum number of seconds to wait for graceful shutdown.",
|
||||
)
|
||||
@click.option(
|
||||
"--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True
|
||||
)
|
||||
@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file", show_default=True)
|
||||
@click.option(
|
||||
"--ssl-certfile",
|
||||
type=str,
|
||||
@ -571,10 +562,7 @@ def run(
|
||||
|
||||
if (config.reload or config.workers > 1) and not isinstance(app, str):
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger.warning(
|
||||
"You must pass the application as an import string to enable 'reload' or "
|
||||
"'workers'."
|
||||
)
|
||||
logger.warning("You must pass the application as an import string to enable 'reload' or " "'workers'.")
|
||||
sys.exit(1)
|
||||
|
||||
if config.should_reload:
|
||||
|
||||
@ -10,8 +10,6 @@ class ASGI2Middleware:
|
||||
def __init__(self, app: "ASGI2Application"):
|
||||
self.app = app
|
||||
|
||||
async def __call__(
|
||||
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
async def __call__(self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None:
|
||||
instance = self.app(scope)
|
||||
await instance(receive, send)
|
||||
|
||||
@ -8,23 +8,18 @@ the connecting client, rather that the connecting proxy.
|
||||
|
||||
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies
|
||||
"""
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from __future__ import annotations
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
HTTPScope,
|
||||
Scope,
|
||||
WebSocketScope,
|
||||
)
|
||||
from typing import Union, cast
|
||||
|
||||
from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, HTTPScope, Scope, WebSocketScope
|
||||
|
||||
|
||||
class ProxyHeadersMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: "ASGI3Application",
|
||||
trusted_hosts: Union[List[str], str] = "127.0.0.1",
|
||||
app: ASGI3Application,
|
||||
trusted_hosts: list[str] | str = "127.0.0.1",
|
||||
) -> None:
|
||||
self.app = app
|
||||
if isinstance(trusted_hosts, str):
|
||||
@ -33,9 +28,7 @@ class ProxyHeadersMiddleware:
|
||||
self.trusted_hosts = set(trusted_hosts)
|
||||
self.always_trust = "*" in self.trusted_hosts
|
||||
|
||||
def get_trusted_client_host(
|
||||
self, x_forwarded_for_hosts: List[str]
|
||||
) -> Optional[str]:
|
||||
def get_trusted_client_host(self, x_forwarded_for_hosts: list[str]) -> str | None:
|
||||
if self.always_trust:
|
||||
return x_forwarded_for_hosts[0]
|
||||
|
||||
@ -45,12 +38,10 @@ class ProxyHeadersMiddleware:
|
||||
|
||||
return None
|
||||
|
||||
async def __call__(
|
||||
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
scope = cast(Union["HTTPScope", "WebSocketScope"], scope)
|
||||
client_addr: Optional[Tuple[str, int]] = scope.get("client")
|
||||
client_addr: tuple[str, int] | None = scope.get("client")
|
||||
client_host = client_addr[0] if client_addr else None
|
||||
|
||||
if self.always_trust or client_host in self.trusted_hosts:
|
||||
@ -59,9 +50,7 @@ class ProxyHeadersMiddleware:
|
||||
if b"x-forwarded-proto" in headers:
|
||||
# Determine if the incoming request was http or https based on
|
||||
# the X-Forwarded-Proto header.
|
||||
x_forwarded_proto = (
|
||||
headers[b"x-forwarded-proto"].decode("latin1").strip()
|
||||
)
|
||||
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip()
|
||||
if scope["type"] == "websocket":
|
||||
scope["scheme"] = x_forwarded_proto.replace("http", "ws")
|
||||
else:
|
||||
@ -72,9 +61,7 @@ class ProxyHeadersMiddleware:
|
||||
# X-Forwarded-For header. We've lost the connecting client's port
|
||||
# information by now, so only include the host.
|
||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
|
||||
x_forwarded_for_hosts = [
|
||||
item.strip() for item in x_forwarded_for.split(",")
|
||||
]
|
||||
x_forwarded_for_hosts = [item.strip() for item in x_forwarded_for.split(",")]
|
||||
host = self.get_trusted_client_host(x_forwarded_for_hosts)
|
||||
port = 0
|
||||
scope["client"] = (host, port) # type: ignore[arg-type]
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import io
|
||||
import sys
|
||||
import warnings
|
||||
from collections import deque
|
||||
from typing import Deque, Iterable, Optional, Tuple
|
||||
from typing import Iterable
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveCallable,
|
||||
@ -22,9 +24,7 @@ from uvicorn._types import (
|
||||
)
|
||||
|
||||
|
||||
def build_environ(
|
||||
scope: "HTTPScope", message: "ASGIReceiveEvent", body: io.BytesIO
|
||||
) -> Environ:
|
||||
def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO) -> Environ:
|
||||
"""
|
||||
Builds a scope and request message into a WSGI environ object.
|
||||
"""
|
||||
@ -91,9 +91,9 @@ class _WSGIMiddleware:
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
scope: "HTTPScope",
|
||||
receive: "ASGIReceiveCallable",
|
||||
send: "ASGISendCallable",
|
||||
scope: HTTPScope,
|
||||
receive: ASGIReceiveCallable,
|
||||
send: ASGISendCallable,
|
||||
) -> None:
|
||||
assert scope["type"] == "http"
|
||||
instance = WSGIResponder(self.app, self.executor, scope)
|
||||
@ -105,7 +105,7 @@ class WSGIResponder:
|
||||
self,
|
||||
app: WSGIApp,
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
scope: "HTTPScope",
|
||||
scope: HTTPScope,
|
||||
):
|
||||
self.app = app
|
||||
self.executor = executor
|
||||
@ -113,21 +113,19 @@ class WSGIResponder:
|
||||
self.status = None
|
||||
self.response_headers = None
|
||||
self.send_event = asyncio.Event()
|
||||
self.send_queue: Deque[Optional["ASGISendEvent"]] = deque()
|
||||
self.send_queue: deque[ASGISendEvent | None] = deque()
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
self.response_started = False
|
||||
self.exc_info: Optional[ExcInfo] = None
|
||||
self.exc_info: ExcInfo | None = None
|
||||
|
||||
async def __call__(
|
||||
self, receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
|
||||
body = io.BytesIO(message.get("body", b""))
|
||||
more_body = message.get("more_body", False)
|
||||
if more_body:
|
||||
body.seek(0, io.SEEK_END)
|
||||
while more_body:
|
||||
body_message: "HTTPRequestEvent" = (
|
||||
body_message: HTTPRequestEvent = (
|
||||
await receive() # type: ignore[assignment]
|
||||
)
|
||||
body.write(body_message.get("body", b""))
|
||||
@ -135,9 +133,7 @@ class WSGIResponder:
|
||||
body.seek(0)
|
||||
environ = build_environ(self.scope, message, body)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
wsgi = self.loop.run_in_executor(
|
||||
self.executor, self.wsgi, environ, self.start_response
|
||||
)
|
||||
wsgi = self.loop.run_in_executor(self.executor, self.wsgi, environ, self.start_response)
|
||||
sender = self.loop.create_task(self.sender(send))
|
||||
try:
|
||||
await asyncio.wait_for(wsgi, None)
|
||||
@ -148,7 +144,7 @@ class WSGIResponder:
|
||||
if self.exc_info is not None:
|
||||
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
|
||||
|
||||
async def sender(self, send: "ASGISendCallable") -> None:
|
||||
async def sender(self, send: ASGISendCallable) -> None:
|
||||
while True:
|
||||
if self.send_queue:
|
||||
message = self.send_queue.popleft()
|
||||
@ -162,18 +158,15 @@ class WSGIResponder:
|
||||
def start_response(
|
||||
self,
|
||||
status: str,
|
||||
response_headers: Iterable[Tuple[str, str]],
|
||||
exc_info: Optional[ExcInfo] = None,
|
||||
response_headers: Iterable[tuple[str, str]],
|
||||
exc_info: ExcInfo | None = None,
|
||||
) -> None:
|
||||
self.exc_info = exc_info
|
||||
if not self.response_started:
|
||||
self.response_started = True
|
||||
status_code_str, _ = status.split(" ", 1)
|
||||
status_code = int(status_code_str)
|
||||
headers = [
|
||||
(name.encode("ascii"), value.encode("ascii"))
|
||||
for name, value in response_headers
|
||||
]
|
||||
headers = [(name.encode("ascii"), value.encode("ascii")) for name, value in response_headers]
|
||||
http_response_start_event: HTTPResponseStartEvent = {
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
|
||||
@ -45,9 +45,7 @@ class FlowControl:
|
||||
self._is_writable_event.set()
|
||||
|
||||
|
||||
async def service_unavailable(
|
||||
scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
|
||||
) -> None:
|
||||
async def service_unavailable(scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None:
|
||||
response_start: "HTTPResponseStartEvent" = {
|
||||
"type": "http.response.start",
|
||||
"status": 503,
|
||||
|
||||
@ -44,9 +44,7 @@ def _get_status_phrase(status_code: int) -> bytes:
|
||||
return b""
|
||||
|
||||
|
||||
STATUS_PHRASES = {
|
||||
status_code: _get_status_phrase(status_code) for status_code in range(100, 600)
|
||||
}
|
||||
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class H11Protocol(asyncio.Protocol):
|
||||
@ -228,8 +226,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency
|
||||
or len(self.tasks) >= self.limit_concurrency
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
@ -323,9 +320,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
# Set a short Keep-Alive timeout.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.timeout_keep_alive_task = self.loop.call_later(
|
||||
self.timeout_keep_alive, self.timeout_keep_alive_handler
|
||||
)
|
||||
self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
@ -372,7 +367,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: "HTTPScope",
|
||||
scope: HTTPScope,
|
||||
conn: h11.Connection,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
@ -408,7 +403,7 @@ class RequestResponseCycle:
|
||||
self.response_complete = False
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: "ASGI3Application") -> None:
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
@ -533,9 +528,7 @@ class RequestResponseCycle:
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
headers: list[tuple[str, str]] = []
|
||||
event = h11.InformationalResponse(
|
||||
status_code=100, headers=headers, reason="Continue"
|
||||
)
|
||||
event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
|
||||
output = self.conn.send(event=event)
|
||||
self.transport.write(output)
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
@ -7,7 +7,7 @@ import re
|
||||
import urllib
|
||||
from asyncio.events import TimerHandle
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Deque, Literal, cast
|
||||
from typing import Any, Callable, Literal, cast
|
||||
|
||||
import httptools
|
||||
|
||||
@ -50,9 +50,7 @@ def _get_status_line(status_code: int) -> bytes:
|
||||
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
|
||||
|
||||
|
||||
STATUS_LINE = {
|
||||
status_code: _get_status_line(status_code) for status_code in range(100, 600)
|
||||
}
|
||||
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class HttpToolsProtocol(asyncio.Protocol):
|
||||
@ -93,7 +91,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
self.pipeline: Deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
|
||||
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
@ -268,8 +266,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency
|
||||
or len(self.tasks) >= self.limit_concurrency
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
@ -302,9 +299,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.pipeline.appendleft((self.cycle, app))
|
||||
|
||||
def on_body(self, body: bytes) -> None:
|
||||
if (
|
||||
self.parser.should_upgrade() and self._should_upgrade()
|
||||
) or self.cycle.response_complete:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.body += body
|
||||
if len(self.cycle.body) > HIGH_WATER_LIMIT:
|
||||
@ -312,9 +307,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def on_message_complete(self) -> None:
|
||||
if (
|
||||
self.parser.should_upgrade() and self._should_upgrade()
|
||||
) or self.cycle.response_complete:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
@ -376,7 +369,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: "HTTPScope",
|
||||
scope: HTTPScope,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
@ -517,11 +510,7 @@ class RequestResponseCycle:
|
||||
self.keep_alive = False
|
||||
content.extend([name, b": ", value, b"\r\n"])
|
||||
|
||||
if (
|
||||
self.chunked_encoding is None
|
||||
and self.scope["method"] != "HEAD"
|
||||
and status_code not in (204, 304)
|
||||
):
|
||||
if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
|
||||
# Neither content-length nor transfer-encoding specified
|
||||
self.chunked_encoding = True
|
||||
content.append(b"transfer-encoding: chunked\r\n")
|
||||
|
||||
@ -53,7 +53,5 @@ def get_client_addr(scope: WWWScope) -> str:
|
||||
def get_path_with_query_string(scope: WWWScope) -> str:
|
||||
path_with_query_string = urllib.parse.quote(scope["path"])
|
||||
if scope["query_string"]:
|
||||
path_with_query_string = "{}?{}".format(
|
||||
path_with_query_string, scope["query_string"].decode("ascii")
|
||||
)
|
||||
path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
|
||||
return path_with_query_string
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]]
|
||||
AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None
|
||||
try:
|
||||
import websockets # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
|
||||
@ -3,16 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Any, Literal, Optional, Sequence, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import websockets
|
||||
@ -61,7 +52,7 @@ class Server:
|
||||
|
||||
|
||||
class WebSocketProtocol(WebSocketServerProtocol):
|
||||
extra_headers: List[Tuple[str, str]]
|
||||
extra_headers: list[tuple[str, str]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -117,8 +108,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
)
|
||||
self.server_header = None
|
||||
self.extra_headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in server_state.default_headers
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
|
||||
]
|
||||
|
||||
def connection_made( # type: ignore[override]
|
||||
@ -136,16 +126,14 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
|
||||
super().connection_made(transport)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.lost_connection_before_handshake = (
|
||||
not self.handshake_completed_event.is_set()
|
||||
)
|
||||
self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
|
||||
self.handshake_completed_event.set()
|
||||
super().connection_lost(exc)
|
||||
if exc is None:
|
||||
@ -162,9 +150,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
def on_task_complete(self, task: asyncio.Task) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
async def process_request(
|
||||
self, path: str, headers: Headers
|
||||
) -> Optional[HTTPResponse]:
|
||||
async def process_request(self, path: str, headers: Headers) -> HTTPResponse | None:
|
||||
"""
|
||||
This hook is called to determine if the websocket should return
|
||||
an HTTP response and close.
|
||||
@ -212,8 +198,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
return self.initial_response
|
||||
|
||||
def process_subprotocol(
|
||||
self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
|
||||
) -> Optional[Subprotocol]:
|
||||
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
|
||||
) -> Subprotocol | None:
|
||||
"""
|
||||
We override the standard 'process_subprotocol' behavior here so that
|
||||
we return whatever subprotocol is sent in the 'accept' message.
|
||||
@ -223,8 +209,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
def send_500_response(self) -> None:
|
||||
msg = b"Internal Server Error"
|
||||
content = [
|
||||
b"HTTP/1.1 500 Internal Server Error\r\n"
|
||||
b"content-type: text/plain; charset=utf-8\r\n",
|
||||
b"HTTP/1.1 500 Internal Server Error\r\n" b"content-type: text/plain; charset=utf-8\r\n",
|
||||
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
|
||||
b"connection: close\r\n",
|
||||
b"\r\n",
|
||||
@ -278,7 +263,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
|
||||
async def asgi_send(self, message: "ASGISendEvent") -> None:
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_started_event.is_set():
|
||||
@ -290,9 +275,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = None
|
||||
self.accepted_subprotocol = cast(
|
||||
Optional[Subprotocol], message.get("subprotocol")
|
||||
)
|
||||
self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol"))
|
||||
if "headers" in message:
|
||||
self.extra_headers.extend(
|
||||
# ASGI spec requires bytes
|
||||
@ -324,8 +307,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
# websockets requires the status to be an enum. look it up.
|
||||
status = http.HTTPStatus(message["status"])
|
||||
headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in message.get("headers", [])
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
|
||||
]
|
||||
self.initial_response = (status, headers, b"")
|
||||
self.handshake_started_event.set()
|
||||
@ -356,10 +338,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
self.closed_event.set()
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.send' or 'websocket.close',"
|
||||
" but got '%s'."
|
||||
)
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except ConnectionClosed as exc:
|
||||
raise ClientDisconnected from exc
|
||||
@ -372,24 +351,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
if not message.get("more_body", False):
|
||||
self.closed_event.set()
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.http.response.body' "
|
||||
"but got '%s'."
|
||||
)
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Unexpected ASGI message '%s', after sending 'websocket.close' "
|
||||
"or response already completed."
|
||||
)
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def asgi_receive(
|
||||
self,
|
||||
) -> Union[
|
||||
"WebSocketDisconnectEvent", "WebSocketConnectEvent", "WebSocketReceiveEvent"
|
||||
]:
|
||||
) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
|
||||
if not self.connect_sent:
|
||||
self.connect_sent = True
|
||||
return {"type": "websocket.connect"}
|
||||
|
||||
@ -168,7 +168,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
path = unquote(raw_path)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
|
||||
self.scope: "WebSocketScope" = {
|
||||
self.scope: WebSocketScope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
@ -224,14 +224,8 @@ class WSProtocol(asyncio.Protocol):
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
]
|
||||
output = self.conn.send(
|
||||
wsproto.events.RejectConnection(
|
||||
status_code=500, headers=headers, has_body=True
|
||||
)
|
||||
)
|
||||
output += self.conn.send(
|
||||
wsproto.events.RejectData(data=b"Internal Server Error")
|
||||
)
|
||||
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
|
||||
output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
|
||||
self.transport.write(output)
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
@ -269,7 +263,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
)
|
||||
subprotocol = message.get("subprotocol")
|
||||
extra_headers = self.default_headers + list(message.get("headers", []))
|
||||
extensions: typing.List[Extension] = []
|
||||
extensions: list[Extension] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(PerMessageDeflate())
|
||||
if not self.transport.is_closing():
|
||||
@ -343,21 +337,14 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.close_sent = True
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait(
|
||||
{"type": "websocket.disconnect", "code": code}
|
||||
)
|
||||
output = self.conn.send(
|
||||
wsproto.events.CloseConnection(code=code, reason=reason)
|
||||
)
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.send' or 'websocket.close',"
|
||||
" but got '%s'."
|
||||
)
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except LocalProtocolError as exc:
|
||||
raise ClientDisconnected from exc
|
||||
@ -365,24 +352,17 @@ class WSProtocol(asyncio.Protocol):
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = typing.cast("WebSocketResponseBodyEvent", message)
|
||||
body_finished = not message.get("more_body", False)
|
||||
reject_data = events.RejectData(
|
||||
data=message["body"], body_finished=body_finished
|
||||
)
|
||||
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
|
||||
output = self.conn.send(reject_data)
|
||||
self.transport.write(output)
|
||||
|
||||
if body_finished:
|
||||
self.queue.put_nowait(
|
||||
{"type": "websocket.disconnect", "code": 1006}
|
||||
)
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.http.response.body' "
|
||||
"but got '%s'."
|
||||
)
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
|
||||
@ -126,18 +126,14 @@ class Server:
|
||||
is_windows = platform.system() == "Windows"
|
||||
if config.workers > 1 and is_windows: # pragma: py-not-win32
|
||||
sock = _share_socket(sock) # type: ignore[assignment]
|
||||
server = await loop.create_server(
|
||||
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
|
||||
)
|
||||
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
|
||||
self.servers.append(server)
|
||||
listeners = sockets
|
||||
|
||||
elif config.fd is not None: # pragma: py-win32
|
||||
# Use an existing socket, from a file descriptor.
|
||||
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server = await loop.create_server(
|
||||
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
|
||||
)
|
||||
server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
|
||||
assert server.sockets is not None # mypy
|
||||
listeners = server.sockets
|
||||
self.servers = [server]
|
||||
@ -194,9 +190,7 @@ class Server:
|
||||
)
|
||||
|
||||
elif config.uds is not None: # pragma: py-win32
|
||||
logger.info(
|
||||
"Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds
|
||||
)
|
||||
logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds)
|
||||
|
||||
else:
|
||||
addr_format = "%s://%s:%d"
|
||||
@ -211,11 +205,7 @@ class Server:
|
||||
|
||||
protocol_name = "https" if config.ssl else "http"
|
||||
message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
|
||||
color_message = (
|
||||
"Uvicorn running on "
|
||||
+ click.style(addr_format, bold=True)
|
||||
+ " (Press CTRL+C to quit)"
|
||||
)
|
||||
color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
|
||||
logger.info(
|
||||
message,
|
||||
protocol_name,
|
||||
@ -244,9 +234,7 @@ class Server:
|
||||
else:
|
||||
date_header = []
|
||||
|
||||
self.server_state.default_headers = (
|
||||
date_header + self.config.encoded_headers
|
||||
)
|
||||
self.server_state.default_headers = date_header + self.config.encoded_headers
|
||||
|
||||
# Callback to `callback_notify` once every `timeout_notify` seconds.
|
||||
if self.config.callback_notify is not None:
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from typing import TYPE_CHECKING, Type
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from uvicorn.supervisors.basereload import BaseReload
|
||||
from uvicorn.supervisors.multiprocess import Multiprocess
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ChangeReload: Type[BaseReload]
|
||||
ChangeReload: type[BaseReload]
|
||||
else:
|
||||
try:
|
||||
from uvicorn.supervisors.watchfilesreload import (
|
||||
|
||||
@ -81,9 +81,7 @@ class BaseReload:
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.signal_handler)
|
||||
|
||||
self.process = get_subprocess(
|
||||
config=self.config, target=self.target, sockets=self.sockets
|
||||
)
|
||||
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
|
||||
self.process.start()
|
||||
|
||||
def restart(self) -> None:
|
||||
@ -95,9 +93,7 @@ class BaseReload:
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
|
||||
self.process = get_subprocess(
|
||||
config=self.config, target=self.target, sockets=self.sockets
|
||||
)
|
||||
self.process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
|
||||
self.process.start()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
@ -110,10 +106,8 @@ class BaseReload:
|
||||
for sock in self.sockets:
|
||||
sock.close()
|
||||
|
||||
message = "Stopping reloader process [{}]".format(str(self.pid))
|
||||
color_message = "Stopping reloader process [{}]".format(
|
||||
click.style(str(self.pid), fg="cyan", bold=True)
|
||||
)
|
||||
message = f"Stopping reloader process [{str(self.pid)}]"
|
||||
color_message = "Stopping reloader process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
def should_restart(self) -> list[Path] | None:
|
||||
|
||||
@ -48,19 +48,15 @@ class Multiprocess:
|
||||
self.shutdown()
|
||||
|
||||
def startup(self) -> None:
|
||||
message = "Started parent process [{}]".format(str(self.pid))
|
||||
color_message = "Started parent process [{}]".format(
|
||||
click.style(str(self.pid), fg="cyan", bold=True)
|
||||
)
|
||||
message = f"Started parent process [{str(self.pid)}]"
|
||||
color_message = "Started parent process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.signal_handler)
|
||||
|
||||
for _idx in range(self.config.workers):
|
||||
process = get_subprocess(
|
||||
config=self.config, target=self.target, sockets=self.sockets
|
||||
)
|
||||
process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
|
||||
@ -69,8 +65,6 @@ class Multiprocess:
|
||||
process.terminate()
|
||||
process.join()
|
||||
|
||||
message = "Stopping parent process [{}]".format(str(self.pid))
|
||||
color_message = "Stopping parent process [{}]".format(
|
||||
click.style(str(self.pid), fg="cyan", bold=True)
|
||||
)
|
||||
message = f"Stopping parent process [{str(self.pid)}]"
|
||||
color_message = "Stopping parent process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
|
||||
logger.info(message, extra={"color_message": color_message})
|
||||
|
||||
@ -23,10 +23,7 @@ class StatReload(BaseReload):
|
||||
self.mtimes: dict[Path, float] = {}
|
||||
|
||||
if config.reload_excludes or config.reload_includes:
|
||||
logger.warning(
|
||||
"--reload-include and --reload-exclude have no effect unless "
|
||||
"watchfiles is installed."
|
||||
)
|
||||
logger.warning("--reload-include and --reload-exclude have no effect unless " "watchfiles is installed.")
|
||||
|
||||
def should_restart(self) -> list[Path] | None:
|
||||
self.pause()
|
||||
|
||||
@ -13,20 +13,12 @@ from uvicorn.supervisors.basereload import BaseReload
|
||||
class FileFilter:
|
||||
def __init__(self, config: Config):
|
||||
default_includes = ["*.py"]
|
||||
self.includes = [
|
||||
default
|
||||
for default in default_includes
|
||||
if default not in config.reload_excludes
|
||||
]
|
||||
self.includes = [default for default in default_includes if default not in config.reload_excludes]
|
||||
self.includes.extend(config.reload_includes)
|
||||
self.includes = list(set(self.includes))
|
||||
|
||||
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
|
||||
self.excludes = [
|
||||
default
|
||||
for default in default_excludes
|
||||
if default not in config.reload_includes
|
||||
]
|
||||
self.excludes = [default for default in default_excludes if default not in config.reload_includes]
|
||||
self.exclude_dirs = []
|
||||
for e in config.reload_excludes:
|
||||
p = Path(e)
|
||||
|
||||
@ -22,20 +22,12 @@ logger = logging.getLogger("uvicorn.error")
|
||||
class CustomWatcher(DefaultWatcher):
|
||||
def __init__(self, root_path: Path, config: Config):
|
||||
default_includes = ["*.py"]
|
||||
self.includes = [
|
||||
default
|
||||
for default in default_includes
|
||||
if default not in config.reload_excludes
|
||||
]
|
||||
self.includes = [default for default in default_includes if default not in config.reload_excludes]
|
||||
self.includes.extend(config.reload_includes)
|
||||
self.includes = list(set(self.includes))
|
||||
|
||||
default_excludes = [".*", ".py[cod]", ".sw.*", "~*"]
|
||||
self.excludes = [
|
||||
default
|
||||
for default in default_excludes
|
||||
if default not in config.reload_includes
|
||||
]
|
||||
self.excludes = [default for default in default_excludes if default not in config.reload_includes]
|
||||
self.excludes.extend(config.reload_excludes)
|
||||
self.excludes = list(set(self.excludes))
|
||||
|
||||
@ -46,7 +38,7 @@ class CustomWatcher(DefaultWatcher):
|
||||
self.resolved_root = root_path
|
||||
super().__init__(str(root_path))
|
||||
|
||||
def should_watch_file(self, entry: "DirEntry") -> bool:
|
||||
def should_watch_file(self, entry: DirEntry) -> bool:
|
||||
cached_result = self.watched_files.get(entry.path)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
@ -71,7 +63,7 @@ class CustomWatcher(DefaultWatcher):
|
||||
self.watched_files[entry.path] = False
|
||||
return False
|
||||
|
||||
def should_watch_dir(self, entry: "DirEntry") -> bool:
|
||||
def should_watch_dir(self, entry: DirEntry) -> bool:
|
||||
cached_result = self.watched_dirs.get(entry.path)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
@ -94,8 +86,7 @@ class CustomWatcher(DefaultWatcher):
|
||||
|
||||
if is_watched:
|
||||
logger.debug(
|
||||
"WatchGodReload detected a new excluded dir '%s' in '%s'; "
|
||||
"Adding to exclude list.",
|
||||
"WatchGodReload detected a new excluded dir '%s' in '%s'; " "Adding to exclude list.",
|
||||
entry_path.relative_to(self.resolved_root),
|
||||
str(self.resolved_root),
|
||||
)
|
||||
@ -115,8 +106,7 @@ class CustomWatcher(DefaultWatcher):
|
||||
for include_pattern in self.includes:
|
||||
if entry_path.match(include_pattern):
|
||||
logger.info(
|
||||
"WatchGodReload detected a new reload dir '%s' in '%s'; "
|
||||
"Adding to watch list.",
|
||||
"WatchGodReload detected a new reload dir '%s' in '%s'; " "Adding to watch list.",
|
||||
str(entry_path.relative_to(self.resolved_root)),
|
||||
str(self.resolved_root),
|
||||
)
|
||||
@ -136,8 +126,7 @@ class WatchGodReload(BaseReload):
|
||||
sockets: list[socket],
|
||||
) -> None:
|
||||
warnings.warn(
|
||||
'"watchgod" is deprecated, you should switch '
|
||||
"to watchfiles (`pip install watchfiles`).",
|
||||
'"watchgod" is deprecated, you should switch ' "to watchfiles (`pip install watchfiles`).",
|
||||
DeprecationWarning,
|
||||
)
|
||||
super().__init__(config, target, sockets)
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from gunicorn.workers.base import Worker
|
||||
@ -17,10 +19,10 @@ class UvicornWorker(Worker):
|
||||
rather than a WSGI callable.
|
||||
"""
|
||||
|
||||
CONFIG_KWARGS: Dict[str, Any] = {"loop": "auto", "http": "auto"}
|
||||
CONFIG_KWARGS: dict[str, Any] = {"loop": "auto", "http": "auto"}
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super(UvicornWorker, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger.handlers = self.log.error_log.handlers
|
||||
@ -63,7 +65,7 @@ class UvicornWorker(Worker):
|
||||
|
||||
def init_process(self) -> None:
|
||||
self.config.setup_event_loop()
|
||||
super(UvicornWorker, self).init_process()
|
||||
super().init_process()
|
||||
|
||||
def init_signals(self) -> None:
|
||||
# Reset signals so Gunicorn doesn't swallow subprocess return codes
|
||||
|
||||
Loading…
Reference in New Issue
Block a user