Add WebSocketsSansIOProtocol (#2540)
This commit is contained in:
parent
5432729137
commit
b9606269a7
@ -91,7 +91,7 @@ Using Uvicorn with watchfiles will enable the following options (which are other
|
||||
|
||||
* `--loop <str>` - Set the event loop implementation. The uvloop implementation provides greater performance, but is not compatible with Windows or PyPy. **Options:** *'auto', 'asyncio', 'uvloop'.* **Default:** *'auto'*.
|
||||
* `--http <str>` - Set the HTTP protocol implementation. The httptools implementation provides greater performance, but it not compatible with PyPy. **Options:** *'auto', 'h11', 'httptools'.* **Default:** *'auto'*.
|
||||
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*.
|
||||
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
|
||||
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. Only available with the `websockets` protocol. **Default:** *16777216* (16 MB).
|
||||
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
|
||||
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Only available with the `websockets` protocol. **Default:** *20.0*.
|
||||
|
||||
@ -92,6 +92,9 @@ filterwarnings = [
|
||||
"ignore:Uvicorn's native WSGI implementation is deprecated.*:DeprecationWarning",
|
||||
"ignore: 'cgi' is deprecated and slated for removal in Python 3.13:DeprecationWarning",
|
||||
"ignore: remove second argument of ws_handler:DeprecationWarning:websockets",
|
||||
"ignore: websockets.legacy is deprecated.*:DeprecationWarning",
|
||||
"ignore: websockets.server.WebSocketServerProtocol is deprecated.*:DeprecationWarning",
|
||||
"ignore: websockets.client.connect is deprecated.*:DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
|
||||
@ -233,9 +233,9 @@ def unused_tcp_port() -> int:
|
||||
marks=pytest.mark.skipif(not importlib.util.find_spec("wsproto"), reason="wsproto not installed."),
|
||||
id="wsproto",
|
||||
),
|
||||
pytest.param("uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", id="websockets"),
|
||||
pytest.param(
|
||||
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
|
||||
id="websockets",
|
||||
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio"
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@ -5,12 +5,12 @@ import logging
|
||||
import socket
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import websockets
|
||||
import websockets.client
|
||||
from websockets.protocol import State
|
||||
|
||||
from tests.utils import run_server
|
||||
from uvicorn import Config
|
||||
@ -50,7 +50,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
|
||||
async def test_trace_logging(caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int):
|
||||
async def test_trace_logging(caplog: pytest.LogCaptureFixture, logging_config: dict[str, Any], unused_tcp_port: int):
|
||||
config = Config(
|
||||
app=app,
|
||||
log_level="trace",
|
||||
@ -92,8 +92,8 @@ async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging
|
||||
|
||||
async def test_trace_logging_on_ws_protocol(
|
||||
ws_protocol_cls: WSProtocol,
|
||||
caplog,
|
||||
logging_config,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
logging_config: dict[str, Any],
|
||||
unused_tcp_port: int,
|
||||
):
|
||||
async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
@ -105,9 +105,9 @@ async def test_trace_logging_on_ws_protocol(
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
async def open_connection(url):
|
||||
async def open_connection(url: str):
|
||||
async with websockets.client.connect(url) as websocket:
|
||||
return websocket.open
|
||||
return websocket.state is State.OPEN
|
||||
|
||||
config = Config(
|
||||
app=websocket_app,
|
||||
@ -127,7 +127,9 @@ async def test_trace_logging_on_ws_protocol(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_colors", [(True), (False), (None)])
|
||||
async def test_access_logging(use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int):
|
||||
async def test_access_logging(
|
||||
use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config: dict[str, Any], unused_tcp_port: int
|
||||
):
|
||||
config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port)
|
||||
with caplog_for_logger(caplog, "uvicorn.access"):
|
||||
async with run_server(config):
|
||||
@ -141,7 +143,7 @@ async def test_access_logging(use_colors: bool, caplog: pytest.LogCaptureFixture
|
||||
|
||||
@pytest.mark.parametrize("use_colors", [(True), (False)])
|
||||
async def test_default_logging(
|
||||
use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int
|
||||
use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config: dict[str, Any], unused_tcp_port: int
|
||||
):
|
||||
config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port)
|
||||
with caplog_for_logger(caplog, "uvicorn.access"):
|
||||
|
||||
@ -465,6 +465,7 @@ async def test_proxy_headers_websocket_x_forwarded_proto(
|
||||
host, port = scope["client"]
|
||||
await send({"type": "websocket.accept"})
|
||||
await send({"type": "websocket.send", "text": f"{scheme}://{host}:{port}"})
|
||||
await send({"type": "websocket.close"})
|
||||
|
||||
app_with_middleware = ProxyHeadersMiddleware(websocket_app, trusted_hosts="*")
|
||||
config = Config(
|
||||
|
||||
@ -597,12 +597,9 @@ async def test_connection_lost_before_handshake_complete(
|
||||
await send_accept_task.wait()
|
||||
disconnect_message = await receive() # type: ignore
|
||||
|
||||
response: httpx.Response | None = None
|
||||
|
||||
async def websocket_session(uri: str):
|
||||
nonlocal response
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
await client.get(
|
||||
f"http://127.0.0.1:{unused_tcp_port}",
|
||||
headers={
|
||||
"upgrade": "websocket",
|
||||
@ -619,9 +616,6 @@ async def test_connection_lost_before_handshake_complete(
|
||||
send_accept_task.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert response is not None
|
||||
assert response.status_code == 500, response.text
|
||||
assert response.text == "Internal Server Error"
|
||||
assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}
|
||||
await task
|
||||
|
||||
@ -916,6 +910,9 @@ async def test_server_reject_connection_with_body_nolength(
|
||||
async def test_server_reject_connection_with_invalid_msg(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
|
||||
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
assert scope["type"] == "websocket"
|
||||
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
|
||||
@ -947,6 +944,9 @@ async def test_server_reject_connection_with_invalid_msg(
|
||||
async def test_server_reject_connection_with_missing_body(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
|
||||
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
assert scope["type"] == "websocket"
|
||||
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
|
||||
@ -982,6 +982,8 @@ async def test_server_multiple_websocket_http_response_start_events(
|
||||
The server should raise an exception if it sends multiple
|
||||
websocket.http.response.start events.
|
||||
"""
|
||||
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
|
||||
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
|
||||
exception_message: str | None = None
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
|
||||
@ -25,7 +25,7 @@ from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
from uvicorn.middleware.wsgi import WSGIMiddleware
|
||||
|
||||
HTTPProtocolType = Literal["auto", "h11", "httptools"]
|
||||
WSProtocolType = Literal["auto", "none", "websockets", "wsproto"]
|
||||
WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"]
|
||||
LifespanType = Literal["auto", "on", "off"]
|
||||
LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"]
|
||||
InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]
|
||||
@ -47,6 +47,7 @@ WS_PROTOCOLS: dict[WSProtocolType, str | None] = {
|
||||
"auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol",
|
||||
"none": None,
|
||||
"websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
|
||||
"websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol",
|
||||
"wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
|
||||
}
|
||||
LIFESPAN: dict[LifespanType, str] = {
|
||||
|
||||
417
uvicorn/protocols/websockets/websockets_sansio_impl.py
Normal file
417
uvicorn/protocols/websockets/websockets_sansio_impl.py
Normal file
@ -0,0 +1,417 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from asyncio.transports import BaseTransport, Transport
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
from websockets.exceptions import InvalidState
|
||||
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
||||
from websockets.frames import Frame, Opcode
|
||||
from websockets.http11 import Request
|
||||
from websockets.server import ServerProtocol
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
if sys.version_info >= (3, 11): # pragma: no cover
|
||||
from typing import assert_never
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import assert_never
|
||||
|
||||
|
||||
class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load() # pragma: no cover
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
self.default_headers = server_state.default_headers
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# WebSocket state
|
||||
self.queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
|
||||
self.handshake_initiated = False
|
||||
self.handshake_complete = False
|
||||
self.close_sent = False
|
||||
self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None
|
||||
|
||||
extensions = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions = [
|
||||
ServerPerMessageDeflateFactory(
|
||||
server_max_window_bits=12,
|
||||
client_max_window_bits=12,
|
||||
compress_settings={"memLevel": 5},
|
||||
)
|
||||
]
|
||||
self.conn = ServerProtocol(
|
||||
extensions=extensions,
|
||||
max_size=self.config.ws_max_size,
|
||||
logger=logging.getLogger("uvicorn.error"),
|
||||
)
|
||||
|
||||
self.read_paused = False
|
||||
self.writable = asyncio.Event()
|
||||
self.writable.set()
|
||||
|
||||
# Buffers
|
||||
self.bytes = b""
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called when a connection is made."""
|
||||
transport = cast(Transport, transport)
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.handshake_complete = True
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.handshake_complete:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
||||
self.conn.send_close(1012)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self.conn.receive_data(data)
|
||||
if self.conn.parser_exc is not None: # pragma: no cover
|
||||
self.handle_parser_exception()
|
||||
return
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
for event in self.conn.events_received():
|
||||
if isinstance(event, Request):
|
||||
self.handle_connect(event)
|
||||
if isinstance(event, Frame):
|
||||
if event.opcode == Opcode.CONT:
|
||||
self.handle_cont(event) # pragma: no cover
|
||||
elif event.opcode == Opcode.TEXT:
|
||||
self.handle_text(event)
|
||||
elif event.opcode == Opcode.BINARY:
|
||||
self.handle_bytes(event)
|
||||
elif event.opcode == Opcode.PING:
|
||||
self.handle_ping()
|
||||
elif event.opcode == Opcode.PONG:
|
||||
pass # pragma: no cover
|
||||
elif event.opcode == Opcode.CLOSE:
|
||||
self.handle_close(event)
|
||||
else:
|
||||
assert_never(event.opcode) # pragma: no cover
|
||||
|
||||
# Event handlers
|
||||
|
||||
def handle_connect(self, event: Request) -> None:
|
||||
self.request = event
|
||||
self.response = self.conn.accept(event)
|
||||
self.handshake_initiated = True
|
||||
if self.response.status_code != 101:
|
||||
self.handshake_complete = True
|
||||
self.close_sent = True
|
||||
self.conn.send_response(self.response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
return
|
||||
|
||||
headers = [
|
||||
(key.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
|
||||
for key, value in event.headers.raw_items()
|
||||
]
|
||||
raw_path, _, query_string = event.path.partition("?")
|
||||
self.scope: WebSocketScope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": unquote(raw_path),
|
||||
"raw_path": raw_path.encode("ascii"),
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": headers,
|
||||
"subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"),
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
self.queue.put_nowait({"type": "websocket.connect"})
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
|
||||
def handle_cont(self, event: Frame) -> None: # pragma: no cover
|
||||
self.bytes += event.data
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_text(self, event: Frame) -> None:
|
||||
self.bytes = event.data
|
||||
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_bytes(self, event: Frame) -> None:
|
||||
self.bytes = event.data
|
||||
self.curr_msg_data_type = "bytes"
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def send_receive_event_to_app(self) -> None:
|
||||
if self.curr_msg_data_type == "text":
|
||||
try:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "text": self.bytes.decode()})
|
||||
except UnicodeDecodeError: # pragma: no cover
|
||||
self.logger.exception("Invalid UTF-8 sequence received from client.")
|
||||
self.conn.send_close(1007)
|
||||
self.handle_parser_exception()
|
||||
return
|
||||
else:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
def handle_ping(self) -> None:
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
def handle_close(self, event: Frame) -> None:
|
||||
if not self.close_sent and not self.transport.is_closing():
|
||||
assert self.conn.close_rcvd is not None
|
||||
code = self.conn.close_rcvd.code
|
||||
reason = self.conn.close_rcvd.reason
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
|
||||
def handle_parser_exception(self) -> None: # pragma: no cover
|
||||
assert self.conn.close_sent is not None
|
||||
code = self.conn.close_sent.code
|
||||
reason = self.conn.close_sent.reason
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
try:
|
||||
result = await self.app(self.scope, self.receive, self.send)
|
||||
except ClientDisconnected:
|
||||
self.transport.close() # pragma: no cover
|
||||
except BaseException:
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
else:
|
||||
if not self.handshake_complete:
|
||||
self.logger.error("ASGI callable returned without completing handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
self.transport.close()
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
if self.initial_response or self.handshake_complete:
|
||||
return
|
||||
response = self.conn.reject(500, "Internal Server Error")
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_complete and self.initial_response is None:
|
||||
if message_type == "websocket.accept":
|
||||
message = cast(WebSocketAcceptEvent, message)
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
headers = [
|
||||
(name.decode("latin-1").lower(), value.decode("latin-1").lower())
|
||||
for name, value in (self.default_headers + list(message.get("headers", [])))
|
||||
]
|
||||
accepted_subprotocol = message.get("subprotocol")
|
||||
if accepted_subprotocol:
|
||||
headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol))
|
||||
self.response.headers.update(headers)
|
||||
|
||||
if not self.transport.is_closing():
|
||||
self.handshake_complete = True
|
||||
self.conn.send_response(self.response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast(WebSocketCloseEvent, message)
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
response = self.conn.reject(HTTPStatus.FORBIDDEN, "")
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.close_sent = True
|
||||
self.handshake_complete = True
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
elif message_type == "websocket.http.response.start" and self.initial_response is None:
|
||||
message = cast(WebSocketResponseStartEvent, message)
|
||||
if not (100 <= message["status"] < 600):
|
||||
raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"])
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
self.scope["client"],
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in list(message.get("headers", []))
|
||||
]
|
||||
self.initial_response = (message["status"], headers, b"")
|
||||
else:
|
||||
msg = (
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
"or 'websocket.http.response.start' "
|
||||
"but got '%s'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.close_sent and self.initial_response is None:
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast(WebSocketSendEvent, message)
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
if text_data:
|
||||
self.conn.send_text(text_data.encode())
|
||||
elif bytes_data:
|
||||
self.conn.send_binary(bytes_data)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
elif message_type == "websocket.close" and not self.transport.is_closing():
|
||||
message = cast(WebSocketCloseEvent, message)
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
self.conn.send_close(code, reason)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
except InvalidState:
|
||||
raise ClientDisconnected()
|
||||
elif self.initial_response is not None:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast(WebSocketResponseBodyEvent, message)
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
response = self.conn.reject(self.initial_response[0], body.decode())
|
||||
response.headers.update(self.initial_response[1])
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.close_sent = True
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
else: # pragma: no cover
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
message = await self.queue.get()
|
||||
if self.read_paused and self.queue.empty():
|
||||
self.read_paused = False
|
||||
self.transport.resume_reading()
|
||||
return message
|
||||
@ -23,9 +23,10 @@ if TYPE_CHECKING:
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
|
||||
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
|
||||
|
||||
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol]
|
||||
Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol]
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user