uvicorn/tests/protocols/test_http.py

1157 lines
43 KiB
Python

from __future__ import annotations
import asyncio
import logging
import socket
import threading
import time
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeAlias
import pytest
from tests.response import Response
from uvicorn import Server
from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.config import WS_PROTOCOLS, Config
from uvicorn.lifespan.off import LifespanOff
from uvicorn.lifespan.on import LifespanOn
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.server import ServerState
try:
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
skip_if_no_httptools = pytest.mark.skipif(False, reason="httptools is installed")
except ModuleNotFoundError: # pragma: no cover
skip_if_no_httptools = pytest.mark.skipif(True, reason="httptools is not installed")
if TYPE_CHECKING:
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol as _WSProtocol
WSProtocol: TypeAlias = WebSocketProtocol | _WSProtocol
HTTPProtocol: TypeAlias = H11Protocol | HttpToolsProtocol
pytestmark = pytest.mark.anyio
WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys()
SIMPLE_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"", b""])
SIMPLE_HEAD_REQUEST = b"\r\n".join([b"HEAD / HTTP/1.1", b"Host: example.org", b"", b""])
SIMPLE_POST_REQUEST = b"\r\n".join(
[
b"POST / HTTP/1.1",
b"Host: example.org",
b"Content-Type: application/json",
b"Content-Length: 18",
b"",
b'{"hello": "world"}',
]
)
CONNECTION_CLOSE_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"Connection: close", b"", b""])
CONNECTION_CLOSE_POST_REQUEST = b"\r\n".join(
[
b"POST / HTTP/1.1",
b"Host: example.org",
b"Connection: close",
b"Content-Type: application/json",
b"Content-Length: 18",
b"",
b"{'hello': 'world'}",
]
)
REQUEST_AFTER_CONNECTION_CLOSE = b"\r\n".join(
[
b"GET / HTTP/1.1",
b"Host: example.org",
b"Connection: close",
b"",
b"",
b"GET / HTTP/1.1",
b"Host: example.org",
b"",
b"",
]
)
LARGE_POST_REQUEST = b"\r\n".join(
[
b"POST / HTTP/1.1",
b"Host: example.org",
b"Content-Type: text/plain",
b"Content-Length: 100000",
b"",
b"x" * 100000,
]
)
START_POST_REQUEST = b"\r\n".join(
[
b"POST / HTTP/1.1",
b"Host: example.org",
b"Content-Type: application/json",
b"Content-Length: 18",
b"",
b"",
]
)
FINISH_POST_REQUEST = b'{"hello": "world"}'
HTTP10_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.0", b"Host: example.org", b"", b""])
GET_REQUEST_WITH_RAW_PATH = b"\r\n".join([b"GET /one%2Ftwo HTTP/1.1", b"Host: example.org", b"", b""])
UPGRADE_REQUEST = b"\r\n".join(
[
b"GET / HTTP/1.1",
b"Host: example.org",
b"Connection: upgrade",
b"Upgrade: websocket",
b"Sec-WebSocket-Version: 11",
b"",
b"",
]
)
UPGRADE_HTTP2_REQUEST = b"\r\n".join(
[
b"GET / HTTP/1.1",
b"Host: example.org",
b"Connection: upgrade",
b"Upgrade: h2c",
b"Sec-WebSocket-Version: 11",
b"",
b"",
]
)
INVALID_REQUEST_TEMPLATE = b"\r\n".join(
[
b"%s",
b"Host: example.org",
b"",
b"",
]
)
GET_REQUEST_HUGE_HEADERS = [
b"".join(
[
b"GET / HTTP/1.1\r\n",
b"Host: example.org\r\n",
b"Cookie: " + b"x" * 32 * 1024,
]
),
b"".join([b"x" * 32 * 1024 + b"\r\n", b"\r\n", b"\r\n"]),
]
UPGRADE_REQUEST_ERROR_FIELD = b"\r\n".join(
[
b"GET / HTTP/1.1",
b"Host: example.org",
b"Connection: upgrade",
b"Upgrade: not-websocket",
b"Sec-WebSocket-Version: 11",
b"",
b"",
]
)
class MockTransport:
def __init__(
self, sockname: tuple[str, int] | None = None, peername: tuple[str, int] | None = None, sslcontext: bool = False
):
self.sockname = ("127.0.0.1", 8000) if sockname is None else sockname
self.peername = ("127.0.0.1", 8001) if peername is None else peername
self.sslcontext = sslcontext
self.closed = False
self.buffer = b""
self.read_paused = False
def get_extra_info(self, key: Any):
return {"sockname": self.sockname, "peername": self.peername, "sslcontext": self.sslcontext}.get(key)
def write(self, data: bytes):
assert not self.closed
self.buffer += data
def close(self):
assert not self.closed
self.closed = True
def pause_reading(self):
self.read_paused = True
def resume_reading(self):
self.read_paused = False
def is_closing(self):
return self.closed
def clear_buffer(self):
self.buffer = b""
def set_protocol(self, protocol: asyncio.Protocol):
pass
class MockTimerHandle:
def __init__(
self, loop_later_list: list[MockTimerHandle], delay: float, callback: Callable[[], None], args: tuple[Any, ...]
):
self.loop_later_list = loop_later_list
self.delay = delay
self.callback = callback
self.args = args
self.cancelled = False
def cancel(self):
if not self.cancelled:
self.cancelled = True
self.loop_later_list.remove(self)
class MockLoop:
def __init__(self):
self._tasks: list[asyncio.Task[Any]] = []
self._later: list[MockTimerHandle] = []
def create_task(self, coroutine: Any) -> Any:
self._tasks.insert(0, coroutine)
return MockTask()
def call_later(self, delay: float, callback: Callable[[], None], *args: Any) -> MockTimerHandle:
handle = MockTimerHandle(self._later, delay, callback, args)
self._later.insert(0, handle)
return handle
async def run_one(self):
return await self._tasks.pop()
def run_later(self, with_delay: float) -> None:
later: list[MockTimerHandle] = []
for timer_handle in self._later:
if with_delay >= timer_handle.delay:
timer_handle.callback(*timer_handle.args)
else:
later.append(timer_handle)
self._later = later
class MockTask:
def add_done_callback(self, callback: Callable[[], None]):
pass
class MockProtocol(asyncio.Protocol):
loop: MockLoop
transport: MockTransport
timeout_keep_alive_task: asyncio.TimerHandle | None
ws_protocol_class: type[WSProtocol] | None
scope: Scope
def get_connected_protocol(
app: ASGIApplication,
http_protocol_cls: type[HTTPProtocol],
lifespan: LifespanOff | LifespanOn | None = None,
**kwargs: Any,
) -> MockProtocol:
loop = MockLoop()
transport = MockTransport()
config = Config(app=app, **kwargs)
lifespan = lifespan or LifespanOff(config)
server_state = ServerState()
protocol = http_protocol_cls(config=config, server_state=server_state, app_state=lifespan.state, _loop=loop) # type: ignore
protocol.connection_made(transport) # type: ignore[arg-type]
return protocol # type: ignore[return-value]
async def test_get_request(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@pytest.mark.parametrize(
"char",
[
pytest.param("c", id="allow_ascii_letter"),
pytest.param("\t", id="allow_tab"),
pytest.param(" ", id="allow_space"),
pytest.param("µ", id="allow_non_ascii_char"),
],
)
async def test_header_value_allowed_characters(http_protocol_cls: type[HTTPProtocol], char: str):
app = Response("Hello, world", media_type="text/plain", headers={"key": f"<{char}>"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert (b"\r\nkey: <" + char.encode() + b">\r\n") in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
@pytest.mark.parametrize(
"name",
[
pytest.param("bad header", id="reject_space"),
pytest.param("bad\x00header", id="reject_null"),
pytest.param("bad(header", id="reject_open_paren"),
pytest.param("bad)header", id="reject_close_paren"),
pytest.param("bad<header", id="reject_less_than"),
pytest.param("bad>header", id="reject_greater_than"),
pytest.param("bad@header", id="reject_at"),
pytest.param("bad,header", id="reject_comma"),
pytest.param("bad;header", id="reject_semicolon"),
pytest.param("bad:header", id="reject_colon"),
pytest.param("bad[header", id="reject_open_bracket"),
pytest.param("bad]header", id="reject_close_bracket"),
pytest.param("bad{header", id="reject_open_brace"),
pytest.param("bad}header", id="reject_close_brace"),
pytest.param("bad=header", id="reject_equals"),
pytest.param('bad"header', id="reject_double_quote"),
pytest.param("bad\\header", id="reject_backslash"),
pytest.param("bad\theader", id="reject_tab"),
pytest.param("bad\x7fheader", id="reject_del"),
],
)
async def test_invalid_header_name(http_protocol_cls: type[HTTPProtocol], name: str):
app = Response("Hello, world", media_type="text/plain", headers={name: "value"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
# No 500 is sent because `response_started` is set before header validation,
# so the error handler just closes the connection.
assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer
assert name.encode() not in protocol.transport.buffer
assert protocol.transport.is_closing()
@pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"])
async def test_request_logging(path: str, http_protocol_cls: type[HTTPProtocol], caplog: pytest.LogCaptureFixture):
get_request_with_query_string = b"\r\n".join(
[f"GET {path} HTTP/1.1".encode("ascii"), b"Host: example.org", b"", b""]
)
caplog.set_level(logging.INFO, logger="uvicorn.access")
logging.getLogger("uvicorn.access").propagate = True
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, log_config=None)
protocol.data_received(get_request_with_query_string)
await protocol.loop.run_one()
assert f'"GET {path} HTTP/1.1" 200' in caplog.records[0].message
async def test_head_request(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_HEAD_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" not in protocol.transport.buffer
async def test_post_request(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""
more_body = True
while more_body:
message = await receive()
assert message["type"] == "http.request"
body += message.get("body", b"")
more_body = message.get("more_body", False)
response = Response(b"Body: " + body, media_type="text/plain")
await response(scope, receive, send)
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_POST_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b'Body: {"hello": "world"}' in protocol.transport.buffer
async def test_keepalive(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", status_code=204)
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
assert not protocol.transport.is_closing()
async def test_keepalive_timeout(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", status_code=204)
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
assert not protocol.transport.is_closing()
protocol.loop.run_later(with_delay=1)
assert not protocol.transport.is_closing()
protocol.loop.run_later(with_delay=5)
assert protocol.transport.is_closing()
async def test_keepalive_timeout_with_pipelined_requests(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.data_received(SIMPLE_GET_REQUEST)
# After processing the first request, the keep-alive task should be
# disabled because the second request is not responded yet.
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
assert protocol.timeout_keep_alive_task is None
# Process the second request and ensure that the keep-alive task
# has been enabled again as the connection is now idle.
protocol.transport.clear_buffer()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
assert protocol.timeout_keep_alive_task is not None
async def test_close(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", status_code=204, headers={"connection": "close"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_chunked_encoding(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"0\r\n\r\n" in protocol.transport.buffer
assert not protocol.transport.is_closing()
async def test_chunked_encoding_empty_body(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert protocol.transport.buffer.count(b"0\r\n\r\n") == 1
assert not protocol.transport.is_closing()
async def test_chunked_encoding_head_request(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_HEAD_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert not protocol.transport.is_closing()
async def test_pipelined_requests(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
protocol.transport.clear_buffer()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
protocol.transport.clear_buffer()
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
protocol.transport.clear_buffer()
async def test_undersized_request(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"xxx", headers={"content-length": "10"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert protocol.transport.is_closing()
async def test_oversized_request(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"xxx" * 20, headers={"content-length": "10"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert protocol.transport.is_closing()
async def test_large_post_request(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(LARGE_POST_REQUEST)
assert protocol.transport.read_paused
await protocol.loop.run_one()
assert not protocol.transport.read_paused
async def test_invalid_http(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(b"x" * 100000)
assert protocol.transport.is_closing()
async def test_app_exception(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
raise Exception()
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_exception_during_response(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b"1", "more_body": True})
raise Exception()
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_no_response_returned(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): ...
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_partial_response_returned(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_response_header_splitting(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", headers={"key": "value\r\nCookie: smuggled=value"})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer
assert b"\r\nCookie: smuggled=value\r\n" not in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_duplicate_start_message(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.start", "status": 200})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_missing_start_message(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.body", "body": b""})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_message_after_body_complete(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
await send({"type": "http.response.body", "body": b""})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_value_returned(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
return 123
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_early_disconnect(http_protocol_cls: type[HTTPProtocol]):
got_disconnect_event = False
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal got_disconnect_event
while True:
message = await receive()
if message["type"] == "http.disconnect":
break
got_disconnect_event = True
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_POST_REQUEST)
protocol.eof_received()
protocol.connection_lost(None)
await protocol.loop.run_one()
assert got_disconnect_event
async def test_early_response(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(START_POST_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
protocol.data_received(FINISH_POST_REQUEST)
assert not protocol.transport.is_closing()
async def test_read_after_response(http_protocol_cls: type[HTTPProtocol]):
message_after_response = None
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal message_after_response
response = Response("Hello, world", media_type="text/plain")
await response(scope, receive, send)
message_after_response = await receive()
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_POST_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert message_after_response == {"type": "http.disconnect"}
async def test_http10_request(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
content = "Version: %s" % scope["http_version"]
response = Response(content, media_type="text/plain")
await response(scope, receive, send)
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(HTTP10_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Version: 1.0" in protocol.transport.buffer
async def test_root_path(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
root_path = scope.get("root_path", "")
path = scope["path"]
response = Response(f"root_path={root_path} path={path}", media_type="text/plain")
await response(scope, receive, send)
protocol = get_connected_protocol(app, http_protocol_cls, root_path="/app")
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"root_path=/app path=/app/" in protocol.transport.buffer
async def test_raw_path(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
path = scope["path"]
raw_path = scope.get("raw_path", None)
assert "/app/one/two" == path
assert b"/app/one%2Ftwo" == raw_path
response = Response("Done", media_type="text/plain")
await response(scope, receive, send)
protocol = get_connected_protocol(app, http_protocol_cls, root_path="/app")
protocol.data_received(GET_REQUEST_WITH_RAW_PATH)
await protocol.loop.run_one()
assert b"Done" in protocol.transport.buffer
async def test_max_concurrency(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, limit_concurrency=1)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert (
b"\r\n".join(
[
b"HTTP/1.1 503 Service Unavailable",
b"content-type: text/plain; charset=utf-8",
b"content-length: 19",
b"connection: close",
b"",
b"Service Unavailable",
]
)
== protocol.transport.buffer
)
async def test_shutdown_during_request(http_protocol_cls: type[HTTPProtocol]):
app = Response(b"", status_code=204)
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.shutdown() # type: ignore[attr-defined]
await protocol.loop.run_one()
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
assert protocol.transport.is_closing()
async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.shutdown() # type: ignore[attr-defined]
assert protocol.transport.buffer == b""
assert protocol.transport.is_closing()
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""
more_body = True
while more_body:
message = await receive()
assert message["type"] == "http.request"
body += message.get("body", b"")
more_body = message.get("more_body", False)
response = Response(b"Body: " + body, media_type="text/plain")
await response(scope, receive, send)
protocol = get_connected_protocol(app, http_protocol_cls)
EXPECT_100_REQUEST = b"\r\n".join(
[
b"POST / HTTP/1.1",
b"Host: example.org",
b"Expect: 100-continue",
b"Content-Type: application/json",
b"Content-Length: 18",
b"",
b'{"hello": "world"}',
]
)
protocol.data_received(EXPECT_100_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 100 Continue" in protocol.transport.buffer
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b'Body: {"hello": "world"}' in protocol.transport.buffer
async def test_100_continue_not_sent_when_body_not_consumed(
http_protocol_cls: type[HTTPProtocol],
):
app = Response(b"", status_code=204)
protocol = get_connected_protocol(app, http_protocol_cls)
EXPECT_100_REQUEST = b"\r\n".join(
[
b"POST / HTTP/1.1",
b"Host: example.org",
b"Expect: 100-continue",
b"Content-Type: application/json",
b"Content-Length: 18",
b"",
b'{"hello": "world"}',
]
)
protocol.data_received(EXPECT_100_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 100 Continue" not in protocol.transport.buffer
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
async def test_supported_upgrade_request(http_protocol_cls: type[HTTPProtocol]):
pytest.importorskip("wsproto")
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, ws="wsproto")
protocol.data_received(UPGRADE_REQUEST)
assert b"HTTP/1.1 426 " in protocol.transport.buffer
async def test_unsupported_ws_upgrade_request(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, ws="none")
protocol.data_received(UPGRADE_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
async def test_unsupported_ws_upgrade_request_warn_on_auto(
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, ws="auto")
protocol.ws_protocol_class = None
protocol.data_received(UPGRADE_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
warnings = [record.msg for record in filter(lambda record: record.levelname == "WARNING", caplog.records)]
assert "Unsupported upgrade request." in warnings
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
assert msg in warnings
async def test_http2_upgrade_request(http_protocol_cls: type[HTTPProtocol], ws_protocol_cls: type[WSProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, ws=ws_protocol_cls)
protocol.data_received(UPGRADE_HTTP2_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
async def asgi3app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
pass
def asgi2app(scope: Scope):
async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable):
pass
return asgi
@pytest.mark.parametrize(
"asgi2or3_app, expected_scopes",
[
(asgi3app, {"version": "3.0", "spec_version": "2.3"}),
(asgi2app, {"version": "2.0", "spec_version": "2.3"}),
],
)
async def test_scopes(
asgi2or3_app: ASGIApplication,
expected_scopes: dict[str, str],
http_protocol_cls: type[HTTPProtocol],
):
protocol = get_connected_protocol(asgi2or3_app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert expected_scopes == protocol.scope.get("asgi")
@pytest.mark.parametrize(
"request_line",
[
pytest.param(b"G?T / HTTP/1.1", id="invalid-method"),
pytest.param(b"GET /?x=y z HTTP/1.1", id="invalid-path"),
pytest.param(b"GET / HTTP1.1", id="invalid-http-version"),
],
)
async def test_invalid_http_request(
request_line: str, http_protocol_cls: type[HTTPProtocol], caplog: pytest.LogCaptureFixture
):
app = Response("Hello, world", media_type="text/plain")
request = INVALID_REQUEST_TEMPLATE % request_line
caplog.set_level(logging.INFO, logger="uvicorn.error")
logging.getLogger("uvicorn.error").propagate = True
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(request)
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer
@skip_if_no_httptools
def test_fragmentation(unused_tcp_port: int):
def receive_all(sock: socket.socket):
chunks: list[bytes] = []
while True:
chunk = sock.recv(1024)
if not chunk:
break
chunks.append(chunk)
return b"".join(chunks)
app = Response("Hello, world", media_type="text/plain")
def send_fragmented_req(path: str):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", unused_tcp_port))
d = (f"GET {path} HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n").encode()
split = len(path) // 2
sock.sendall(d[:split])
time.sleep(0.01)
sock.sendall(d[split:])
resp = receive_all(sock)
# see https://github.com/kmonsoor/py-amqplib/issues/45
# we skip the error on bsd systems if python is too slow
try:
sock.shutdown(socket.SHUT_RDWR)
except Exception: # pragma: no cover
pass
sock.close()
return resp
config = Config(app=app, http="httptools", port=unused_tcp_port)
server = Server(config=config)
t = threading.Thread(target=server.run)
t.daemon = True
t.start()
time.sleep(1) # wait for uvicorn to start
path = "/?param=" + "q" * 10
response = send_fragmented_req(path)
bad_response = b"HTTP/1.1 400 Bad Request"
assert bad_response != response[: len(bad_response)]
server.should_exit = True
t.join()
async def test_huge_headers_h11protocol_failure():
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, H11Protocol)
# Huge headers make h11 fail in it's default config
# h11 sends back a 400 in this case
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Connection: close" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer
@skip_if_no_httptools
async def test_huge_headers_httptools_will_pass():
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, HttpToolsProtocol)
# Huge headers make h11 fail in it's default config
# httptools protocol will always pass
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.data_received(GET_REQUEST_HUGE_HEADERS[1])
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
async def test_huge_headers_h11protocol_failure_with_setting():
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, H11Protocol, h11_max_incomplete_event_size=20 * 1024)
# Huge headers make h11 fail in it's default config
# h11 sends back a 400 in this case
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer
assert b"Connection: close" in protocol.transport.buffer
assert b"Invalid HTTP request received." in protocol.transport.buffer
@skip_if_no_httptools
async def test_huge_headers_httptools():
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, HttpToolsProtocol)
# Huge headers make h11 fail in it's default config
# httptools protocol will always pass
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.data_received(GET_REQUEST_HUGE_HEADERS[1])
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
async def test_huge_headers_h11_max_incomplete():
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, H11Protocol, h11_max_incomplete_event_size=64 * 1024)
protocol.data_received(GET_REQUEST_HUGE_HEADERS[0])
protocol.data_received(GET_REQUEST_HUGE_HEADERS[1])
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
async def test_return_close_header(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(CONNECTION_CLOSE_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"content-type: text/plain" in protocol.transport.buffer
assert b"content-length: 12" in protocol.transport.buffer
# NOTE: We need to use `.lower()` because H11 implementation doesn't allow Uvicorn
# to lowercase them. See: https://github.com/python-hyper/h11/issues/156
assert b"connection: close" in protocol.transport.buffer.lower()
async def test_close_connection_with_multiple_requests(http_protocol_cls: type[HTTPProtocol]):
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(REQUEST_AFTER_CONNECTION_CLOSE)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"content-type: text/plain" in protocol.transport.buffer
assert b"content-length: 12" in protocol.transport.buffer
# NOTE: We need to use `.lower()` because H11 implementation doesn't allow Uvicorn
# to lowercase them. See: https://github.com/python-hyper/h11/issues/156
assert b"connection: close" in protocol.transport.buffer.lower()
async def test_close_connection_with_post_request(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""
more_body = True
while more_body:
message = await receive()
assert message["type"] == "http.request"
body += message.get("body", b"")
more_body = message.get("more_body", False)
response = Response(b"Body: " + body, media_type="text/plain")
await response(scope, receive, send)
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(CONNECTION_CLOSE_POST_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Body: {'hello': 'world'}" in protocol.transport.buffer
async def test_iterator_headers(http_protocol_cls: type[HTTPProtocol]):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
headers = iter([(b"x-test-header", b"test value")])
await send({"type": "http.response.start", "status": 200, "headers": headers})
await send({"type": "http.response.body", "body": b""})
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"x-test-header: test value" in protocol.transport.buffer
async def test_lifespan_state(http_protocol_cls: type[HTTPProtocol]):
expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}]
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert "state" in scope
expected_state = expected_states.pop(0)
assert scope["state"] == expected_state
# modifications to keys are not preserved
scope["state"]["a"] = 456
# unless of course the value itself is mutated
scope["state"]["b"].append(2)
return await Response("Hi!")(scope, receive, send)
lifespan = LifespanOn(config=Config(app=app))
# skip over actually running the lifespan, that is tested
# in the lifespan tests
lifespan.state.update({"a": 123, "b": [1]})
protocol = get_connected_protocol(app, http_protocol_cls, lifespan=lifespan)
for _ in range(2):
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hi!" in protocol.transport.buffer
assert not expected_states # consumed
async def test_header_upgrade_is_not_websocket_depend_installed(
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
):
caplog.set_level(logging.WARNING, logger="uvicorn.error")
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(UPGRADE_REQUEST_ERROR_FIELD)
await protocol.loop.run_one()
assert "Unsupported upgrade request." in caplog.text
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
assert msg not in caplog.text
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer
async def test_header_upgrade_is_websocket_depend_not_installed(
caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol]
):
caplog.set_level(logging.WARNING, logger="uvicorn.error")
app = Response("Hello, world", media_type="text/plain")
protocol = get_connected_protocol(app, http_protocol_cls, ws="none")
protocol.data_received(UPGRADE_REQUEST_ERROR_FIELD)
await protocol.loop.run_one()
assert "Unsupported upgrade request." in caplog.text
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
assert msg in caplog.text
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hello, world" in protocol.transport.buffer