Compare commits

...

2 Commits

Author SHA1 Message Date
florimondmanca
9449195a27
Use httparse==0.1.2 2022-11-13 22:36:49 +01:00
florimondmanca
92163a2456
Experimental: integrate httparse 2022-11-13 13:11:12 +01:00
9 changed files with 721 additions and 12 deletions

View File

@ -57,7 +57,8 @@ Options:
$WEB_CONCURRENCY environment variable if
available, or 1. Not valid with --reload.
--loop [auto|asyncio|uvloop] Event loop implementation. [default: auto]
--http [auto|h11|httptools] HTTP protocol implementation. [default:
--http [auto|h11|httptools|httparse]
HTTP protocol implementation. [default:
auto]
--ws [auto|none|websockets|wsproto]
WebSocket protocol implementation.

View File

@ -124,7 +124,8 @@ Options:
$WEB_CONCURRENCY environment variable if
available, or 1. Not valid with --reload.
--loop [auto|asyncio|uvloop] Event loop implementation. [default: auto]
--http [auto|h11|httptools] HTTP protocol implementation. [default:
--http [auto|h11|httptools|httparse]
HTTP protocol implementation. [default:
auto]
--ws [auto|none|websockets|wsproto]
WebSocket protocol implementation.

View File

@ -1,5 +1,8 @@
-e .[standard]
# Experimental
httparse==0.1.2
# Type annotation
asgiref==3.5.2

View File

@ -16,8 +16,17 @@ try:
except ImportError: # pragma: nocover
HttpToolsProtocol = None
try:
from uvicorn.protocols.http.httparse_impl import HttparseProtocol
except ImportError: # pragma: nocover
HttparseProtocol = None
HTTP_PROTOCOLS = [p for p in [H11Protocol, HttpToolsProtocol] if p is not None]
HTTP_PROTOCOLS = [
pytest.param(p, id=p.__name__.replace("Protocol", "").lower())
for p in [H11Protocol, HttpToolsProtocol, HttparseProtocol]
if p is not None
]
WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys()
SIMPLE_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"", b""])
@ -197,6 +206,7 @@ async def test_get_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@ -216,6 +226,7 @@ async def test_request_logging(path, protocol_cls, caplog):
protocol = get_connected_protocol(app, protocol_cls, log_config=None)
protocol.data_received(get_request_with_query_string)
protocol.eof_received()
await protocol.loop.run_one()
assert '"GET {} HTTP/1.1" 200'.format(path) in caplog.records[0].message
@ -227,6 +238,7 @@ async def test_head_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_HEAD_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" not in protocol.transport.buffer
@ -247,6 +259,7 @@ async def test_post_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_POST_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b'Body: {"hello": "world"}' in protocol.transport.buffer
@ -259,6 +272,7 @@ async def test_keepalive(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
@ -272,6 +286,7 @@ async def test_keepalive_timeout(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
assert not protocol.transport.is_closing()
@ -288,6 +303,7 @@ async def test_close(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -302,6 +318,7 @@ async def test_chunked_encoding(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"0\r\n\r\n" in protocol.transport.buffer
@ -317,6 +334,7 @@ async def test_chunked_encoding_empty_body(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert protocol.transport.buffer.count(b"0\r\n\r\n") == 1
@ -332,6 +350,7 @@ async def test_chunked_encoding_head_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_HEAD_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert not protocol.transport.is_closing()
@ -344,8 +363,11 @@ async def test_pipelined_requests(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@ -367,6 +389,7 @@ async def test_undersized_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert protocol.transport.is_closing()
@ -378,6 +401,7 @@ async def test_oversized_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert protocol.transport.is_closing()
@ -389,6 +413,7 @@ async def test_large_post_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(LARGE_POST_REQUEST)
protocol.eof_received()
assert protocol.transport.read_paused
await protocol.loop.run_one()
assert not protocol.transport.read_paused
@ -401,6 +426,7 @@ async def test_invalid_http(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(b"x" * 100000)
protocol.eof_received()
assert protocol.transport.is_closing()
@ -412,6 +438,7 @@ async def test_app_exception(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -427,6 +454,7 @@ async def test_exception_during_response(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -440,6 +468,7 @@ async def test_no_response_returned(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -453,6 +482,7 @@ async def test_partial_response_returned(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -467,6 +497,7 @@ async def test_duplicate_start_message(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -480,6 +511,7 @@ async def test_missing_start_message(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -495,6 +527,7 @@ async def test_message_after_body_complete(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -510,6 +543,7 @@ async def test_value_returned(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert protocol.transport.is_closing()
@ -548,6 +582,7 @@ async def test_early_response(protocol_cls):
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
protocol.data_received(FINISH_POST_REQUEST)
protocol.eof_received()
assert not protocol.transport.is_closing()
@ -565,6 +600,7 @@ async def test_read_after_response(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(SIMPLE_POST_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert message_after_response == {"type": "http.disconnect"}
@ -580,6 +616,7 @@ async def test_http10_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(HTTP10_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Version: 1.0" in protocol.transport.buffer
@ -595,6 +632,7 @@ async def test_root_path(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls, root_path="/app")
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Path: /app/" in protocol.transport.buffer
@ -614,6 +652,7 @@ async def test_raw_path(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls, root_path="/app")
protocol.data_received(GET_REQUEST_WITH_RAW_PATH)
protocol.eof_received()
await protocol.loop.run_one()
assert b"Done" in protocol.transport.buffer
@ -625,6 +664,7 @@ async def test_max_concurrency(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls, limit_concurrency=1)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 503 Service Unavailable" in protocol.transport.buffer
@ -679,6 +719,7 @@ async def test_100_continue_sent_when_body_consumed(protocol_cls):
]
)
protocol.data_received(EXPECT_100_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 100 Continue" in protocol.transport.buffer
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
@ -703,6 +744,7 @@ async def test_100_continue_not_sent_when_body_not_consumed(protocol_cls):
]
)
protocol.data_received(EXPECT_100_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 100 Continue" not in protocol.transport.buffer
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
@ -715,6 +757,7 @@ async def test_supported_upgrade_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls, ws="wsproto")
protocol.data_received(UPGRADE_REQUEST)
protocol.eof_received()
assert b"HTTP/1.1 426 " in protocol.transport.buffer
@ -725,6 +768,7 @@ async def test_unsupported_ws_upgrade_request(protocol_cls):
protocol = get_connected_protocol(app, protocol_cls, ws="none")
protocol.data_received(UPGRADE_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@ -740,6 +784,7 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto(
protocol = get_connected_protocol(app, protocol_cls, ws="auto")
protocol.ws_protocol_class = None
protocol.data_received(UPGRADE_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@ -762,6 +807,7 @@ async def test_http2_upgrade_request(protocol_cls, ws):
protocol = get_connected_protocol(app, protocol_cls, ws=ws)
protocol.data_received(UPGRADE_HTTP2_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@ -790,6 +836,7 @@ asgi_scope_data = [
async def test_scopes(asgi2or3_app, expected_scopes, protocol_cls):
protocol = get_connected_protocol(asgi2or3_app, protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
await protocol.loop.run_one()
assert expected_scopes == protocol.scope.get("asgi")
@ -813,6 +860,7 @@ async def test_invalid_http_request(request_line, protocol_cls, caplog):
protocol = get_connected_protocol(app, protocol_cls)
protocol.data_received(request)
protocol.eof_received()
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer
@ -874,6 +922,7 @@ async def test_huge_headers_h11protocol_failure():
# Huge headers make h11 fail in it's default config
# h11 sends back a 400 in this case
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.eof_received()
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Connection: close" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer
@ -889,6 +938,7 @@ async def test_huge_headers_httptools_will_pass():
# httptools protocol will always pass
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.data_received(GET_REQUEST_HUGE_HEADERS[1])
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@ -904,6 +954,7 @@ async def test_huge_headers_h11protocol_failure_with_setting():
# Huge headers make h11 fail in it's default config
# h11 sends back a 400 in this case
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.eof_received()
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Connection: close" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer
@ -919,6 +970,21 @@ async def test_huge_headers_httptools():
# httptools protocol will always pass
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.data_received(GET_REQUEST_HUGE_HEADERS[1])
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@pytest.mark.anyio
@pytest.mark.skipif(HttparseProtocol is None, reason="httparse is not installed")
async def test_huge_headers_httparse():
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, HttparseProtocol)
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.data_received(GET_REQUEST_HUGE_HEADERS[1])
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@ -933,6 +999,7 @@ async def test_huge_headers_h11_max_incomplete():
)
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.data_received(GET_REQUEST_HUGE_HEADERS[1])
protocol.eof_received()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer

View File

@ -13,6 +13,11 @@ try:
except ImportError: # pragma: no cover
uvloop = None
try:
import httparse
except ImportError: # pragma: no cover
httparse = None
try:
import httptools
except ImportError: # pragma: no cover
@ -46,7 +51,11 @@ async def test_http_auto():
config = Config(app=app)
server_state = ServerState()
protocol = AutoHTTPProtocol(config=config, server_state=server_state)
expected_http = "H11Protocol" if httptools is None else "HttpToolsProtocol"
expected_http = (
"H11Protocol"
if httptools is None
else ("HttpToolsProtocol" if httparse is None else "HttparseProtocol")
)
assert type(protocol).__name__ == expected_http

View File

@ -49,7 +49,7 @@ from uvicorn.middleware.wsgi import WSGIMiddleware
if TYPE_CHECKING:
from asgiref.typing import ASGIApplication
HTTPProtocolType = Literal["auto", "h11", "httptools"]
HTTPProtocolType = Literal["auto", "h11", "httptools", "httparse"]
WSProtocolType = Literal["auto", "none", "websockets", "wsproto"]
LifespanType = Literal["auto", "on", "off"]
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
@ -67,6 +67,7 @@ HTTP_PROTOCOLS: Dict[HTTPProtocolType, str] = {
"auto": "uvicorn.protocols.http.auto:AutoHTTPProtocol",
"h11": "uvicorn.protocols.http.h11_impl:H11Protocol",
"httptools": "uvicorn.protocols.http.httptools_impl:HttpToolsProtocol",
"httparse": "uvicorn.protocols.http.httparse_impl:HttparseProtocol",
}
WS_PROTOCOLS: Dict[WSProtocolType, Optional[str]] = {
"auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol",

View File

@ -2,13 +2,21 @@ import asyncio
from typing import Type
AutoHTTPProtocol: Type[asyncio.Protocol]
try:
import httptools # noqa
import httparse # noqa
except ImportError: # pragma: no cover
from uvicorn.protocols.http.h11_impl import H11Protocol
try:
import httptools # noqa
except ImportError: # pragma: no cover
from uvicorn.protocols.http.h11_impl import H11Protocol
AutoHTTPProtocol = H11Protocol
else: # pragma: no cover
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
AutoHTTPProtocol = H11Protocol
else: # pragma: no cover
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
AutoHTTPProtocol = HttpToolsProtocol
AutoHTTPProtocol = HttpToolsProtocol
else:
from uvicorn.protocols.http.httparse_impl import HttparseProtocol
AutoHTTPProtocol = HttparseProtocol

View File

@ -0,0 +1,616 @@
import asyncio
import http
import logging
import re
import sys
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import TYPE_CHECKING, Callable, Deque, List, Optional, Tuple, Union, cast
import httparse
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState
if sys.version_info < (3, 8): # pragma: py-gte-38
from typing_extensions import Literal
else: # pragma: py-lt-38
from typing import Literal
if TYPE_CHECKING:
from asgiref.typing import (
ASGI3Application,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPRequestEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
HTTPScope,
)
HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")
def _get_status_line(status_code: int) -> bytes:
try:
phrase = http.HTTPStatus(status_code).phrase.encode()
except ValueError:
phrase = b""
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
STATUS_LINE = {
status_code: _get_status_line(status_code) for status_code in range(100, 600)
}
class HttparseProtocol(asyncio.Protocol):
def __init__(
self,
config: Config,
server_state: ServerState,
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
config.load()
self.config = config
self.app = config.loaded_app
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
self.access_log = self.access_logger.hasHandlers()
self.parser = httparse.RequestParser()
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
# Timeouts
self.timeout_keep_alive_task: Optional[TimerHandle] = None
self.timeout_keep_alive = config.timeout_keep_alive
# Global state
self.server_state = server_state
self.connections = server_state.connections
self.tasks = server_state.tasks
# Per-connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.scheme: Optional[Literal["http", "https"]] = None
self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque()
# Per-request state
self._buffer = b""
self._parsed: Optional[httparse.ParsedRequest] = None
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
self.expect_100_continue = False
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
# Protocol interface
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
self.connections.add(self)
self.transport = transport
self.flow = FlowControl(transport)
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "https" if is_ssl(transport) else "http"
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
def connection_lost(self, exc: Optional[Exception]) -> None:
self.connections.discard(self)
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
if self.cycle and not self.cycle.response_complete:
self.cycle.disconnected = True
if self.cycle is not None:
self.cycle.message_event.set()
if self.flow is not None:
self.flow.resume_writing()
if exc is None:
self.transport.close()
self._unset_keepalive_if_required()
self.parser = httparse.RequestParser()
def _unset_keepalive_if_required(self) -> None:
if self.timeout_keep_alive_task is not None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None
def _get_upgrade(self) -> Optional[bytes]:
connection = []
upgrade = None
for name, value in self.headers:
if name == b"connection":
connection = [token.lower().strip() for token in value.split(b",")]
if name == b"upgrade":
upgrade = value.lower()
if b"upgrade" in connection:
return upgrade
return None
def _should_upgrade_to_ws(self) -> bool:
if self.ws_protocol_class is None:
if self.config.ws == "auto":
msg = "Unsupported upgrade request."
self.logger.warning(msg)
msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501
self.logger.warning(msg)
return False
return True
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()
self._buffer += data
if self._parsed is None:
self._attempt_to_parse_request()
else:
self._consume_body()
def eof_received(self) -> None:
if self._parsed is None:
msg = "Invalid HTTP request received."
self.logger.warning(msg)
self.send_400_response(msg)
return
if self.cycle is not None:
self.cycle.more_body = False
self.cycle.message_event.set()
self._parsed = None
def _attempt_to_parse_request(self) -> None:
try:
parsed = self.parser.parse(self._buffer)
except httparse.ParsingError:
return
if parsed is None:
return
# Rust's httparse does not validate HTTP methods, but our tests
# expect some basic checks like for h11 and httptools.
if not parsed.method.isalpha():
return
self._parsed = parsed
self.expect_100_continue = False
self.headers = []
for h in parsed.headers:
name = h.name.lower().encode("utf-8")
value = h.value
if name == b"expect" and value.lower() == b"100-continue":
self.expect_100_continue = True
self.headers.append((name, value))
http_version = {
0: "1.0",
1: "1.1",
}.get(parsed.version, str(parsed.version))
method = parsed.method
raw_path = parsed.path
path, _, query = raw_path.partition("?")
if "%" in path:
path = urllib.parse.unquote(path)
self.scope = {
"type": "http",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": http_version,
"server": self.server,
"client": self.client,
"scheme": str(self.scheme),
"root_path": self.root_path,
"headers": self.headers,
"method": method,
"path": path,
"raw_path": raw_path.encode("ascii"),
"query_string": query.encode("ascii"),
"extensions": {},
}
upgrade = self._get_upgrade()
if upgrade == b"websocket" and self._should_upgrade_to_ws():
self._handle_websocket_upgrade()
return
# Handle 503 responses when 'limit_concurrency' is exceeded.
if self.limit_concurrency is not None and (
len(self.connections) >= self.limit_concurrency
or len(self.tasks) >= self.limit_concurrency
):
app = service_unavailable
message = "Exceeded concurrency limit."
self.logger.warning(message)
else:
app = self.app
existing_cycle = self.cycle
self.cycle = RequestResponseCycle(
scope=self.scope,
transport=self.transport,
flow=self.flow,
logger=self.logger,
access_logger=self.access_logger,
access_log=self.access_log,
default_headers=self.server_state.default_headers,
message_event=asyncio.Event(),
expect_100_continue=self.expect_100_continue,
keep_alive=http_version != "1.0",
on_response=self._on_response_complete,
)
if existing_cycle is None or existing_cycle.response_complete:
# Standard case - start processing the request.
task = self.loop.create_task(self.cycle.run_asgi(app))
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)
else:
# Pipelined HTTP requests need to be queued up.
self.flow.pause_reading()
self.pipeline.appendleft((self.cycle, app))
self._buffer = self._buffer[parsed.body_start_offset :]
if self._buffer:
self._consume_body()
def _handle_websocket_upgrade(self) -> None:
assert self._parsed is not None
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
self.connections.discard(self)
method = self.scope["method"].encode("ascii")
target = self._parsed.path.encode("ascii")
output = [method, b" ", target, b" HTTP/1.1\r\n"]
for name, value in self.scope["headers"]:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
self.transport.set_protocol(protocol)
def send_400_response(self, msg: str) -> None:
content = [STATUS_LINE[400]]
for name, value in self.server_state.default_headers:
content.extend([name, b": ", value, b"\r\n"])
content.extend(
[
b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
b"connection: close\r\n",
b"\r\n",
msg.encode("ascii"),
]
)
self.transport.write(b"".join(content))
self.transport.close()
def _consume_body(self) -> None:
if self.cycle.response_complete:
return
body = self._buffer
assert body
self.cycle.body += body
self._buffer = b""
if len(self.cycle.body) > HIGH_WATER_LIMIT:
self.flow.pause_reading()
self.cycle.message_event.set()
def _on_response_complete(self) -> None:
# Callback for pipelined HTTP requests to be started.
self.server_state.total_requests += 1
if self.transport.is_closing():
return
# Set a short Keep-Alive timeout.
self._unset_keepalive_if_required()
self.timeout_keep_alive_task = self.loop.call_later(
self.timeout_keep_alive, self.timeout_keep_alive_handler
)
# Unpause data reads if needed.
self.flow.resume_reading()
# Unblock any pipelined events.
if self.pipeline:
cycle, app = self.pipeline.pop()
task = self.loop.create_task(cycle.run_asgi(app))
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)
def shutdown(self) -> None:
"""
Called by the server to commence a graceful shutdown.
"""
if self.cycle is None or self.cycle.response_complete:
self.transport.close()
else:
self.cycle.keep_alive = False
def pause_writing(self) -> None:
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()
def timeout_keep_alive_handler(self) -> None:
"""
Called on a keep-alive connection if no new data is received after a short
delay.
"""
if not self.transport.is_closing():
self.transport.close()
class RequestResponseCycle:
def __init__(
self,
scope: "HTTPScope",
transport: asyncio.Transport,
flow: FlowControl,
logger: logging.Logger,
access_logger: logging.Logger,
access_log: bool,
default_headers: List[Tuple[bytes, bytes]],
message_event: asyncio.Event,
expect_100_continue: bool,
keep_alive: bool,
on_response: Callable[..., None],
):
self.scope = scope
self.transport = transport
self.flow = flow
self.logger = logger
self.access_logger = access_logger
self.access_log = access_log
self.default_headers = default_headers
self.message_event = message_event
self.on_response = on_response
# Connection state
self.disconnected = False
self.keep_alive = keep_alive
self.waiting_for_100_continue = expect_100_continue
# Request state
self.body = b""
self.more_body = True
# Response state
self.response_started = False
self.response_complete = False
self.chunked_encoding: Optional[bool] = None
self.expected_content_length = 0
# ASGI exception wrapper
async def run_asgi(self, app: "ASGI3Application") -> None:
try:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
)
except BaseException as exc:
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
if not self.response_started:
await self.send_500_response()
else:
self.transport.close()
else:
if result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.transport.close()
elif not self.response_started and not self.disconnected:
msg = "ASGI callable returned without starting response."
self.logger.error(msg)
await self.send_500_response()
elif not self.response_complete and not self.disconnected:
msg = "ASGI callable returned without completing response."
self.logger.error(msg)
self.transport.close()
finally:
self.on_response = lambda: None
async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)
# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain()
if self.disconnected:
return
if not self.response_started:
# Sending response status line and headers
if message_type != "http.response.start":
msg = "Expected ASGI message 'http.response.start', but got '%s'."
raise RuntimeError(msg % message_type)
message = cast("HTTPResponseStartEvent", message)
self.response_started = True
self.waiting_for_100_continue = False
status_code = message["status"]
headers = self.default_headers + list(message.get("headers", []))
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
headers = headers + [CLOSE_HEADER]
if self.access_log:
self.access_logger.info(
'%s - "%s %s HTTP/%s" %d',
get_client_addr(self.scope),
self.scope["method"],
get_path_with_query_string(self.scope),
self.scope["http_version"],
status_code,
)
# Write response status line and headers
content = [STATUS_LINE[status_code]]
for name, value in headers:
if HEADER_RE.search(name):
raise RuntimeError("Invalid HTTP header name.")
if HEADER_VALUE_RE.search(value):
raise RuntimeError("Invalid HTTP header value.")
name = name.lower()
if name == b"content-length" and self.chunked_encoding is None:
self.expected_content_length = int(value.decode())
self.chunked_encoding = False
elif name == b"transfer-encoding" and value.lower() == b"chunked":
self.expected_content_length = 0
self.chunked_encoding = True
elif name == b"connection" and value.lower() == b"close":
self.keep_alive = False
content.extend([name, b": ", value, b"\r\n"])
if (
self.chunked_encoding is None
and self.scope["method"] != "HEAD"
and status_code not in (204, 304)
):
# Neither content-length nor transfer-encoding specified
self.chunked_encoding = True
content.append(b"transfer-encoding: chunked\r\n")
content.append(b"\r\n")
self.transport.write(b"".join(content))
elif not self.response_complete:
# Sending response body
if message_type != "http.response.body":
msg = "Expected ASGI message 'http.response.body', but got '%s'."
raise RuntimeError(msg % message_type)
body = cast(bytes, message.get("body", b""))
more_body = message.get("more_body", False)
# Write response body
if self.scope["method"] == "HEAD":
self.expected_content_length = 0
elif self.chunked_encoding:
if body:
content = [b"%x\r\n" % len(body), body, b"\r\n"]
else:
content = []
if not more_body:
content.append(b"0\r\n\r\n")
self.transport.write(b"".join(content))
else:
num_bytes = len(body)
if num_bytes > self.expected_content_length:
raise RuntimeError("Response content longer than Content-Length")
else:
self.expected_content_length -= num_bytes
self.transport.write(body)
# Handle response completion
if not more_body:
if self.expected_content_length != 0:
raise RuntimeError("Response content shorter than Content-Length")
self.response_complete = True
self.message_event.set()
if not self.keep_alive:
self.transport.close()
self.on_response()
else:
# Response already sent
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message_type)
async def receive(self) -> "ASGIReceiveEvent":
if self.waiting_for_100_continue and not self.transport.is_closing():
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.waiting_for_100_continue = False
if not self.disconnected and not self.response_complete:
self.flow.resume_reading()
await self.message_event.wait()
self.message_event.clear()
message: "Union[HTTPDisconnectEvent, HTTPRequestEvent]"
if self.disconnected or self.response_complete:
message = {"type": "http.disconnect"}
else:
message = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,
}
self.body = b""
return message

View File

@ -18,11 +18,14 @@ from uvicorn.config import Config
if TYPE_CHECKING:
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.http.httparse_impl import HttparseProtocol
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol]
Protocols = Union[
H11Protocol, HttpToolsProtocol, HttparseProtocol, WSProtocol, WebSocketProtocol
]
HANDLED_SIGNALS = (