Use pytestmark to simplify test suite (#2334)

* Use pytestmark

* Fix linter
This commit is contained in:
Marcelo Trylesinski 2024-05-14 18:34:43 -04:00 committed by GitHub
parent 0efd3835da
commit 14bdf047f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 50 additions and 169 deletions

View File

@ -13,6 +13,7 @@ import websockets.client
from tests.utils import run_server
from uvicorn import Config
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
if typing.TYPE_CHECKING:
import sys
@ -27,9 +28,11 @@ if typing.TYPE_CHECKING:
WSProtocol: TypeAlias = "type[WebSocketProtocol | _WSProtocol]"
pytestmark = pytest.mark.anyio
@contextlib.contextmanager
def caplog_for_logger(caplog, logger_name):
def caplog_for_logger(caplog: pytest.LogCaptureFixture, logger_name: str) -> typing.Iterator[pytest.LogCaptureFixture]:
logger = logging.getLogger(logger_name)
logger.propagate, old_propagate = False, logger.propagate
logger.addHandler(caplog.handler)
@ -40,14 +43,13 @@ def caplog_for_logger(caplog, logger_name):
logger.propagate = old_propagate
async def app(scope, receive, send):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
@pytest.mark.anyio
async def test_trace_logging(caplog, logging_config, unused_tcp_port: int):
async def test_trace_logging(caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int):
config = Config(
app=app,
log_level="trace",
@ -69,7 +71,6 @@ async def test_trace_logging(caplog, logging_config, unused_tcp_port: int):
assert "ASGI [2] Completed" in messages.pop(0)
@pytest.mark.anyio
async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging_config, unused_tcp_port: int):
config = Config(
app=app,
@ -88,14 +89,13 @@ async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging
assert any(" - HTTP connection lost" in message for message in messages)
@pytest.mark.anyio
async def test_trace_logging_on_ws_protocol(
ws_protocol_cls: WSProtocol,
caplog,
logging_config,
unused_tcp_port: int,
):
async def websocket_app(scope, receive, send):
async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
while True:
message = await receive()
@ -125,9 +125,8 @@ async def test_trace_logging_on_ws_protocol(
assert any(" - WebSocket connection lost" in message for message in messages)
@pytest.mark.anyio
@pytest.mark.parametrize("use_colors", [(True), (False), (None)])
async def test_access_logging(use_colors, caplog, logging_config, unused_tcp_port: int):
async def test_access_logging(use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int):
config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port)
with caplog_for_logger(caplog, "uvicorn.access"):
async with run_server(config):
@ -139,9 +138,10 @@ async def test_access_logging(use_colors, caplog, logging_config, unused_tcp_por
assert '"GET / HTTP/1.1" 204' in messages.pop()
@pytest.mark.anyio
@pytest.mark.parametrize("use_colors", [(True), (False)])
async def test_default_logging(use_colors, caplog, logging_config, unused_tcp_port: int):
async def test_default_logging(
use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int
):
config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port)
with caplog_for_logger(caplog, "uvicorn.access"):
async with run_server(config):
@ -158,9 +158,10 @@ async def test_default_logging(use_colors, caplog, logging_config, unused_tcp_po
assert "Shutting down" in messages.pop(0)
@pytest.mark.anyio
@pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system")
async def test_running_log_using_uds(caplog, short_socket_name, unused_tcp_port: int): # pragma: py-win32
async def test_running_log_using_uds(
caplog: pytest.LogCaptureFixture, short_socket_name: str, unused_tcp_port: int
): # pragma: py-win32
config = Config(app=app, uds=short_socket_name, port=unused_tcp_port)
with caplog_for_logger(caplog, "uvicorn.access"):
async with run_server(config):
@ -170,9 +171,8 @@ async def test_running_log_using_uds(caplog, short_socket_name, unused_tcp_port:
assert f"Uvicorn running on unix socket {short_socket_name} (Press CTRL+C to quit)" in messages
@pytest.mark.anyio
@pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system")
async def test_running_log_using_fd(caplog, unused_tcp_port: int): # pragma: py-win32
async def test_running_log_using_fd(caplog: pytest.LogCaptureFixture, unused_tcp_port: int): # pragma: py-win32
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
fd = sock.fileno()
config = Config(app=app, fd=fd, port=unused_tcp_port)
@ -184,9 +184,8 @@ async def test_running_log_using_fd(caplog, unused_tcp_port: int): # pragma: py
assert f"Uvicorn running on socket {sockname} (Press CTRL+C to quit)" in messages
@pytest.mark.anyio
async def test_unknown_status_code(caplog, unused_tcp_port: int):
async def app(scope, receive, send):
async def test_unknown_status_code(caplog: pytest.LogCaptureFixture, unused_tcp_port: int):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 599, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
@ -202,11 +201,10 @@ async def test_unknown_status_code(caplog, unused_tcp_port: int):
assert '"GET / HTTP/1.1" 599' in messages.pop()
@pytest.mark.anyio
async def test_server_start_with_port_zero(caplog: pytest.LogCaptureFixture):
config = Config(app=app, port=0)
async with run_server(config) as server:
server = server.servers[0]
async with run_server(config) as _server:
server = _server.servers[0]
sock = server.sockets[0]
host, port = sock.getsockname()
messages = [record.message for record in caplog.records if "uvicorn" in record.name]

View File

@ -39,6 +39,8 @@ if TYPE_CHECKING:
HTTPProtocol: TypeAlias = "type[HttpToolsProtocol | H11Protocol]"
WSProtocol: TypeAlias = "type[WebSocketProtocol | _WSProtocol]"
pytestmark = pytest.mark.anyio
WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys()
@ -239,7 +241,6 @@ def get_connected_protocol(
return protocol
@pytest.mark.anyio
async def test_get_request(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -250,7 +251,6 @@ async def test_get_request(http_protocol_cls: HTTPProtocol):
assert b"Hello, world" in protocol.transport.buffer
@pytest.mark.anyio
@pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"])
async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture):
get_request_with_query_string = b"\r\n".join(
@ -267,7 +267,6 @@ async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplo
assert f'"GET {path} HTTP/1.1" 200' in caplog.records[0].message
@pytest.mark.anyio
async def test_head_request(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -278,7 +277,6 @@ async def test_head_request(http_protocol_cls: HTTPProtocol):
assert b"Hello, world" not in protocol.transport.buffer
@pytest.mark.anyio
async def test_post_request(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""
@ -298,7 +296,6 @@ async def test_post_request(http_protocol_cls: HTTPProtocol):
assert b'Body: {"hello": "world"}' in protocol.transport.buffer
@pytest.mark.anyio
async def test_keepalive(http_protocol_cls: HTTPProtocol):
app = Response(b"", status_code=204)
@ -310,7 +307,6 @@ async def test_keepalive(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.is_closing()
@pytest.mark.anyio
async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol):
app = Response(b"", status_code=204)
@ -325,7 +321,6 @@ async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_keepalive_timeout_with_pipelined_requests(
http_protocol_cls: HTTPProtocol,
):
@ -351,7 +346,6 @@ async def test_keepalive_timeout_with_pipelined_requests(
assert protocol.timeout_keep_alive_task is not None
@pytest.mark.anyio
async def test_close(http_protocol_cls: HTTPProtocol):
app = Response(b"", status_code=204, headers={"connection": "close"})
@ -362,7 +356,6 @@ async def test_close(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
@ -374,7 +367,6 @@ async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.is_closing()
@pytest.mark.anyio
async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"})
@ -386,7 +378,6 @@ async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.is_closing()
@pytest.mark.anyio
async def test_chunked_encoding_head_request(
http_protocol_cls: HTTPProtocol,
):
@ -399,7 +390,6 @@ async def test_chunked_encoding_head_request(
assert not protocol.transport.is_closing()
@pytest.mark.anyio
async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -421,7 +411,6 @@ async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
protocol.transport.clear_buffer()
@pytest.mark.anyio
async def test_undersized_request(http_protocol_cls: HTTPProtocol):
app = Response(b"xxx", headers={"content-length": "10"})
@ -431,7 +420,6 @@ async def test_undersized_request(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_oversized_request(http_protocol_cls: HTTPProtocol):
app = Response(b"xxx" * 20, headers={"content-length": "10"})
@ -441,7 +429,6 @@ async def test_oversized_request(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_large_post_request(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -452,7 +439,6 @@ async def test_large_post_request(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.read_paused
@pytest.mark.anyio
async def test_invalid_http(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -461,7 +447,6 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_app_exception(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
raise Exception()
@ -473,7 +458,6 @@ async def test_app_exception(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
@ -487,7 +471,6 @@ async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): ...
@ -498,7 +481,6 @@ async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
@ -510,7 +492,6 @@ async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
@ -523,7 +504,6 @@ async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.body", "body": b""})
@ -535,7 +515,6 @@ async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
@ -549,7 +528,6 @@ async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_value_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
@ -563,7 +541,6 @@ async def test_value_returned(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
got_disconnect_event = False
@ -585,7 +562,6 @@ async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
assert got_disconnect_event
@pytest.mark.anyio
async def test_early_response(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -597,7 +573,6 @@ async def test_early_response(http_protocol_cls: HTTPProtocol):
assert not protocol.transport.is_closing()
@pytest.mark.anyio
async def test_read_after_response(http_protocol_cls: HTTPProtocol):
message_after_response = None
@ -615,7 +590,6 @@ async def test_read_after_response(http_protocol_cls: HTTPProtocol):
assert message_after_response == {"type": "http.disconnect"}
@pytest.mark.anyio
async def test_http10_request(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
@ -630,7 +604,6 @@ async def test_http10_request(http_protocol_cls: HTTPProtocol):
assert b"Version: 1.0" in protocol.transport.buffer
@pytest.mark.anyio
async def test_root_path(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
@ -646,7 +619,6 @@ async def test_root_path(http_protocol_cls: HTTPProtocol):
assert b"root_path=/app path=/app/" in protocol.transport.buffer
@pytest.mark.anyio
async def test_raw_path(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
@ -664,7 +636,6 @@ async def test_raw_path(http_protocol_cls: HTTPProtocol):
assert b"Done" in protocol.transport.buffer
@pytest.mark.anyio
async def test_max_concurrency(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -686,7 +657,6 @@ async def test_max_concurrency(http_protocol_cls: HTTPProtocol):
)
@pytest.mark.anyio
async def test_shutdown_during_request(http_protocol_cls: HTTPProtocol):
app = Response(b"", status_code=204)
@ -698,7 +668,6 @@ async def test_shutdown_during_request(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_shutdown_during_idle(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -708,7 +677,6 @@ async def test_shutdown_during_idle(http_protocol_cls: HTTPProtocol):
assert protocol.transport.is_closing()
@pytest.mark.anyio
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
body = b""
@ -740,7 +708,6 @@ async def test_100_continue_sent_when_body_consumed(http_protocol_cls: HTTPProto
assert b'Body: {"hello": "world"}' in protocol.transport.buffer
@pytest.mark.anyio
async def test_100_continue_not_sent_when_body_not_consumed(
http_protocol_cls: HTTPProtocol,
):
@ -764,7 +731,6 @@ async def test_100_continue_not_sent_when_body_not_consumed(
assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer
@pytest.mark.anyio
async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol):
pytest.importorskip("wsproto")
@ -775,7 +741,6 @@ async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol):
assert b"HTTP/1.1 426 " in protocol.transport.buffer
@pytest.mark.anyio
async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -786,7 +751,6 @@ async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol):
assert b"Hello, world" in protocol.transport.buffer
@pytest.mark.anyio
async def test_unsupported_ws_upgrade_request_warn_on_auto(
caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol
):
@ -804,7 +768,6 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto(
assert msg in warnings
@pytest.mark.anyio
async def test_http2_upgrade_request(http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -826,7 +789,6 @@ def asgi2app(scope: Scope):
return asgi
@pytest.mark.anyio
@pytest.mark.parametrize(
"asgi2or3_app, expected_scopes",
[
@ -845,7 +807,6 @@ async def test_scopes(
assert expected_scopes == protocol.scope.get("asgi")
@pytest.mark.anyio
@pytest.mark.parametrize(
"request_line",
[
@ -915,7 +876,6 @@ def test_fragmentation(unused_tcp_port: int):
t.join()
@pytest.mark.anyio
async def test_huge_headers_h11protocol_failure():
app = Response("Hello, world", media_type="text/plain")
@ -928,7 +888,6 @@ async def test_huge_headers_h11protocol_failure():
assert b"Invalid HTTP request received." in protocol.transport.buffer
@pytest.mark.anyio
@skip_if_no_httptools
async def test_huge_headers_httptools_will_pass():
app = Response("Hello, world", media_type="text/plain")
@ -943,7 +902,6 @@ async def test_huge_headers_httptools_will_pass():
assert b"Hello, world" in protocol.transport.buffer
@pytest.mark.anyio
async def test_huge_headers_h11protocol_failure_with_setting():
app = Response("Hello, world", media_type="text/plain")
@ -956,7 +914,6 @@ async def test_huge_headers_h11protocol_failure_with_setting():
assert b"Invalid HTTP request received." in protocol.transport.buffer
@pytest.mark.anyio
@skip_if_no_httptools
async def test_huge_headers_httptools():
app = Response("Hello, world", media_type="text/plain")
@ -971,7 +928,6 @@ async def test_huge_headers_httptools():
assert b"Hello, world" in protocol.transport.buffer
@pytest.mark.anyio
async def test_huge_headers_h11_max_incomplete():
app = Response("Hello, world", media_type="text/plain")
@ -983,7 +939,6 @@ async def test_huge_headers_h11_max_incomplete():
assert b"Hello, world" in protocol.transport.buffer
@pytest.mark.anyio
async def test_return_close_header(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")
@ -998,7 +953,6 @@ async def test_return_close_header(http_protocol_cls: HTTPProtocol):
assert b"connection: close" in protocol.transport.buffer.lower()
@pytest.mark.anyio
async def test_iterator_headers(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
headers = iter([(b"x-test-header", b"test value")])
@ -1011,7 +965,6 @@ async def test_iterator_headers(http_protocol_cls: HTTPProtocol):
assert b"x-test-header: test value" in protocol.transport.buffer
@pytest.mark.anyio
async def test_lifespan_state(http_protocol_cls: HTTPProtocol):
expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}]

View File

@ -48,6 +48,8 @@ if typing.TYPE_CHECKING:
HTTPProtocol: TypeAlias = "type[H11Protocol | HttpToolsProtocol]"
WSProtocol: TypeAlias = "type[_WSProtocol | WebSocketProtocol]"
pytestmark = pytest.mark.anyio
class WebSocketResponse:
def __init__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
@ -84,7 +86,6 @@ async def wsresponse(url):
return await client.get(url, headers=headers)
@pytest.mark.anyio
async def test_invalid_upgrade(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
def app(scope: Scope):
return None
@ -117,7 +118,6 @@ async def test_invalid_upgrade(ws_protocol_cls: WSProtocol, http_protocol_cls: H
)
@pytest.mark.anyio
async def test_accept_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
class App(WebSocketResponse):
async def websocket_connect(self, message):
@ -139,7 +139,6 @@ async def test_accept_connection(ws_protocol_cls: WSProtocol, http_protocol_cls:
assert is_open
@pytest.mark.anyio
async def test_shutdown(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
class App(WebSocketResponse):
async def websocket_connect(self, message):
@ -158,7 +157,6 @@ async def test_shutdown(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProt
await server.shutdown()
@pytest.mark.anyio
async def test_supports_permessage_deflate_extension(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -183,7 +181,6 @@ async def test_supports_permessage_deflate_extension(
assert "permessage-deflate" in extension_names
@pytest.mark.anyio
async def test_can_disable_permessage_deflate_extension(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -211,7 +208,6 @@ async def test_can_disable_permessage_deflate_extension(
assert "permessage-deflate" not in extension_names
@pytest.mark.anyio
async def test_close_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
class App(WebSocketResponse):
async def websocket_connect(self, message):
@ -236,7 +232,6 @@ async def test_close_connection(ws_protocol_cls: WSProtocol, http_protocol_cls:
assert not is_open
@pytest.mark.anyio
async def test_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
class App(WebSocketResponse):
async def websocket_connect(self, message):
@ -262,7 +257,6 @@ async def test_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProto
assert is_open
@pytest.mark.anyio
async def test_extra_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
class App(WebSocketResponse):
async def websocket_connect(self, message):
@ -284,7 +278,6 @@ async def test_extra_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTT
assert extra_headers.get("extra") == "header"
@pytest.mark.anyio
async def test_path_and_raw_path(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
class App(WebSocketResponse):
async def websocket_connect(self, message):
@ -310,7 +303,6 @@ async def test_path_and_raw_path(ws_protocol_cls: WSProtocol, http_protocol_cls:
assert is_open
@pytest.mark.anyio
async def test_send_text_data_to_client(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -335,7 +327,6 @@ async def test_send_text_data_to_client(
assert data == "123"
@pytest.mark.anyio
async def test_send_binary_data_to_client(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -360,7 +351,6 @@ async def test_send_binary_data_to_client(
assert data == b"123"
@pytest.mark.anyio
async def test_send_and_close_connection(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -393,7 +383,6 @@ async def test_send_and_close_connection(
assert not is_open
@pytest.mark.anyio
async def test_send_text_data_to_server(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -422,7 +411,6 @@ async def test_send_text_data_to_server(
assert data == "abc"
@pytest.mark.anyio
async def test_send_binary_data_to_server(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -451,7 +439,6 @@ async def test_send_binary_data_to_server(
assert data == b"abc"
@pytest.mark.anyio
async def test_send_after_protocol_close(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -486,7 +473,6 @@ async def test_send_after_protocol_close(
assert not is_open
@pytest.mark.anyio
async def test_missing_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
pass
@ -507,7 +493,6 @@ async def test_missing_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls:
assert exc_info.value.status_code == 500
@pytest.mark.anyio
async def test_send_before_handshake(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -530,7 +515,6 @@ async def test_send_before_handshake(
assert exc_info.value.status_code == 500
@pytest.mark.anyio
async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "websocket.accept"})
@ -553,7 +537,6 @@ async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cl
assert exc_info.value.code == 1006
@pytest.mark.anyio
async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
"""
The ASGI callable should return 'None'. If it doesn't, make sure that
@ -581,7 +564,6 @@ async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls:
assert exc_info.value.code == 1006
@pytest.mark.anyio
@pytest.mark.parametrize("code", [None, 1000, 1001])
@pytest.mark.parametrize(
"reason",
@ -633,7 +615,6 @@ async def test_app_close(
assert exc_info.value.reason == (reason or "")
@pytest.mark.anyio
async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
while True:
@ -661,7 +642,6 @@ async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTP
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
@pytest.mark.anyio
async def test_client_connection_lost(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -695,7 +675,6 @@ async def test_client_connection_lost(
assert got_disconnect_event_before_shutdown is True
@pytest.mark.anyio
async def test_client_connection_lost_on_send(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -729,7 +708,6 @@ async def test_client_connection_lost_on_send(
assert got_disconnect_event is True
@pytest.mark.anyio
async def test_connection_lost_before_handshake_complete(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -778,7 +756,6 @@ async def test_connection_lost_before_handshake_complete(
await task
@pytest.mark.anyio
async def test_send_close_on_server_shutdown(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -823,7 +800,6 @@ async def test_send_close_on_server_shutdown(
task.cancel()
@pytest.mark.anyio
@pytest.mark.parametrize("subprotocol", ["proto1", "proto2"])
async def test_subprotocols(
ws_protocol_cls: WSProtocol,
@ -857,7 +833,6 @@ MAX_WS_BYTES = 1024 * 1024 * 16
MAX_WS_BYTES_PLUS1 = MAX_WS_BYTES + 1
@pytest.mark.anyio
@pytest.mark.parametrize(
"client_size_sent, server_size_max, expected_result",
[
@ -911,7 +886,6 @@ async def test_send_binary_data_to_server_bigger_than_default_on_websockets(
assert e.value.code == expected_result
@pytest.mark.anyio
async def test_server_reject_connection(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -952,7 +926,6 @@ async def test_server_reject_connection(
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
@pytest.mark.anyio
async def test_server_reject_connection_with_response(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -990,7 +963,6 @@ async def test_server_reject_connection_with_response(
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
@pytest.mark.anyio
async def test_server_reject_connection_with_multibody_response(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1043,7 +1015,6 @@ async def test_server_reject_connection_with_multibody_response(
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
@pytest.mark.anyio
async def test_server_reject_connection_with_invalid_status(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1085,7 +1056,6 @@ async def test_server_reject_connection_with_invalid_status(
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
@pytest.mark.anyio
async def test_server_reject_connection_with_body_nolength(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1130,7 +1100,6 @@ async def test_server_reject_connection_with_body_nolength(
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
@pytest.mark.anyio
async def test_server_reject_connection_with_invalid_msg(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1168,7 +1137,6 @@ async def test_server_reject_connection_with_invalid_msg(
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
@pytest.mark.anyio
async def test_server_reject_connection_with_missing_body(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1205,7 +1173,6 @@ async def test_server_reject_connection_with_missing_body(
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
@pytest.mark.anyio
async def test_server_multiple_websocket_http_response_start_events(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1257,7 +1224,6 @@ async def test_server_multiple_websocket_http_response_start_events(
)
@pytest.mark.anyio
async def test_server_can_read_messages_in_buffer_after_close(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1299,7 +1265,6 @@ async def test_server_can_read_messages_in_buffer_after_close(
assert disconnect_message == {"type": "websocket.disconnect", "code": 1000}
@pytest.mark.anyio
async def test_default_server_headers(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1323,7 +1288,6 @@ async def test_default_server_headers(
assert headers.get("server") == "uvicorn" and "date" in headers
@pytest.mark.anyio
async def test_no_server_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
class App(WebSocketResponse):
async def websocket_connect(self, message):
@ -1346,7 +1310,6 @@ async def test_no_server_headers(ws_protocol_cls: WSProtocol, http_protocol_cls:
assert "server" not in headers
@pytest.mark.anyio
@skip_if_no_wsproto
async def test_no_date_header_on_wsproto(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
class App(WebSocketResponse):
@ -1370,7 +1333,6 @@ async def test_no_date_header_on_wsproto(http_protocol_cls: HTTPProtocol, unused
assert "date" not in headers
@pytest.mark.anyio
async def test_multiple_server_header(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
@ -1402,7 +1364,6 @@ async def test_multiple_server_header(
assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"]
@pytest.mark.anyio
async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
expected_states = [
{"a": 123, "b": [1]},

View File

@ -32,10 +32,6 @@ async def app(scope, receive, send):
pass # pragma: no cover
# TODO: Add pypy to our testing matrix, and assert we get the correct classes
# dependent on the platform we're running the tests under.
def test_loop_auto():
auto_loop_setup()
policy = asyncio.get_event_loop_policy()

View File

@ -1,17 +1,21 @@
from __future__ import annotations
import httpx
import pytest
from tests.utils import run_server
from uvicorn import Config
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
pytestmark = pytest.mark.anyio
async def app(scope, receive, send):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
@pytest.mark.anyio
async def test_default_default_headers(unused_tcp_port: int):
config = Config(app=app, loop="asyncio", limit_max_requests=1, port=unused_tcp_port)
async with run_server(config):
@ -20,79 +24,45 @@ async def test_default_default_headers(unused_tcp_port: int):
assert response.headers["server"] == "uvicorn" and response.headers["date"]
@pytest.mark.anyio
async def test_override_server_header(unused_tcp_port: int):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
headers=[("Server", "over-ridden")],
port=unused_tcp_port,
)
headers: list[tuple[str, str]] = [("Server", "over-ridden")]
config = Config(app=app, loop="asyncio", limit_max_requests=1, headers=headers, port=unused_tcp_port)
async with run_server(config):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
assert response.headers["server"] == "over-ridden" and response.headers["date"]
@pytest.mark.anyio
async def test_disable_default_server_header(unused_tcp_port: int):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
server_header=False,
port=unused_tcp_port,
)
config = Config(app=app, loop="asyncio", limit_max_requests=1, server_header=False, port=unused_tcp_port)
async with run_server(config):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
assert "server" not in response.headers
@pytest.mark.anyio
async def test_override_server_header_multiple_times(unused_tcp_port: int):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
headers=[("Server", "over-ridden"), ("Server", "another-value")],
port=unused_tcp_port,
)
headers: list[tuple[str, str]] = [("Server", "over-ridden"), ("Server", "another-value")]
config = Config(app=app, loop="asyncio", limit_max_requests=1, headers=headers, port=unused_tcp_port)
async with run_server(config):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
assert response.headers["server"] == "over-ridden, another-value" and response.headers["date"]
@pytest.mark.anyio
async def test_add_additional_header(unused_tcp_port: int):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
headers=[("X-Additional", "new-value")],
port=unused_tcp_port,
)
headers: list[tuple[str, str]] = [("X-Additional", "new-value")]
config = Config(app=app, loop="asyncio", limit_max_requests=1, headers=headers, port=unused_tcp_port)
async with run_server(config):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")
assert (
response.headers["x-additional"] == "new-value"
and response.headers["server"] == "uvicorn"
and response.headers["date"]
)
assert response.headers["x-additional"] == "new-value"
assert response.headers["server"] == "uvicorn"
assert response.headers["date"]
@pytest.mark.anyio
async def test_disable_default_date_header(unused_tcp_port: int):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
date_header=False,
port=unused_tcp_port,
)
config = Config(app=app, loop="asyncio", limit_max_requests=1, date_header=False, port=unused_tcp_port)
async with run_server(config):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{unused_tcp_port}")

View File

@ -7,17 +7,20 @@ import pytest
from tests.utils import run_server
from uvicorn import Server
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.config import Config
from uvicorn.main import run
pytestmark = pytest.mark.anyio
async def app(scope, receive, send):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
assert scope["type"] == "http"
await send({"type": "http.response.start", "status": 204, "headers": []})
await send({"type": "http.response.body", "body": b"", "more_body": False})
def _has_ipv6(host):
def _has_ipv6(host: str):
sock = None
has_ipv6 = False
if socket.has_ipv6:
@ -32,7 +35,6 @@ def _has_ipv6(host):
return has_ipv6
@pytest.mark.anyio
@pytest.mark.parametrize(
"host, url",
[
@ -54,7 +56,6 @@ async def test_run(host, url: str, unused_tcp_port: int):
assert response.status_code == 204
@pytest.mark.anyio
async def test_run_multiprocess(unused_tcp_port: int):
config = Config(app=app, loop="asyncio", workers=2, limit_max_requests=1, port=unused_tcp_port)
async with run_server(config):
@ -63,7 +64,6 @@ async def test_run_multiprocess(unused_tcp_port: int):
assert response.status_code == 204
@pytest.mark.anyio
async def test_run_reload(unused_tcp_port: int):
config = Config(app=app, loop="asyncio", reload=True, limit_max_requests=1, port=unused_tcp_port)
async with run_server(config):
@ -107,7 +107,6 @@ def test_run_match_config_params() -> None:
assert config_params == run_params
@pytest.mark.anyio
async def test_exit_on_create_server_with_invalid_host() -> None:
with pytest.raises(SystemExit) as exc_info:
config = Config(app=app, host="illegal_host")

View File

@ -1,13 +1,17 @@
from __future__ import annotations
import asyncio
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path
from socket import socket
from uvicorn import Config, Server
@asynccontextmanager
async def run_server(config: Config, sockets=None):
async def run_server(config: Config, sockets: list[socket] | None = None) -> AsyncIterator[Server]:
server = Server(config=config)
task = asyncio.create_task(server.serve(sockets=sockets))
await asyncio.sleep(0.1)