Add WebSocketsSansIOProtocol (#2540)

This commit is contained in:
Marcelo Trylesinski 2025-06-24 12:16:26 +02:00 committed by GitHub
parent 5432729137
commit b9606269a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 448 additions and 21 deletions

View File

@ -91,7 +91,7 @@ Using Uvicorn with watchfiles will enable the following options (which are other
* `--loop <str>` - Set the event loop implementation. The uvloop implementation provides greater performance, but is not compatible with Windows or PyPy. **Options:** *'auto', 'asyncio', 'uvloop'.* **Default:** *'auto'*.
* `--http <str>` - Set the HTTP protocol implementation. The httptools implementation provides greater performance, but it not compatible with PyPy. **Options:** *'auto', 'h11', 'httptools'.* **Default:** *'auto'*.
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*.
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. Only available with the `websockets` protocol. **Default:** *16777216* (16 MB).
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*.

View File

@ -92,6 +92,9 @@ filterwarnings = [
"ignore:Uvicorn's native WSGI implementation is deprecated.*:DeprecationWarning",
"ignore: 'cgi' is deprecated and slated for removal in Python 3.13:DeprecationWarning",
"ignore: remove second argument of ws_handler:DeprecationWarning:websockets",
"ignore: websockets.legacy is deprecated.*:DeprecationWarning",
"ignore: websockets.server.WebSocketServerProtocol is deprecated.*:DeprecationWarning",
"ignore: websockets.client.connect is deprecated.*:DeprecationWarning",
]
[tool.coverage.run]

View File

@ -233,9 +233,9 @@ def unused_tcp_port() -> int:
marks=pytest.mark.skipif(not importlib.util.find_spec("wsproto"), reason="wsproto not installed."),
id="wsproto",
),
pytest.param("uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", id="websockets"),
pytest.param(
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
id="websockets",
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio"
),
]
)

View File

@ -5,12 +5,12 @@ import logging
import socket
import sys
from collections.abc import Iterator
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import httpx
import pytest
import websockets
import websockets.client
from websockets.protocol import State
from tests.utils import run_server
from uvicorn import Config
@ -50,7 +50,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def test_trace_logging(caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int):
async def test_trace_logging(caplog: pytest.LogCaptureFixture, logging_config: dict[str, Any], unused_tcp_port: int):
config = Config(
app=app,
log_level="trace",
@ -92,8 +92,8 @@ async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging
async def test_trace_logging_on_ws_protocol(
ws_protocol_cls: WSProtocol,
caplog,
logging_config,
caplog: pytest.LogCaptureFixture,
logging_config: dict[str, Any],
unused_tcp_port: int,
):
async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
@ -105,9 +105,9 @@ async def test_trace_logging_on_ws_protocol(
elif message["type"] == "websocket.disconnect":
break
async def open_connection(url):
async def open_connection(url: str):
async with websockets.client.connect(url) as websocket:
return websocket.open
return websocket.state is State.OPEN
config = Config(
app=websocket_app,
@ -127,7 +127,9 @@ async def test_trace_logging_on_ws_protocol(
@pytest.mark.parametrize("use_colors", [(True), (False), (None)])
async def test_access_logging(use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int):
async def test_access_logging(
use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config: dict[str, Any], unused_tcp_port: int
):
config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port)
with caplog_for_logger(caplog, "uvicorn.access"):
async with run_server(config):
@ -141,7 +143,7 @@ async def test_access_logging(use_colors: bool, caplog: pytest.LogCaptureFixture
@pytest.mark.parametrize("use_colors", [(True), (False)])
async def test_default_logging(
use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int
use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config: dict[str, Any], unused_tcp_port: int
):
config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port)
with caplog_for_logger(caplog, "uvicorn.access"):

View File

@ -465,6 +465,7 @@ async def test_proxy_headers_websocket_x_forwarded_proto(
host, port = scope["client"]
await send({"type": "websocket.accept"})
await send({"type": "websocket.send", "text": f"{scheme}://{host}:{port}"})
await send({"type": "websocket.close"})
app_with_middleware = ProxyHeadersMiddleware(websocket_app, trusted_hosts="*")
config = Config(

View File

@ -597,12 +597,9 @@ async def test_connection_lost_before_handshake_complete(
await send_accept_task.wait()
disconnect_message = await receive() # type: ignore
response: httpx.Response | None = None
async def websocket_session(uri: str):
nonlocal response
async with httpx.AsyncClient() as client:
response = await client.get(
await client.get(
f"http://127.0.0.1:{unused_tcp_port}",
headers={
"upgrade": "websocket",
@ -619,9 +616,6 @@ async def test_connection_lost_before_handshake_complete(
send_accept_task.set()
await asyncio.sleep(0.1)
assert response is not None
assert response.status_code == 500, response.text
assert response.text == "Internal Server Error"
assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}
await task
@ -916,6 +910,9 @@ async def test_server_reject_connection_with_body_nolength(
async def test_server_reject_connection_with_invalid_msg(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
@ -947,6 +944,9 @@ async def test_server_reject_connection_with_invalid_msg(
async def test_server_reject_connection_with_missing_body(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "websocket"
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
@ -982,6 +982,8 @@ async def test_server_multiple_websocket_http_response_start_events(
The server should raise an exception if it sends multiple
websocket.http.response.start events.
"""
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
exception_message: str | None = None
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):

View File

@ -25,7 +25,7 @@ from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from uvicorn.middleware.wsgi import WSGIMiddleware
HTTPProtocolType = Literal["auto", "h11", "httptools"]
WSProtocolType = Literal["auto", "none", "websockets", "wsproto"]
WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"]
LifespanType = Literal["auto", "on", "off"]
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]
@ -47,6 +47,7 @@ WS_PROTOCOLS: dict[WSProtocolType, str | None] = {
"auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol",
"none": None,
"websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
"websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol",
"wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
}
LIFESPAN: dict[LifespanType, str] = {

View File

@ -0,0 +1,417 @@
from __future__ import annotations
import asyncio
import logging
import sys
from asyncio.transports import BaseTransport, Transport
from http import HTTPStatus
from typing import Any, Literal, cast
from urllib.parse import unquote
from websockets.exceptions import InvalidState
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.frames import Frame, Opcode
from websockets.http11 import Request
from websockets.server import ServerProtocol
from uvicorn._types import (
ASGIReceiveEvent,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
ClientDisconnected,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState
if sys.version_info >= (3, 11): # pragma: no cover
from typing import assert_never
else: # pragma: no cover
from typing_extensions import assert_never
class WebSocketsSansIOProtocol(asyncio.Protocol):
def __init__(
self,
config: Config,
server_state: ServerState,
app_state: dict[str, Any],
_loop: asyncio.AbstractEventLoop | None = None,
) -> None:
if not config.loaded:
config.load() # pragma: no cover
self.config = config
self.app = config.loaded_app
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
self.app_state = app_state
# Shared server state
self.connections = server_state.connections
self.tasks = server_state.tasks
self.default_headers = server_state.default_headers
# Connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: tuple[str, int] | None = None
self.client: tuple[str, int] | None = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
# WebSocket state
self.queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
self.handshake_initiated = False
self.handshake_complete = False
self.close_sent = False
self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None
extensions = []
if self.config.ws_per_message_deflate:
extensions = [
ServerPerMessageDeflateFactory(
server_max_window_bits=12,
client_max_window_bits=12,
compress_settings={"memLevel": 5},
)
]
self.conn = ServerProtocol(
extensions=extensions,
max_size=self.config.ws_max_size,
logger=logging.getLogger("uvicorn.error"),
)
self.read_paused = False
self.writable = asyncio.Event()
self.writable.set()
# Buffers
self.bytes = b""
def connection_made(self, transport: BaseTransport) -> None:
"""Called when a connection is made."""
transport = cast(Transport, transport)
self.connections.add(self)
self.transport = transport
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
def connection_lost(self, exc: Exception | None) -> None:
code = 1005 if self.handshake_complete else 1006
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
self.connections.remove(self)
if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
self.handshake_complete = True
if exc is None:
self.transport.close()
def eof_received(self) -> None:
pass
def shutdown(self) -> None:
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
self.conn.send_close(1012)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
else:
self.send_500_response()
self.transport.close()
def data_received(self, data: bytes) -> None:
self.conn.receive_data(data)
if self.conn.parser_exc is not None: # pragma: no cover
self.handle_parser_exception()
return
self.handle_events()
def handle_events(self) -> None:
for event in self.conn.events_received():
if isinstance(event, Request):
self.handle_connect(event)
if isinstance(event, Frame):
if event.opcode == Opcode.CONT:
self.handle_cont(event) # pragma: no cover
elif event.opcode == Opcode.TEXT:
self.handle_text(event)
elif event.opcode == Opcode.BINARY:
self.handle_bytes(event)
elif event.opcode == Opcode.PING:
self.handle_ping()
elif event.opcode == Opcode.PONG:
pass # pragma: no cover
elif event.opcode == Opcode.CLOSE:
self.handle_close(event)
else:
assert_never(event.opcode) # pragma: no cover
# Event handlers
def handle_connect(self, event: Request) -> None:
self.request = event
self.response = self.conn.accept(event)
self.handshake_initiated = True
if self.response.status_code != 101:
self.handshake_complete = True
self.close_sent = True
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.transport.close()
return
headers = [
(key.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for key, value in event.headers.raw_items()
]
raw_path, _, query_string = event.path.partition("?")
self.scope: WebSocketScope = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": "1.1",
"scheme": self.scheme,
"server": self.server,
"client": self.client,
"root_path": self.root_path,
"path": unquote(raw_path),
"raw_path": raw_path.encode("ascii"),
"query_string": query_string.encode("ascii"),
"headers": headers,
"subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"),
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
self.queue.put_nowait({"type": "websocket.connect"})
task = self.loop.create_task(self.run_asgi())
task.add_done_callback(self.on_task_complete)
self.tasks.add(task)
def handle_cont(self, event: Frame) -> None: # pragma: no cover
self.bytes += event.data
if event.fin:
self.send_receive_event_to_app()
def handle_text(self, event: Frame) -> None:
self.bytes = event.data
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
if event.fin:
self.send_receive_event_to_app()
def handle_bytes(self, event: Frame) -> None:
self.bytes = event.data
self.curr_msg_data_type = "bytes"
if event.fin:
self.send_receive_event_to_app()
def send_receive_event_to_app(self) -> None:
if self.curr_msg_data_type == "text":
try:
self.queue.put_nowait({"type": "websocket.receive", "text": self.bytes.decode()})
except UnicodeDecodeError: # pragma: no cover
self.logger.exception("Invalid UTF-8 sequence received from client.")
self.conn.send_close(1007)
self.handle_parser_exception()
return
else:
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
if not self.read_paused:
self.read_paused = True
self.transport.pause_reading()
def handle_ping(self) -> None:
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
def handle_close(self, event: Frame) -> None:
if not self.close_sent and not self.transport.is_closing():
assert self.conn.close_rcvd is not None
code = self.conn.close_rcvd.code
reason = self.conn.close_rcvd.reason
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.transport.close()
def handle_parser_exception(self) -> None: # pragma: no cover
assert self.conn.close_sent is not None
code = self.conn.close_sent.code
reason = self.conn.close_sent.reason
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.close_sent = True
self.transport.close()
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
except ClientDisconnected:
self.transport.close() # pragma: no cover
except BaseException:
self.logger.exception("Exception in ASGI application\n")
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
self.logger.error("ASGI callable returned without completing handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
self.transport.close()
def send_500_response(self) -> None:
if self.initial_response or self.handshake_complete:
return
response = self.conn.reject(500, "Internal Server Error")
self.conn.send_response(response)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()
message_type = message["type"]
if not self.handshake_complete and self.initial_response is None:
if message_type == "websocket.accept":
message = cast(WebSocketAcceptEvent, message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_path_with_query_string(self.scope),
)
headers = [
(name.decode("latin-1").lower(), value.decode("latin-1").lower())
for name, value in (self.default_headers + list(message.get("headers", [])))
]
accepted_subprotocol = message.get("subprotocol")
if accepted_subprotocol:
headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol))
self.response.headers.update(headers)
if not self.transport.is_closing():
self.handshake_complete = True
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
elif message_type == "websocket.close":
message = cast(WebSocketCloseEvent, message)
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_path_with_query_string(self.scope),
)
response = self.conn.reject(HTTPStatus.FORBIDDEN, "")
self.conn.send_response(response)
output = self.conn.data_to_send()
self.close_sent = True
self.handshake_complete = True
self.transport.write(b"".join(output))
self.transport.close()
elif message_type == "websocket.http.response.start" and self.initial_response is None:
message = cast(WebSocketResponseStartEvent, message)
if not (100 <= message["status"] < 600):
raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"])
self.logger.info(
'%s - "WebSocket %s" %d',
self.scope["client"],
get_path_with_query_string(self.scope),
message["status"],
)
headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in list(message.get("headers", []))
]
self.initial_response = (message["status"], headers, b"")
else:
msg = (
"Expected ASGI message 'websocket.accept', 'websocket.close' "
"or 'websocket.http.response.start' "
"but got '%s'."
)
raise RuntimeError(msg % message_type)
elif not self.close_sent and self.initial_response is None:
try:
if message_type == "websocket.send":
message = cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
if text_data:
self.conn.send_text(text_data.encode())
elif bytes_data:
self.conn.send_binary(bytes_data)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
elif message_type == "websocket.close" and not self.transport.is_closing():
message = cast(WebSocketCloseEvent, message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
self.conn.send_close(code, reason)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
self.close_sent = True
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
raise RuntimeError(msg % message_type)
except InvalidState:
raise ClientDisconnected()
elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
message = cast(WebSocketResponseBodyEvent, message)
body = self.initial_response[2] + message["body"]
self.initial_response = self.initial_response[:2] + (body,)
if not message.get("more_body", False):
response = self.conn.reject(self.initial_response[0], body.decode())
response.headers.update(self.initial_response[1])
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.conn.send_response(response)
output = self.conn.data_to_send()
self.close_sent = True
self.transport.write(b"".join(output))
self.transport.close()
else: # pragma: no cover
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
raise RuntimeError(msg % message_type)
else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)
async def receive(self) -> ASGIReceiveEvent:
message = await self.queue.get()
if self.read_paused and self.queue.empty():
self.read_paused = False
self.transport.resume_reading()
return message

View File

@ -23,9 +23,10 @@ if TYPE_CHECKING:
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol]
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol]
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.