306 lines
8.9 KiB
Python
306 lines
8.9 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
import typing
|
|
|
|
import pytest
|
|
import trustme
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.serialization import (
|
|
BestAvailableEncryption,
|
|
Encoding,
|
|
PrivateFormat,
|
|
load_pem_private_key,
|
|
)
|
|
from uvicorn.config import Config
|
|
from uvicorn.server import Server
|
|
|
|
from httpx import URL
|
|
from tests.concurrency import sleep
|
|
|
|
if typing.TYPE_CHECKING: # pragma: no cover
|
|
from httpx._transports.asgi import _Receive, _Send
|
|
|
|
ENVIRONMENT_VARIABLES = {
|
|
"SSL_CERT_FILE",
|
|
"SSL_CERT_DIR",
|
|
"HTTP_PROXY",
|
|
"HTTPS_PROXY",
|
|
"ALL_PROXY",
|
|
"NO_PROXY",
|
|
"SSLKEYLOGFILE",
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def clean_environ():
|
|
"""Keeps os.environ clean for every test without having to mock os.environ"""
|
|
original_environ = os.environ.copy()
|
|
os.environ.clear()
|
|
os.environ.update(
|
|
{
|
|
k: v
|
|
for k, v in original_environ.items()
|
|
if k not in ENVIRONMENT_VARIABLES and k.lower() not in ENVIRONMENT_VARIABLES
|
|
}
|
|
)
|
|
yield
|
|
os.environ.clear()
|
|
os.environ.update(original_environ)
|
|
|
|
|
|
_Scope = typing.Dict[str, typing.Any]
|
|
|
|
|
|
async def app(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
assert scope["type"] == "http"
|
|
if scope["path"].startswith("/slow_response"):
|
|
await slow_response(scope, receive, send)
|
|
elif scope["path"].startswith("/status"):
|
|
await status_code(scope, receive, send)
|
|
elif scope["path"].startswith("/echo_body"):
|
|
await echo_body(scope, receive, send)
|
|
elif scope["path"].startswith("/echo_binary"):
|
|
await echo_binary(scope, receive, send)
|
|
elif scope["path"].startswith("/echo_headers"):
|
|
await echo_headers(scope, receive, send)
|
|
elif scope["path"].startswith("/redirect_301"):
|
|
await redirect_301(scope, receive, send)
|
|
elif scope["path"].startswith("/json"):
|
|
await hello_world_json(scope, receive, send)
|
|
else:
|
|
await hello_world(scope, receive, send)
|
|
|
|
|
|
async def hello_world(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [[b"content-type", b"text/plain"]],
|
|
}
|
|
)
|
|
await send({"type": "http.response.body", "body": b"Hello, world!"})
|
|
|
|
|
|
async def hello_world_json(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [[b"content-type", b"application/json"]],
|
|
}
|
|
)
|
|
await send({"type": "http.response.body", "body": b'{"Hello": "world!"}'})
|
|
|
|
|
|
async def slow_response(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [[b"content-type", b"text/plain"]],
|
|
}
|
|
)
|
|
await sleep(1.0) # Allow triggering a read timeout.
|
|
await send({"type": "http.response.body", "body": b"Hello, world!"})
|
|
|
|
|
|
async def status_code(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
status_code = int(scope["path"].replace("/status/", ""))
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": status_code,
|
|
"headers": [[b"content-type", b"text/plain"]],
|
|
}
|
|
)
|
|
await send({"type": "http.response.body", "body": b"Hello, world!"})
|
|
|
|
|
|
async def echo_body(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
body = b""
|
|
more_body = True
|
|
|
|
while more_body:
|
|
message = await receive()
|
|
body += message.get("body", b"")
|
|
more_body = message.get("more_body", False)
|
|
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [[b"content-type", b"text/plain"]],
|
|
}
|
|
)
|
|
await send({"type": "http.response.body", "body": body})
|
|
|
|
|
|
async def echo_binary(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
body = b""
|
|
more_body = True
|
|
|
|
while more_body:
|
|
message = await receive()
|
|
body += message.get("body", b"")
|
|
more_body = message.get("more_body", False)
|
|
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [[b"content-type", b"application/octet-stream"]],
|
|
}
|
|
)
|
|
await send({"type": "http.response.body", "body": body})
|
|
|
|
|
|
async def echo_headers(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
body = {
|
|
name.capitalize().decode(): value.decode()
|
|
for name, value in scope.get("headers", [])
|
|
}
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [[b"content-type", b"application/json"]],
|
|
}
|
|
)
|
|
await send({"type": "http.response.body", "body": json.dumps(body).encode()})
|
|
|
|
|
|
async def redirect_301(scope: _Scope, receive: "_Receive", send: "_Send") -> None:
|
|
await send(
|
|
{"type": "http.response.start", "status": 301, "headers": [[b"location", b"/"]]}
|
|
)
|
|
await send({"type": "http.response.body"})
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def cert_authority():
|
|
return trustme.CA()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def ca_cert_pem_file(cert_authority):
|
|
with cert_authority.cert_pem.tempfile() as tmp:
|
|
yield tmp
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def localhost_cert(cert_authority):
|
|
return cert_authority.issue_cert("localhost")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def cert_pem_file(localhost_cert):
|
|
with localhost_cert.cert_chain_pems[0].tempfile() as tmp:
|
|
yield tmp
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def cert_private_key_file(localhost_cert):
|
|
with localhost_cert.private_key_pem.tempfile() as tmp:
|
|
yield tmp
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def cert_encrypted_private_key_file(localhost_cert):
|
|
# Deserialize the private key and then reserialize with a password
|
|
private_key = load_pem_private_key(
|
|
localhost_cert.private_key_pem.bytes(), password=None, backend=default_backend()
|
|
)
|
|
encrypted_private_key_pem = trustme.Blob(
|
|
private_key.private_bytes(
|
|
Encoding.PEM,
|
|
PrivateFormat.TraditionalOpenSSL,
|
|
BestAvailableEncryption(password=b"password"),
|
|
)
|
|
)
|
|
with encrypted_private_key_pem.tempfile() as tmp:
|
|
yield tmp
|
|
|
|
|
|
class TestServer(Server):
|
|
@property
|
|
def url(self) -> URL:
|
|
protocol = "https" if self.config.is_ssl else "http"
|
|
return URL(f"{protocol}://{self.config.host}:{self.config.port}/")
|
|
|
|
def install_signal_handlers(self) -> None:
|
|
# Disable the default installation of handlers for signals such as SIGTERM,
|
|
# because it can only be done in the main thread.
|
|
pass
|
|
|
|
async def serve(self, sockets=None):
|
|
self.restart_requested = asyncio.Event()
|
|
|
|
loop = asyncio.get_event_loop()
|
|
tasks = {
|
|
loop.create_task(super().serve(sockets=sockets)),
|
|
loop.create_task(self.watch_restarts()),
|
|
}
|
|
await asyncio.wait(tasks)
|
|
|
|
async def restart(self) -> None: # pragma: no cover
|
|
# This coroutine may be called from a different thread than the one the
|
|
# server is running on, and from an async environment that's not asyncio.
|
|
# For this reason, we use an event to coordinate with the server
|
|
# instead of calling shutdown()/startup() directly, and should not make
|
|
# any asyncio-specific operations.
|
|
self.started = False
|
|
self.restart_requested.set()
|
|
while not self.started:
|
|
await sleep(0.2)
|
|
|
|
async def watch_restarts(self) -> None: # pragma: no cover
|
|
while True:
|
|
if self.should_exit:
|
|
return
|
|
|
|
try:
|
|
await asyncio.wait_for(self.restart_requested.wait(), timeout=0.1)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
self.restart_requested.clear()
|
|
await self.shutdown()
|
|
await self.startup()
|
|
|
|
|
|
def serve_in_thread(server: TestServer) -> typing.Iterator[TestServer]:
|
|
thread = threading.Thread(target=server.run)
|
|
thread.start()
|
|
try:
|
|
while not server.started:
|
|
time.sleep(1e-3)
|
|
yield server
|
|
finally:
|
|
server.should_exit = True
|
|
thread.join()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def server() -> typing.Iterator[TestServer]:
|
|
config = Config(app=app, lifespan="off", loop="asyncio")
|
|
server = TestServer(config=config)
|
|
yield from serve_in_thread(server)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def https_server(cert_pem_file, cert_private_key_file):
|
|
config = Config(
|
|
app=app,
|
|
lifespan="off",
|
|
ssl_certfile=cert_pem_file,
|
|
ssl_keyfile=cert_private_key_file,
|
|
port=8001,
|
|
loop="asyncio",
|
|
)
|
|
server = TestServer(config=config)
|
|
yield from serve_in_thread(server)
|