Compare commits
3 Commits
main
...
define-bas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4927f77b1a | ||
|
|
e37847705b | ||
|
|
806b429bb9 |
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import threading
|
||||
@ -14,6 +15,7 @@ from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallabl
|
||||
from uvicorn.config import WS_PROTOCOLS, Config
|
||||
from uvicorn.lifespan.off import LifespanOff
|
||||
from uvicorn.lifespan.on import LifespanOn
|
||||
from uvicorn.protocols.http.base import HTTPProtocol
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
@ -36,7 +38,6 @@ if TYPE_CHECKING:
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
HTTPProtocol: TypeAlias = "type[HttpToolsProtocol | H11Protocol]"
|
||||
WSProtocol: TypeAlias = "type[WebSocketProtocol | _WSProtocol]"
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
@ -173,7 +174,9 @@ UPGRADE_REQUEST_ERROR_FIELD = b"\r\n".join(
|
||||
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self, sockname=None, peername=None, sslcontext=False):
|
||||
def __init__(
|
||||
self, sockname: tuple[str, int] | None = None, peername: tuple[str, int] | None = None, sslcontext: bool = False
|
||||
):
|
||||
self.sockname = ("127.0.0.1", 8000) if sockname is None else sockname
|
||||
self.peername = ("127.0.0.1", 8001) if peername is None else peername
|
||||
self.sslcontext = sslcontext
|
||||
@ -181,14 +184,10 @@ class MockTransport:
|
||||
self.buffer = b""
|
||||
self.read_paused = False
|
||||
|
||||
def get_extra_info(self, key):
|
||||
return {
|
||||
"sockname": self.sockname,
|
||||
"peername": self.peername,
|
||||
"sslcontext": self.sslcontext,
|
||||
}.get(key)
|
||||
def get_extra_info(self, key: Any):
|
||||
return {"sockname": self.sockname, "peername": self.peername, "sslcontext": self.sslcontext}.get(key)
|
||||
|
||||
def write(self, data):
|
||||
def write(self, data: bytes):
|
||||
assert not self.closed
|
||||
self.buffer += data
|
||||
|
||||
@ -208,7 +207,7 @@ class MockTransport:
|
||||
def clear_buffer(self):
|
||||
self.buffer = b""
|
||||
|
||||
def set_protocol(self, protocol):
|
||||
def set_protocol(self, protocol: asyncio.Protocol):
|
||||
pass
|
||||
|
||||
|
||||
@ -258,12 +257,17 @@ class MockTask:
|
||||
pass
|
||||
|
||||
|
||||
class MockProtocol(HTTPProtocol):
|
||||
loop: MockLoop # type: ignore[assignment]
|
||||
transport: MockTransport # type: ignore[assignment]
|
||||
|
||||
|
||||
def get_connected_protocol(
|
||||
app: ASGIApplication,
|
||||
http_protocol_cls: HTTPProtocol,
|
||||
http_protocol_cls: type[HTTPProtocol],
|
||||
lifespan: LifespanOff | LifespanOn | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> MockProtocol:
|
||||
loop = MockLoop()
|
||||
transport = MockTransport()
|
||||
config = Config(app=app, **kwargs)
|
||||
@ -273,13 +277,13 @@ def get_connected_protocol(
|
||||
config=config,
|
||||
server_state=server_state,
|
||||
app_state=lifespan.state,
|
||||
_loop=loop, # type: ignore
|
||||
_loop=loop, # type: ignore[arg-type]
|
||||
)
|
||||
protocol.connection_made(transport) # type: ignore
|
||||
return protocol
|
||||
protocol.connection_made(transport) # type: ignore[arg-type]
|
||||
return protocol # type: ignore[return-value]
|
||||
|
||||
|
||||
async def test_get_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_get_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -298,7 +302,7 @@ async def test_get_request(http_protocol_cls: HTTPProtocol):
|
||||
pytest.param("µ", id="allow_non_ascii_char"),
|
||||
],
|
||||
)
|
||||
async def test_header_value_allowed_characters(http_protocol_cls: HTTPProtocol, char: str):
|
||||
async def test_header_value_allowed_characters(http_protocol_cls: type[HTTPProtocol], char: str):
|
||||
app = Response("Hello, world", media_type="text/plain", headers={"key": f"<{char}>"})
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
@ -309,7 +313,7 @@ async def test_header_value_allowed_characters(http_protocol_cls: HTTPProtocol,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"])
|
||||
async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture):
|
||||
async def test_request_logging(path: str, http_protocol_cls: type[HTTPProtocol], caplog: pytest.LogCaptureFixture):
|
||||
get_request_with_query_string = b"\r\n".join(
|
||||
[f"GET {path} HTTP/1.1".encode("ascii"), b"Host: example.org", b"", b""]
|
||||
)
|
||||
@ -324,7 +328,7 @@ async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplo
|
||||
assert f'"GET {path} HTTP/1.1" 200' in caplog.records[0].message
|
||||
|
||||
|
||||
async def test_head_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_head_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -334,7 +338,7 @@ async def test_head_request(http_protocol_cls: HTTPProtocol):
|
||||
assert b"Hello, world" not in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_post_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_post_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
body = b""
|
||||
more_body = True
|
||||
@ -353,7 +357,7 @@ async def test_post_request(http_protocol_cls: HTTPProtocol):
|
||||
assert b'Body: {"hello": "world"}' in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_keepalive(http_protocol_cls: HTTPProtocol):
|
||||
async def test_keepalive(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"", status_code=204)
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -364,7 +368,7 @@ async def test_keepalive(http_protocol_cls: HTTPProtocol):
|
||||
assert not protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol):
|
||||
async def test_keepalive_timeout(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"", status_code=204)
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -378,9 +382,7 @@ async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_keepalive_timeout_with_pipelined_requests(
|
||||
http_protocol_cls: HTTPProtocol,
|
||||
):
|
||||
async def test_keepalive_timeout_with_pipelined_requests(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -403,7 +405,7 @@ async def test_keepalive_timeout_with_pipelined_requests(
|
||||
assert protocol.timeout_keep_alive_task is not None
|
||||
|
||||
|
||||
async def test_close(http_protocol_cls: HTTPProtocol):
|
||||
async def test_close(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"", status_code=204, headers={"connection": "close"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -413,7 +415,7 @@ async def test_close(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
|
||||
async def test_chunked_encoding(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -424,7 +426,7 @@ async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
|
||||
assert not protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
|
||||
async def test_chunked_encoding_empty_body(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -435,9 +437,7 @@ async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
|
||||
assert not protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_chunked_encoding_head_request(
|
||||
http_protocol_cls: HTTPProtocol,
|
||||
):
|
||||
async def test_chunked_encoding_head_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -447,7 +447,7 @@ async def test_chunked_encoding_head_request(
|
||||
assert not protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
|
||||
async def test_pipelined_requests(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -468,7 +468,7 @@ async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
|
||||
protocol.transport.clear_buffer()
|
||||
|
||||
|
||||
async def test_undersized_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_undersized_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"xxx", headers={"content-length": "10"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -477,7 +477,7 @@ async def test_undersized_request(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_oversized_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_oversized_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"xxx" * 20, headers={"content-length": "10"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -486,7 +486,7 @@ async def test_oversized_request(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_large_post_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_large_post_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -496,7 +496,7 @@ async def test_large_post_request(http_protocol_cls: HTTPProtocol):
|
||||
assert not protocol.transport.read_paused
|
||||
|
||||
|
||||
async def test_invalid_http(http_protocol_cls: HTTPProtocol):
|
||||
async def test_invalid_http(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -504,7 +504,7 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_app_exception(http_protocol_cls: HTTPProtocol):
|
||||
async def test_app_exception(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
raise Exception()
|
||||
|
||||
@ -515,7 +515,7 @@ async def test_app_exception(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
|
||||
async def test_exception_during_response(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
await send({"type": "http.response.body", "body": b"1", "more_body": True})
|
||||
@ -528,7 +528,7 @@ async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
|
||||
async def test_no_response_returned(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): ...
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -538,7 +538,7 @@ async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
|
||||
async def test_partial_response_returned(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
|
||||
@ -549,7 +549,7 @@ async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_response_header_splitting(http_protocol_cls: HTTPProtocol):
|
||||
async def test_response_header_splitting(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"", headers={"key": "value\r\nCookie: smuggled=value"})
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -560,7 +560,7 @@ async def test_response_header_splitting(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
|
||||
async def test_duplicate_start_message(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
@ -572,7 +572,7 @@ async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
|
||||
async def test_missing_start_message(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send({"type": "http.response.body", "body": b""})
|
||||
|
||||
@ -583,7 +583,7 @@ async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
|
||||
async def test_message_after_body_complete(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
await send({"type": "http.response.body", "body": b""})
|
||||
@ -596,7 +596,7 @@ async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_value_returned(http_protocol_cls: HTTPProtocol):
|
||||
async def test_value_returned(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send({"type": "http.response.start", "status": 200})
|
||||
await send({"type": "http.response.body", "body": b""})
|
||||
@ -609,7 +609,7 @@ async def test_value_returned(http_protocol_cls: HTTPProtocol):
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
|
||||
async def test_early_disconnect(http_protocol_cls: type[HTTPProtocol]):
|
||||
got_disconnect_event = False
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
@ -630,7 +630,7 @@ async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
|
||||
assert got_disconnect_event
|
||||
|
||||
|
||||
async def test_early_response(http_protocol_cls: HTTPProtocol):
|
||||
async def test_early_response(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -641,7 +641,7 @@ async def test_early_response(http_protocol_cls: HTTPProtocol):
|
||||
assert not protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_read_after_response(http_protocol_cls: HTTPProtocol):
|
||||
async def test_read_after_response(http_protocol_cls: type[HTTPProtocol]):
|
||||
message_after_response = None
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
@ -658,7 +658,7 @@ async def test_read_after_response(http_protocol_cls: HTTPProtocol):
|
||||
assert message_after_response == {"type": "http.disconnect"}
|
||||
|
||||
|
||||
async def test_http10_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_http10_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
assert scope["type"] == "http"
|
||||
content = "Version: %s" % scope["http_version"]
|
||||
@ -672,7 +672,7 @@ async def test_http10_request(http_protocol_cls: HTTPProtocol):
|
||||
assert b"Version: 1.0" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_root_path(http_protocol_cls: HTTPProtocol):
|
||||
async def test_root_path(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
assert scope["type"] == "http"
|
||||
root_path = scope.get("root_path", "")
|
||||
@ -687,7 +687,7 @@ async def test_root_path(http_protocol_cls: HTTPProtocol):
|
||||
assert b"root_path=/app path=/app/" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_raw_path(http_protocol_cls: HTTPProtocol):
|
||||
async def test_raw_path(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
assert scope["type"] == "http"
|
||||
path = scope["path"]
|
||||
@ -704,7 +704,7 @@ async def test_raw_path(http_protocol_cls: HTTPProtocol):
|
||||
assert b"Done" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_max_concurrency(http_protocol_cls: HTTPProtocol):
|
||||
async def test_max_concurrency(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls, limit_concurrency=1)
|
||||
@ -725,27 +725,27 @@ async def test_max_concurrency(http_protocol_cls: HTTPProtocol):
|
||||
)
|
||||
|
||||
|
||||
async def test_shutdown_during_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_shutdown_during_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response(b"", status_code=204)
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
protocol.shutdown()
|
||||
protocol.shutdown() # type: ignore[attr-defined]
|
||||
await protocol.loop.run_one()
|
||||
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_shutdown_during_idle(http_protocol_cls: HTTPProtocol):
|
||||
async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
protocol.shutdown()
|
||||
protocol.shutdown() # type: ignore[attr-defined]
|
||||
assert protocol.transport.buffer == b""
|
||||
assert protocol.transport.is_closing()
|
||||
|
||||
|
||||
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: HTTPProtocol):
|
||||
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
body = b""
|
||||
more_body = True
|
||||
@ -777,7 +777,7 @@ async def test_100_continue_sent_when_body_consumed(http_protocol_cls: HTTPProto
|
||||
|
||||
|
||||
async def test_100_continue_not_sent_when_body_not_consumed(
|
||||
http_protocol_cls: HTTPProtocol,
|
||||
http_protocol_cls: type[HTTPProtocol],
|
||||
):
|
||||
app = Response(b"", status_code=204)
|
||||
|
||||
@ -799,7 +799,7 @@ async def test_100_continue_not_sent_when_body_not_consumed(
|
||||
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_supported_upgrade_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
pytest.importorskip("wsproto")
|
||||
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
@ -809,7 +809,7 @@ async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol):
|
||||
assert b"HTTP/1.1 426 " in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_unsupported_ws_upgrade_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls, ws="none")
|
||||
@ -820,7 +820,7 @@ async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol):
|
||||
|
||||
|
||||
async def test_unsupported_ws_upgrade_request_warn_on_auto(
|
||||
caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol
|
||||
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
|
||||
):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
@ -836,7 +836,7 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto(
|
||||
assert msg in warnings
|
||||
|
||||
|
||||
async def test_http2_upgrade_request(http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol):
|
||||
async def test_http2_upgrade_request(http_protocol_cls: type[HTTPProtocol], ws_protocol_cls: WSProtocol):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls, ws=ws_protocol_cls)
|
||||
@ -867,7 +867,7 @@ def asgi2app(scope: Scope):
|
||||
async def test_scopes(
|
||||
asgi2or3_app: ASGIApplication,
|
||||
expected_scopes: dict[str, str],
|
||||
http_protocol_cls: HTTPProtocol,
|
||||
http_protocol_cls: type[HTTPProtocol],
|
||||
):
|
||||
protocol = get_connected_protocol(asgi2or3_app, http_protocol_cls)
|
||||
protocol.data_received(SIMPLE_GET_REQUEST)
|
||||
@ -884,7 +884,7 @@ async def test_scopes(
|
||||
],
|
||||
)
|
||||
async def test_invalid_http_request(
|
||||
request_line: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture
|
||||
request_line: str, http_protocol_cls: type[HTTPProtocol], caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
request = INVALID_REQUEST_TEMPLATE % request_line
|
||||
@ -1007,7 +1007,7 @@ async def test_huge_headers_h11_max_incomplete():
|
||||
assert b"Hello, world" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_return_close_header(http_protocol_cls: HTTPProtocol):
|
||||
async def test_return_close_header(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -1021,7 +1021,7 @@ async def test_return_close_header(http_protocol_cls: HTTPProtocol):
|
||||
assert b"connection: close" in protocol.transport.buffer.lower()
|
||||
|
||||
|
||||
async def test_close_connection_with_multiple_requests(http_protocol_cls: HTTPProtocol):
|
||||
async def test_close_connection_with_multiple_requests(http_protocol_cls: type[HTTPProtocol]):
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
protocol = get_connected_protocol(app, http_protocol_cls)
|
||||
@ -1035,7 +1035,7 @@ async def test_close_connection_with_multiple_requests(http_protocol_cls: HTTPPr
|
||||
assert b"connection: close" in protocol.transport.buffer.lower()
|
||||
|
||||
|
||||
async def test_close_connection_with_post_request(http_protocol_cls: HTTPProtocol):
|
||||
async def test_close_connection_with_post_request(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
body = b""
|
||||
more_body = True
|
||||
@ -1054,7 +1054,7 @@ async def test_close_connection_with_post_request(http_protocol_cls: HTTPProtoco
|
||||
assert b"Body: {'hello': 'world'}" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_iterator_headers(http_protocol_cls: HTTPProtocol):
|
||||
async def test_iterator_headers(http_protocol_cls: type[HTTPProtocol]):
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
headers = iter([(b"x-test-header", b"test value")])
|
||||
await send({"type": "http.response.start", "status": 200, "headers": headers})
|
||||
@ -1066,7 +1066,7 @@ async def test_iterator_headers(http_protocol_cls: HTTPProtocol):
|
||||
assert b"x-test-header: test value" in protocol.transport.buffer
|
||||
|
||||
|
||||
async def test_lifespan_state(http_protocol_cls: HTTPProtocol):
|
||||
async def test_lifespan_state(http_protocol_cls: type[HTTPProtocol]):
|
||||
expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}]
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
@ -1095,7 +1095,7 @@ async def test_lifespan_state(http_protocol_cls: HTTPProtocol):
|
||||
|
||||
|
||||
async def test_header_upgrade_is_not_websocket_depend_installed(
|
||||
caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol
|
||||
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
|
||||
):
|
||||
caplog.set_level(logging.WARNING, logger="uvicorn.error")
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
@ -1111,7 +1111,7 @@ async def test_header_upgrade_is_not_websocket_depend_installed(
|
||||
|
||||
|
||||
async def test_header_upgrade_is_websocket_depend_not_installed(
|
||||
caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol
|
||||
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
|
||||
):
|
||||
caplog.set_level(logging.WARNING, logger="uvicorn.error")
|
||||
app = Response("Hello, world", media_type="text/plain")
|
||||
|
||||
@ -270,12 +270,5 @@ class ASGI2Protocol(Protocol):
|
||||
|
||||
|
||||
ASGI2Application = type[ASGI2Protocol]
|
||||
ASGI3Application = Callable[
|
||||
[
|
||||
Scope,
|
||||
ASGIReceiveCallable,
|
||||
ASGISendCallable,
|
||||
],
|
||||
Awaitable[None],
|
||||
]
|
||||
ASGI3Application = Callable[[Scope, ASGIReceiveCallable, ASGISendCallable], Awaitable[None]]
|
||||
ASGIApplication = Union[ASGI2Application, ASGI3Application]
|
||||
|
||||
@ -96,13 +96,7 @@ class AccessFormatter(ColourizedFormatter):
|
||||
|
||||
def formatMessage(self, record: logging.LogRecord) -> str:
|
||||
recordcopy = copy(record)
|
||||
(
|
||||
client_addr,
|
||||
method,
|
||||
full_path,
|
||||
http_version,
|
||||
status_code,
|
||||
) = recordcopy.args # type: ignore[misc]
|
||||
(client_addr, method, full_path, http_version, status_code) = recordcopy.args # type: ignore[misc]
|
||||
status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type]
|
||||
request_line = f"{method} {full_path} HTTP/{http_version}"
|
||||
if self.use_colors:
|
||||
|
||||
82
uvicorn/protocols/http/base.py
Normal file
82
uvicorn/protocols/http/base.py
Normal file
@ -0,0 +1,82 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from uvicorn._types import HTTPScope
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.protocols.http.flow_control import FlowControl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
class HTTPProtocol(asyncio.Protocol):
|
||||
__slots__ = (
|
||||
"config",
|
||||
"app",
|
||||
"loop",
|
||||
"logger",
|
||||
"access_logger",
|
||||
"access_log",
|
||||
"ws_protocol_class",
|
||||
"root_path",
|
||||
"limit_concurrency",
|
||||
"app_state",
|
||||
# Timeouts
|
||||
"timeout_keep_alive_task",
|
||||
"timeout_keep_alive",
|
||||
# Global state
|
||||
"server_state",
|
||||
"connections",
|
||||
"tasks",
|
||||
# Per-connection state
|
||||
"transport",
|
||||
"flow",
|
||||
"server",
|
||||
"client",
|
||||
# Per-request state
|
||||
"scope",
|
||||
"headers",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Global state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from typing import Any, Callable, Literal, cast
|
||||
from typing import Any, Callable, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import h11
|
||||
@ -20,6 +20,7 @@ from uvicorn._types import (
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.base import HTTPProtocol
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
@ -35,7 +36,7 @@ def _get_status_phrase(status_code: int) -> bytes:
|
||||
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class H11Protocol(asyncio.Protocol):
|
||||
class H11Protocol(HTTPProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
@ -43,55 +44,24 @@ class H11Protocol(asyncio.Protocol):
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
super().__init__(config, server_state, app_state, _loop)
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
self.conn = h11.Connection(
|
||||
h11.SERVER,
|
||||
config.h11_max_incomplete_event_size
|
||||
if config.h11_max_incomplete_event_size is not None
|
||||
else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE,
|
||||
)
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Shared server state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self.connections.add(self)
|
||||
|
||||
self.transport = transport
|
||||
self.flow = FlowControl(transport)
|
||||
self.transport = cast(asyncio.Transport, transport)
|
||||
self.flow = FlowControl(self.transport)
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "https" if is_ssl(transport) else "http"
|
||||
@ -204,7 +174,7 @@ class H11Protocol(asyncio.Protocol):
|
||||
"http_version": event.http_version.decode("ascii"),
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"scheme": self.scheme,
|
||||
"method": event.method.decode("ascii"),
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
@ -534,10 +504,6 @@ class RequestResponseCycle:
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
message: HTTPRequestEvent = {
|
||||
"type": "http.request",
|
||||
"body": self.body,
|
||||
"more_body": self.more_body,
|
||||
}
|
||||
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
|
||||
self.body = b""
|
||||
return message
|
||||
|
||||
@ -5,9 +5,8 @@ import http
|
||||
import logging
|
||||
import re
|
||||
import urllib
|
||||
from asyncio.events import TimerHandle
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Literal, cast
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import httptools
|
||||
|
||||
@ -21,6 +20,7 @@ from uvicorn._types import (
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.base import HTTPProtocol
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
@ -40,7 +40,7 @@ def _get_status_line(status_code: int) -> bytes:
|
||||
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class HttpToolsProtocol(asyncio.Protocol):
|
||||
class HttpToolsProtocol(HTTPProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
@ -48,15 +48,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
super().__init__(config, server_state, app_state, _loop)
|
||||
self.parser = httptools.HttpRequestParser(self)
|
||||
|
||||
try:
|
||||
@ -66,42 +58,19 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
# httptools < 0.6.3
|
||||
pass
|
||||
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Global state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.expect_100_continue = False
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self.connections.add(self)
|
||||
|
||||
self.transport = transport
|
||||
self.flow = FlowControl(transport)
|
||||
self.transport = cast(asyncio.Transport, transport)
|
||||
self.flow = FlowControl(self.transport)
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "https" if is_ssl(transport) else "http"
|
||||
@ -226,7 +195,7 @@ class HttpToolsProtocol(asyncio.Protocol):
|
||||
"http_version": "1.1",
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"scheme": self.scheme,
|
||||
"root_path": self.root_path,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
|
||||
@ -9,7 +9,7 @@ from uvicorn._types import WWWScope
|
||||
class ClientDisconnected(OSError): ...
|
||||
|
||||
|
||||
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
def get_remote_addr(transport: asyncio.BaseTransport) -> tuple[str, int] | None:
|
||||
socket_info = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
try:
|
||||
@ -26,7 +26,7 @@ def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
def get_local_addr(transport: asyncio.BaseTransport) -> tuple[str, int] | None:
|
||||
socket_info = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
info = socket_info.getsockname()
|
||||
@ -38,7 +38,7 @@ def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
return None
|
||||
|
||||
|
||||
def is_ssl(transport: asyncio.Transport) -> bool:
|
||||
def is_ssl(transport: asyncio.BaseTransport) -> bool:
|
||||
return bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user