Compare commits

...

3 Commits

Author SHA1 Message Date
Marcelo Trylesinski
4927f77b1a
Merge branch 'main' into define-base-http-class 2025-09-23 06:10:17 -07:00
Marcelo Trylesinski
e37847705b
Merge branch 'main' into define-base-http-class 2025-09-21 06:56:14 -07:00
Marcelo Trylesinski
806b429bb9 Define HTTPProtocol class 2025-09-14 15:00:40 +02:00
7 changed files with 175 additions and 171 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import logging
import socket
import threading
@ -14,6 +15,7 @@ from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallabl
from uvicorn.config import WS_PROTOCOLS, Config
from uvicorn.lifespan.off import LifespanOff
from uvicorn.lifespan.on import LifespanOn
from uvicorn.protocols.http.base import HTTPProtocol
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.server import ServerState
@ -36,7 +38,6 @@ if TYPE_CHECKING:
else: # pragma: no cover
from typing_extensions import TypeAlias
HTTPProtocol: TypeAlias = "type[HttpToolsProtocol | H11Protocol]"
WSProtocol: TypeAlias = "type[WebSocketProtocol | _WSProtocol]"
pytestmark = pytest.mark.anyio
@ -173,7 +174,9 @@ UPGRADE_REQUEST_ERROR_FIELD = b"\r\n".join(
class MockTransport:
def __init__(self, sockname=None, peername=None, sslcontext=False):
def __init__(
self, sockname: tuple[str, int] | None = None, peername: tuple[str, int] | None = None, sslcontext: bool = False
):
self.sockname = ("127.0.0.1", 8000) if sockname is None else sockname
self.peername = ("127.0.0.1", 8001) if peername is None else peername
self.sslcontext = sslcontext
@ -181,14 +184,10 @@ class MockTransport:
self.buffer = b""
self.read_paused = False
def get_extra_info(self, key):
return {
"sockname": self.sockname,
"peername": self.peername,
"sslcontext": self.sslcontext,
}.get(key)
def get_extra_info(self, key: Any):
return {"sockname": self.sockname, "peername": self.peername, "sslcontext": self.sslcontext}.get(key)
def write(self, data):
def write(self, data: bytes):
assert not self.closed
self.buffer += data
@ -208,7 +207,7 @@ class MockTransport:
def clear_buffer(self):
self.buffer = b""
def set_protocol(self, protocol):
def set_protocol(self, protocol: asyncio.Protocol):
pass
@ -258,12 +257,17 @@ class MockTask:
pass
class MockProtocol(HTTPProtocol):
loop: MockLoop # type: ignore[assignment]
transport: MockTransport # type: ignore[assignment]
def get_connected_protocol(
app: ASGIApplication,
http_protocol_cls: HTTPProtocol,
http_protocol_cls: type[HTTPProtocol],
lifespan: LifespanOff | LifespanOn | None = None,
**kwargs: Any,
):
) -> MockProtocol:
loop = MockLoop()
transport = MockTransport()
config = Config(app=app, **kwargs)
@ -273,13 +277,13 @@ def get_connected_protocol(
config=config,
server_state=server_state,
app_state=lifespan.state,
_loop=loop, # type: ignore
_loop=loop, # type: ignore[arg-type]
)
protocol.connection_made(transport) # type: ignore
return protocol
protocol.connection_made(transport) # type: ignore[arg-type]
return protocol # type: ignore[return-value]
async def test_get_request(http_protocol_cls: HTTPProtocol):
async def test_get_request(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -298,7 +302,7 @@ async def test_get_request(http_protocol_cls: HTTPProtocol):
pytest.param("µ", id="allow_non_ascii_char"),
],
)
async def test_header_value_allowed_characters(http_protocol_cls: HTTPProtocol, char: str):
async def test_header_value_allowed_characters(http_protocol_cls: type[HTTPProtocol], char: str):
app = Response("Hello, world", media_type="text/plain", headers={"key": f"<{char}>"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
@ -309,7 +313,7 @@ async def test_header_value_allowed_characters(http_protocol_cls: HTTPProtocol,
@pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"])
async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture):
async def test_request_logging(path: str, http_protocol_cls: type[HTTPProtocol], caplog: pytest.LogCaptureFixture):
get_request_with_query_string = b"\r\n".join(
[f"GET {path} HTTP/1.1".encode("ascii"), b"Host: example.org", b"", b""]
)
@ -324,7 +328,7 @@ async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplo
assert f'"GET {path} HTTP/1.1" 200' in caplog.records[0].message
async def test_head_request(http_protocol_cls: HTTPProtocol):
async def test_head_request(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -334,7 +338,7 @@ async def test_head_request(http_protocol_cls: HTTPProtocol):
assert b"Hello, world" not in protocol.transport.buffer
async def test_post_request(http_protocol_cls: HTTPProtocol):
async def test_post_request(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""
more_body = True
@ -353,7 +357,7 @@ async def test_post_request(http_protocol_cls: HTTPProtocol):
assert b'Body: {"hello": "world"}' in protocol.transport.buffer
async def test_keepalive(http_protocol_cls: HTTPProtocol):
async def test_keepalive(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", status_code=204)
protocol = get_connected_protocol(app, http_protocol_cls)
@ -364,7 +368,7 @@ async def test_keepalive(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.is_closing()
async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol):
async def test_keepalive_timeout(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", status_code=204)
protocol = get_connected_protocol(app, http_protocol_cls)
@ -378,9 +382,7 @@ async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_keepalive_timeout_with_pipelined_requests(
http_protocol_cls: HTTPProtocol,
):
async def test_keepalive_timeout_with_pipelined_requests(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -403,7 +405,7 @@ async def test_keepalive_timeout_with_pipelined_requests(
assert protocol.timeout_keep_alive_task is not None
async def test_close(http_protocol_cls: HTTPProtocol):
async def test_close(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", status_code=204, headers={"connection": "close"})
protocol = get_connected_protocol(app, http_protocol_cls)
@ -413,7 +415,7 @@ async def test_close(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
async def test_chunked_encoding(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
protocol = get_connected_protocol(app, http_protocol_cls)
@ -424,7 +426,7 @@ async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.is_closing()
async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
async def test_chunked_encoding_empty_body(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
protocol = get_connected_protocol(app, http_protocol_cls)
@ -435,9 +437,7 @@ async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.is_closing()
async def test_chunked_encoding_head_request(
http_protocol_cls: HTTPProtocol,
):
async def test_chunked_encoding_head_request(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
protocol = get_connected_protocol(app, http_protocol_cls)
@ -447,7 +447,7 @@ async def test_chunked_encoding_head_request(
assert not protocol.transport.is_closing()
async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
async def test_pipelined_requests(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -468,7 +468,7 @@ async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
protocol.transport.clear_buffer()
async def test_undersized_request(http_protocol_cls: HTTPProtocol):
async def test_undersized_request(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"xxx", headers={"content-length": "10"})
protocol = get_connected_protocol(app, http_protocol_cls)
@ -477,7 +477,7 @@ async def test_undersized_request(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_oversized_request(http_protocol_cls: HTTPProtocol):
async def test_oversized_request(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"xxx" * 20, headers={"content-length": "10"})
protocol = get_connected_protocol(app, http_protocol_cls)
@ -486,7 +486,7 @@ async def test_oversized_request(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_large_post_request(http_protocol_cls: HTTPProtocol):
async def test_large_post_request(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -496,7 +496,7 @@ async def test_large_post_request(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.read_paused
async def test_invalid_http(http_protocol_cls: HTTPProtocol):
async def test_invalid_http(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -504,7 +504,7 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_app_exception(http_protocol_cls: HTTPProtocol):
async def test_app_exception(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
raise Exception()
@ -515,7 +515,7 @@ async def test_app_exception(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
async def test_exception_during_response(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b"1", "more_body": True})
@ -528,7 +528,7 @@ async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
async def test_no_response_returned(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): ...
protocol = get_connected_protocol(app, http_protocol_cls)
@ -538,7 +538,7 @@ async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
async def test_partial_response_returned(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
@ -549,7 +549,7 @@ async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_response_header_splitting(http_protocol_cls: HTTPProtocol):
async def test_response_header_splitting(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", headers={"key": "value\r\nCookie: smuggled=value"})
protocol = get_connected_protocol(app, http_protocol_cls)
@ -560,7 +560,7 @@ async def test_response_header_splitting(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
async def test_duplicate_start_message(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.start", "status": 200})
@ -572,7 +572,7 @@ async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
async def test_missing_start_message(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.body", "body": b""})
@ -583,7 +583,7 @@ async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
async def test_message_after_body_complete(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
@ -596,7 +596,7 @@ async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_value_returned(http_protocol_cls: HTTPProtocol):
async def test_value_returned(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
@ -609,7 +609,7 @@ async def test_value_returned(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
async def test_early_disconnect(http_protocol_cls: type[HTTPProtocol]):
got_disconnect_event = False
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
@ -630,7 +630,7 @@ async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
assert got_disconnect_event
async def test_early_response(http_protocol_cls: HTTPProtocol):
async def test_early_response(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -641,7 +641,7 @@ async def test_early_response(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.is_closing()
async def test_read_after_response(http_protocol_cls: HTTPProtocol):
async def test_read_after_response(http_protocol_cls: type[HTTPProtocol]):
message_after_response = None
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
@ -658,7 +658,7 @@ async def test_read_after_response(http_protocol_cls: HTTPProtocol):
assert message_after_response == {"type": "http.disconnect"}
async def test_http10_request(http_protocol_cls: HTTPProtocol):
async def test_http10_request(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
content = "Version: %s" % scope["http_version"]
@ -672,7 +672,7 @@ async def test_http10_request(http_protocol_cls: HTTPProtocol):
assert b"Version: 1.0" in protocol.transport.buffer
async def test_root_path(http_protocol_cls: HTTPProtocol):
async def test_root_path(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
root_path = scope.get("root_path", "")
@ -687,7 +687,7 @@ async def test_root_path(http_protocol_cls: HTTPProtocol):
assert b"root_path=/app path=/app/" in protocol.transport.buffer
async def test_raw_path(http_protocol_cls: HTTPProtocol):
async def test_raw_path(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
path = scope["path"]
@ -704,7 +704,7 @@ async def test_raw_path(http_protocol_cls: HTTPProtocol):
assert b"Done" in protocol.transport.buffer
async def test_max_concurrency(http_protocol_cls: HTTPProtocol):
async def test_max_concurrency(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, limit_concurrency=1)
@ -725,27 +725,27 @@ async def test_max_concurrency(http_protocol_cls: HTTPProtocol):
)
async def test_shutdown_during_request(http_protocol_cls: HTTPProtocol):
async def test_shutdown_during_request(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", status_code=204)
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.shutdown()
protocol.shutdown() # type: ignore[attr-defined]
await protocol.loop.run_one()
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_shutdown_during_idle(http_protocol_cls: HTTPProtocol):
async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.shutdown()
protocol.shutdown() # type: ignore[attr-defined]
assert protocol.transport.buffer == b""
assert protocol.transport.is_closing()
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: HTTPProtocol):
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""
more_body = True
@ -777,7 +777,7 @@ async def test_100_continue_sent_when_body_consumed(http_protocol_cls: HTTPProto
async def test_100_continue_not_sent_when_body_not_consumed(
http_protocol_cls: HTTPProtocol,
http_protocol_cls: type[HTTPProtocol],
):
app = Response(b"", status_code=204)
@ -799,7 +799,7 @@ async def test_100_continue_not_sent_when_body_not_consumed(
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol):
async def test_supported_upgrade_request(http_protocol_cls: type[HTTPProtocol]):
pytest.importorskip("wsproto")
app = Response("Hello, world", media_type="text/plain")
@ -809,7 +809,7 @@ async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol):
assert b"HTTP/1.1 426 " in protocol.transport.buffer
async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol):
async def test_unsupported_ws_upgrade_request(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, ws="none")
@ -820,7 +820,7 @@ async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol):
async def test_unsupported_ws_upgrade_request_warn_on_auto(
caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
):
app = Response("Hello, world", media_type="text/plain")
@ -836,7 +836,7 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto(
assert msg in warnings
async def test_http2_upgrade_request(http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol):
async def test_http2_upgrade_request(http_protocol_cls: type[HTTPProtocol], ws_protocol_cls: WSProtocol):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, ws=ws_protocol_cls)
@ -867,7 +867,7 @@ def asgi2app(scope: Scope):
async def test_scopes(
asgi2or3_app: ASGIApplication,
expected_scopes: dict[str, str],
http_protocol_cls: HTTPProtocol,
http_protocol_cls: type[HTTPProtocol],
):
protocol = get_connected_protocol(asgi2or3_app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
@ -884,7 +884,7 @@ async def test_scopes(
],
)
async def test_invalid_http_request(
request_line: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture
request_line: str, http_protocol_cls: type[HTTPProtocol], caplog: pytest.LogCaptureFixture
):
app = Response("Hello, world", media_type="text/plain")
request = INVALID_REQUEST_TEMPLATE % request_line
@ -1007,7 +1007,7 @@ async def test_huge_headers_h11_max_incomplete():
assert b"Hello, world" in protocol.transport.buffer
async def test_return_close_header(http_protocol_cls: HTTPProtocol):
async def test_return_close_header(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -1021,7 +1021,7 @@ async def test_return_close_header(http_protocol_cls: HTTPProtocol):
assert b"connection: close" in protocol.transport.buffer.lower()
async def test_close_connection_with_multiple_requests(http_protocol_cls: HTTPProtocol):
async def test_close_connection_with_multiple_requests(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
@ -1035,7 +1035,7 @@ async def test_close_connection_with_multiple_requests(http_protocol_cls: HTTPPr
assert b"connection: close" in protocol.transport.buffer.lower()
async def test_close_connection_with_post_request(http_protocol_cls: HTTPProtocol):
async def test_close_connection_with_post_request(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""
more_body = True
@ -1054,7 +1054,7 @@ async def test_close_connection_with_post_request(http_protocol_cls: HTTPProtoco
assert b"Body: {'hello': 'world'}" in protocol.transport.buffer
async def test_iterator_headers(http_protocol_cls: HTTPProtocol):
async def test_iterator_headers(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
headers = iter([(b"x-test-header", b"test value")])
await send({"type": "http.response.start", "status": 200, "headers": headers})
@ -1066,7 +1066,7 @@ async def test_iterator_headers(http_protocol_cls: HTTPProtocol):
assert b"x-test-header: test value" in protocol.transport.buffer
async def test_lifespan_state(http_protocol_cls: HTTPProtocol):
async def test_lifespan_state(http_protocol_cls: type[HTTPProtocol]):
expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}]
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
@ -1095,7 +1095,7 @@ async def test_lifespan_state(http_protocol_cls: HTTPProtocol):
async def test_header_upgrade_is_not_websocket_depend_installed(
caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
):
caplog.set_level(logging.WARNING, logger="uvicorn.error")
app = Response("Hello, world", media_type="text/plain")
@ -1111,7 +1111,7 @@ async def test_header_upgrade_is_not_websocket_depend_installed(
async def test_header_upgrade_is_websocket_depend_not_installed(
caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
):
caplog.set_level(logging.WARNING, logger="uvicorn.error")
app = Response("Hello, world", media_type="text/plain")

View File

@ -270,12 +270,5 @@ class ASGI2Protocol(Protocol):
ASGI2Application = type[ASGI2Protocol]
ASGI3Application = Callable[
[
Scope,
ASGIReceiveCallable,
ASGISendCallable,
],
Awaitable[None],
]
ASGI3Application = Callable[[Scope, ASGIReceiveCallable, ASGISendCallable], Awaitable[None]]
ASGIApplication = Union[ASGI2Application, ASGI3Application]

View File

@ -96,13 +96,7 @@ class AccessFormatter(ColourizedFormatter):
def formatMessage(self, record: logging.LogRecord) -> str:
recordcopy = copy(record)
(
client_addr,
method,
full_path,
http_version,
status_code,
) = recordcopy.args # type: ignore[misc]
(client_addr, method, full_path, http_version, status_code) = recordcopy.args # type: ignore[misc]
status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type]
request_line = f"{method} {full_path} HTTP/{http_version}"
if self.use_colors:

View File

@ -0,0 +1,82 @@
from __future__ import annotations as _annotations
import asyncio
import logging
from typing import Any
from uvicorn._types import HTTPScope
from uvicorn.config import Config
from uvicorn.protocols.http.flow_control import FlowControl
from uvicorn.server import ServerState
class HTTPProtocol(asyncio.Protocol):
__slots__ = (
"config",
"app",
"loop",
"logger",
"access_logger",
"access_log",
"ws_protocol_class",
"root_path",
"limit_concurrency",
"app_state",
# Timeouts
"timeout_keep_alive_task",
"timeout_keep_alive",
# Global state
"server_state",
"connections",
"tasks",
# Per-connection state
"transport",
"flow",
"server",
"client",
# Per-request state
"scope",
"headers",
)
def __init__(
self,
config: Config,
server_state: ServerState,
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load()
self.config = config
self.app = config.loaded_app
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
self.access_log = self.access_logger.hasHandlers()
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
self.app_state = app_state
# Timeouts
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
self.timeout_keep_alive = config.timeout_keep_alive
# Global state
self.server_state = server_state
self.connections = server_state.connections
self.tasks = server_state.tasks
# Per-connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
# Per-request state
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import http
import logging
from typing import Any, Callable, Literal, cast
from typing import Any, Callable, cast
from urllib.parse import unquote
import h11
@ -20,6 +20,7 @@ from uvicorn._types import (
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.base import HTTPProtocol
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState
@ -35,7 +36,7 @@ def _get_status_phrase(status_code: int) -> bytes:
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
class H11Protocol(asyncio.Protocol):
class H11Protocol(HTTPProtocol):
def __init__(
self,
config: Config,
@ -43,55 +44,24 @@ class H11Protocol(asyncio.Protocol):
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load()
super().__init__(config, server_state, app_state, _loop)
self.config = config
self.app = config.loaded_app
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
self.access_log = self.access_logger.hasHandlers()
self.conn = h11.Connection(
h11.SERVER,
config.h11_max_incomplete_event_size
if config.h11_max_incomplete_event_size is not None
else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE,
)
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
self.app_state = app_state
# Timeouts
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
self.timeout_keep_alive = config.timeout_keep_alive
# Shared server state
self.server_state = server_state
self.connections = server_state.connections
self.tasks = server_state.tasks
# Per-connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["http", "https"] | None = None
# Per-request state
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
# Protocol interface
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.connections.add(self)
self.transport = transport
self.flow = FlowControl(transport)
self.transport = cast(asyncio.Transport, transport)
self.flow = FlowControl(self.transport)
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "https" if is_ssl(transport) else "http"
@ -204,7 +174,7 @@ class H11Protocol(asyncio.Protocol):
"http_version": event.http_version.decode("ascii"),
"server": self.server,
"client": self.client,
"scheme": self.scheme, # type: ignore[typeddict-item]
"scheme": self.scheme,
"method": event.method.decode("ascii"),
"root_path": self.root_path,
"path": full_path,
@ -534,10 +504,6 @@ class RequestResponseCycle:
if self.disconnected or self.response_complete:
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,
}
message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
self.body = b""
return message

View File

@ -5,9 +5,8 @@ import http
import logging
import re
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import Any, Callable, Literal, cast
from typing import Any, Callable, cast
import httptools
@ -21,6 +20,7 @@ from uvicorn._types import (
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.base import HTTPProtocol
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState
@ -40,7 +40,7 @@ def _get_status_line(status_code: int) -> bytes:
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
class HttpToolsProtocol(asyncio.Protocol):
class HttpToolsProtocol(HTTPProtocol):
def __init__(
self,
config: Config,
@ -48,15 +48,7 @@ class HttpToolsProtocol(asyncio.Protocol):
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load()
self.config = config
self.app = config.loaded_app
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
self.access_log = self.access_logger.hasHandlers()
super().__init__(config, server_state, app_state, _loop)
self.parser = httptools.HttpRequestParser(self)
try:
@ -66,42 +58,19 @@ class HttpToolsProtocol(asyncio.Protocol):
# httptools < 0.6.3
pass
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
self.app_state = app_state
# Timeouts
self.timeout_keep_alive_task: TimerHandle | None = None
self.timeout_keep_alive = config.timeout_keep_alive
# Global state
self.server_state = server_state
self.connections = server_state.connections
self.tasks = server_state.tasks
# Per-connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["http", "https"] | None = None
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
# Per-request state
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
self.expect_100_continue = False
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
# Protocol interface
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.connections.add(self)
self.transport = transport
self.flow = FlowControl(transport)
self.transport = cast(asyncio.Transport, transport)
self.flow = FlowControl(self.transport)
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "https" if is_ssl(transport) else "http"
@ -226,7 +195,7 @@ class HttpToolsProtocol(asyncio.Protocol):
"http_version": "1.1",
"server": self.server,
"client": self.client,
"scheme": self.scheme, # type: ignore[typeddict-item]
"scheme": self.scheme,
"root_path": self.root_path,
"headers": self.headers,
"state": self.app_state.copy(),

View File

@ -9,7 +9,7 @@ from uvicorn._types import WWWScope
class ClientDisconnected(OSError): ...
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
def get_remote_addr(transport: asyncio.BaseTransport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
if socket_info is not None:
try:
@ -26,7 +26,7 @@ def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
return None
def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
def get_local_addr(transport: asyncio.BaseTransport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
if socket_info is not None:
info = socket_info.getsockname()
@ -38,7 +38,7 @@ def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
return None
def is_ssl(transport: asyncio.Transport) -> bool:
def is_ssl(transport: asyncio.BaseTransport) -> bool:
return bool(transport.get_extra_info("sslcontext"))