208 lines
7.1 KiB
Python
208 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
import ssl
|
|
from collections.abc import Callable
|
|
from typing import TypeAlias
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from tests.utils import run_server
|
|
from uvicorn.config import Config
|
|
|
|
DefaultFactory: TypeAlias = Callable[[], ssl.SSLContext]
|
|
|
|
|
|
async def app(scope, receive, send):
|
|
assert scope["type"] == "http"
|
|
await send({"type": "http.response.start", "status": 204, "headers": []})
|
|
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_run(
|
|
tls_ca_ssl_context,
|
|
tls_certificate_server_cert_path,
|
|
tls_certificate_private_key_path,
|
|
tls_ca_certificate_pem_path,
|
|
unused_tcp_port: int,
|
|
):
|
|
config = Config(
|
|
app=app,
|
|
loop="asyncio",
|
|
limit_max_requests=1,
|
|
ssl_keyfile=tls_certificate_private_key_path,
|
|
ssl_certfile=tls_certificate_server_cert_path,
|
|
ssl_ca_certs=tls_ca_certificate_pem_path,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with httpx.AsyncClient(verify=tls_ca_ssl_context) as client:
|
|
response = await client.get(f"https://127.0.0.1:{unused_tcp_port}")
|
|
assert response.status_code == 204
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_run_chain(
|
|
tls_ca_ssl_context,
|
|
tls_certificate_key_and_chain_path,
|
|
tls_ca_certificate_pem_path,
|
|
unused_tcp_port: int,
|
|
):
|
|
config = Config(
|
|
app=app,
|
|
loop="asyncio",
|
|
limit_max_requests=1,
|
|
ssl_certfile=tls_certificate_key_and_chain_path,
|
|
ssl_ca_certs=tls_ca_certificate_pem_path,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with httpx.AsyncClient(verify=tls_ca_ssl_context) as client:
|
|
response = await client.get(f"https://127.0.0.1:{unused_tcp_port}")
|
|
assert response.status_code == 204
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_run_chain_only(tls_ca_ssl_context, tls_certificate_key_and_chain_path, unused_tcp_port: int):
|
|
config = Config(
|
|
app=app,
|
|
loop="asyncio",
|
|
limit_max_requests=1,
|
|
ssl_certfile=tls_certificate_key_and_chain_path,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with httpx.AsyncClient(verify=tls_ca_ssl_context) as client:
|
|
response = await client.get(f"https://127.0.0.1:{unused_tcp_port}")
|
|
assert response.status_code == 204
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_run_password(
|
|
tls_ca_ssl_context,
|
|
tls_certificate_server_cert_path,
|
|
tls_ca_certificate_pem_path,
|
|
tls_certificate_private_key_encrypted_path,
|
|
unused_tcp_port: int,
|
|
):
|
|
config = Config(
|
|
app=app,
|
|
loop="asyncio",
|
|
limit_max_requests=1,
|
|
ssl_keyfile=tls_certificate_private_key_encrypted_path,
|
|
ssl_certfile=tls_certificate_server_cert_path,
|
|
ssl_keyfile_password="uvicorn password for the win",
|
|
ssl_ca_certs=tls_ca_certificate_pem_path,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with httpx.AsyncClient(verify=tls_ca_ssl_context) as client:
|
|
response = await client.get(f"https://127.0.0.1:{unused_tcp_port}")
|
|
assert response.status_code == 204
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_run_ssl_context_factory_default(
|
|
tls_ca_ssl_context: ssl.SSLContext,
|
|
tls_certificate_server_cert_path: str,
|
|
tls_certificate_private_key_path: str,
|
|
unused_tcp_port: int,
|
|
) -> None:
|
|
"""A factory that just delegates to the default factory should produce a working server."""
|
|
|
|
def ssl_context_factory(config: Config, default_ssl_context_factory: DefaultFactory) -> ssl.SSLContext:
|
|
return default_ssl_context_factory()
|
|
|
|
config = Config(
|
|
app=app,
|
|
loop="asyncio",
|
|
limit_max_requests=1,
|
|
ssl_keyfile=tls_certificate_private_key_path,
|
|
ssl_certfile=tls_certificate_server_cert_path,
|
|
ssl_context_factory=ssl_context_factory,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with httpx.AsyncClient(verify=tls_ca_ssl_context) as client:
|
|
response = await client.get(f"https://127.0.0.1:{unused_tcp_port}")
|
|
assert response.status_code == 204
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_run_ssl_context_factory_custom(
|
|
tls_ca_ssl_context: ssl.SSLContext,
|
|
tls_certificate_server_cert_path: str,
|
|
tls_certificate_private_key_path: str,
|
|
unused_tcp_port: int,
|
|
) -> None:
|
|
"""A factory that builds its own SSLContext from scratch should work without ssl_keyfile/ssl_certfile."""
|
|
|
|
def ssl_context_factory(config: Config, default_ssl_context_factory: DefaultFactory) -> ssl.SSLContext:
|
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
ctx.load_cert_chain(tls_certificate_server_cert_path, tls_certificate_private_key_path)
|
|
return ctx
|
|
|
|
config = Config(
|
|
app=app,
|
|
loop="asyncio",
|
|
limit_max_requests=1,
|
|
ssl_context_factory=ssl_context_factory,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with httpx.AsyncClient(verify=tls_ca_ssl_context) as client:
|
|
response = await client.get(f"https://127.0.0.1:{unused_tcp_port}")
|
|
assert response.status_code == 204
|
|
|
|
|
|
def test_ssl_context_factory_mutates_default(
|
|
tls_certificate_server_cert_path: str,
|
|
tls_certificate_private_key_path: str,
|
|
) -> None:
|
|
"""The factory can call the default and mutate the result (e.g., bump TLS minimum version)."""
|
|
|
|
def ssl_context_factory(config: Config, default_ssl_context_factory: DefaultFactory) -> ssl.SSLContext:
|
|
ctx = default_ssl_context_factory()
|
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
|
|
return ctx
|
|
|
|
config = Config(
|
|
app=app,
|
|
ssl_keyfile=tls_certificate_private_key_path,
|
|
ssl_certfile=tls_certificate_server_cert_path,
|
|
ssl_context_factory=ssl_context_factory,
|
|
)
|
|
config.load()
|
|
assert config.is_ssl
|
|
assert isinstance(config.ssl, ssl.SSLContext)
|
|
assert config.ssl.minimum_version == ssl.TLSVersion.TLSv1_3
|
|
|
|
|
|
def test_default_ssl_context_factory_requires_ssl_certfile() -> None:
|
|
"""Calling `default_ssl_context_factory()` without `ssl_certfile` raises a clear error."""
|
|
|
|
def ssl_context_factory(config: Config, default_ssl_context_factory: DefaultFactory) -> ssl.SSLContext:
|
|
return default_ssl_context_factory()
|
|
|
|
config = Config(app=app, ssl_context_factory=ssl_context_factory)
|
|
with pytest.raises(RuntimeError, match="requires `ssl_certfile`"):
|
|
config.load()
|
|
|
|
|
|
def test_ssl_context_factory_must_return_ssl_context() -> None:
|
|
def bad_factory(config: Config, default_ssl_context_factory: DefaultFactory) -> object:
|
|
return "not an SSLContext"
|
|
|
|
config = Config(app=app, ssl_context_factory=bad_factory) # type: ignore[arg-type]
|
|
with pytest.raises(TypeError, match="must return an `ssl.SSLContext`"):
|
|
config.load()
|
|
|
|
|
|
def test_is_ssl_true_when_only_factory_set() -> None:
|
|
def ssl_context_factory(config: Config, default_ssl_context_factory: DefaultFactory) -> ssl.SSLContext:
|
|
return ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # pragma: no cover
|
|
|
|
config = Config(app=app, ssl_context_factory=ssl_context_factory)
|
|
assert config.is_ssl is True
|