1347 lines
54 KiB
Python
1347 lines
54 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from copy import deepcopy
|
|
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict
|
|
|
|
import httpx
|
|
import pytest
|
|
import websockets
|
|
import websockets.client
|
|
import websockets.exceptions
|
|
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
|
|
from websockets.frames import Opcode
|
|
from websockets.typing import Subprotocol
|
|
|
|
from tests.response import Response
|
|
from tests.utils import run_server
|
|
from uvicorn._types import (
|
|
ASGIReceiveCallable,
|
|
ASGIReceiveEvent,
|
|
ASGISendCallable,
|
|
Scope,
|
|
WebSocketCloseEvent,
|
|
WebSocketConnectEvent,
|
|
WebSocketDisconnectEvent,
|
|
WebSocketReceiveEvent,
|
|
WebSocketResponseStartEvent,
|
|
)
|
|
from uvicorn.config import Config
|
|
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
|
from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
|
|
|
|
try:
|
|
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol as _WSProtocol
|
|
|
|
skip_if_no_wsproto = pytest.mark.skipif(False, reason="wsproto is installed.")
|
|
except ModuleNotFoundError: # pragma: no cover
|
|
skip_if_no_wsproto = pytest.mark.skipif(True, reason="wsproto is not installed.")
|
|
|
|
if TYPE_CHECKING:
|
|
from uvicorn.protocols.http.h11_impl import H11Protocol
|
|
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
|
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol as _WSProtocol
|
|
|
|
HTTPProtocol: TypeAlias = "type[H11Protocol | HttpToolsProtocol]"
|
|
WSProtocol: TypeAlias = "type[_WSProtocol | WebSocketProtocol]"
|
|
KeepaliveWSProtocol: TypeAlias = "type[_WSProtocol | WebSocketsSansIOProtocol]"
|
|
|
|
pytestmark = pytest.mark.anyio
|
|
|
|
|
|
class WebSocketResponse:
|
|
def __init__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
self.scope = scope
|
|
self.receive = receive
|
|
self.send = send
|
|
|
|
def __await__(self):
|
|
return self.asgi().__await__()
|
|
|
|
async def asgi(self):
|
|
while True:
|
|
message = await self.receive()
|
|
message_type = message["type"].replace(".", "_")
|
|
handler = getattr(self, message_type, None)
|
|
if handler is not None:
|
|
await handler(message)
|
|
if message_type == "websocket_disconnect":
|
|
break
|
|
|
|
|
|
async def wsresponse(url: str):
|
|
"""
|
|
A simple websocket connection request and response helper
|
|
"""
|
|
url = url.replace("ws:", "http:")
|
|
headers = {
|
|
"connection": "upgrade",
|
|
"upgrade": "websocket",
|
|
"Sec-WebSocket-Key": "x3JJHMbDL1EzLkh9GBhXDw==",
|
|
"Sec-WebSocket-Version": "13",
|
|
}
|
|
async with httpx.AsyncClient() as client:
|
|
return await client.get(url, headers=headers)
|
|
|
|
|
|
async def test_invalid_upgrade(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
def app(scope: Scope):
|
|
return None
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, port=unused_tcp_port)
|
|
async with run_server(config):
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"http://127.0.0.1:{unused_tcp_port}",
|
|
headers={
|
|
"upgrade": "websocket",
|
|
"connection": "upgrade",
|
|
"sec-webSocket-version": "11",
|
|
},
|
|
)
|
|
if response.status_code == 426:
|
|
# response.text == ""
|
|
pass # ok, wsproto 0.13
|
|
else:
|
|
assert response.status_code == 400
|
|
assert response.text.lower().strip().rstrip(".") in [
|
|
"missing sec-websocket-key header",
|
|
"missing sec-websocket-version header", # websockets
|
|
"missing or empty sec-websocket-key header", # wsproto
|
|
"failed to open a websocket connection: missing sec-websocket-key header",
|
|
"failed to open a websocket connection: missing or empty sec-websocket-key header",
|
|
"failed to open a websocket connection: missing sec-websocket-key header; 'sec-websocket-key'",
|
|
]
|
|
|
|
|
|
async def test_accept_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return websocket.open
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert is_open
|
|
|
|
|
|
async def test_shutdown(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config) as server:
|
|
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}"):
|
|
# Attempt shutdown while connection is still open
|
|
await server.shutdown()
|
|
|
|
|
|
async def test_supports_permessage_deflate_extension(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
extension_factories = [ClientPerMessageDeflateFactory()]
|
|
async with websockets.client.connect(url, extensions=extension_factories) as websocket:
|
|
return [extension.name for extension in websocket.extensions]
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
extension_names = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert "permessage-deflate" in extension_names
|
|
|
|
|
|
async def test_can_disable_permessage_deflate_extension(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
# enable per-message deflate on the client, so that we can check the server
|
|
# won't support it when it's disabled.
|
|
extension_factories = [ClientPerMessageDeflateFactory()]
|
|
async with websockets.client.connect(url, extensions=extension_factories) as websocket:
|
|
return [extension.name for extension in websocket.extensions]
|
|
|
|
config = Config(
|
|
app=App,
|
|
ws=ws_protocol_cls,
|
|
http=http_protocol_cls,
|
|
lifespan="off",
|
|
ws_per_message_deflate=False,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
extension_names = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert "permessage-deflate" not in extension_names
|
|
|
|
|
|
async def test_close_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.close"})
|
|
|
|
async def open_connection(url: str):
|
|
try:
|
|
await websockets.client.connect(url)
|
|
except websockets.exceptions.InvalidHandshake:
|
|
return False
|
|
return True # pragma: no cover
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert not is_open
|
|
|
|
|
|
async def test_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
headers = self.scope.get("headers")
|
|
headers = dict(headers) # type: ignore
|
|
assert headers[b"host"].startswith(b"127.0.0.1")
|
|
assert headers[b"username"] == bytes("abraão", "utf-8")
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url, extra_headers=[("username", "abraão")]) as websocket:
|
|
return websocket.open
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert is_open
|
|
|
|
|
|
async def test_extra_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept", "headers": [(b"extra", b"header")]})
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return websocket.response_headers
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
extra_headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert extra_headers.get("extra") == "header"
|
|
|
|
|
|
async def test_path_and_raw_path(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
path = self.scope.get("path")
|
|
raw_path = self.scope.get("raw_path")
|
|
assert path == "/one/two"
|
|
assert raw_path == b"/one%2Ftwo"
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return websocket.open
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}/one%2Ftwo")
|
|
assert is_open
|
|
|
|
|
|
async def test_send_text_data_to_client(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
await self.send({"type": "websocket.send", "text": "123"})
|
|
|
|
async def get_data(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return await websocket.recv()
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
data = await get_data(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert data == "123"
|
|
|
|
|
|
async def test_send_binary_data_to_client(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
await self.send({"type": "websocket.send", "bytes": b"123"})
|
|
|
|
async def get_data(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return await websocket.recv()
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
data = await get_data(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert data == b"123"
|
|
|
|
|
|
async def test_send_and_close_connection(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
await self.send({"type": "websocket.send", "text": "123"})
|
|
await self.send({"type": "websocket.close"})
|
|
|
|
async def get_data(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
data = await websocket.recv()
|
|
is_open = True
|
|
try:
|
|
await websocket.recv()
|
|
except Exception:
|
|
is_open = False
|
|
return (data, is_open)
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
(data, is_open) = await get_data(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert data == "123"
|
|
assert not is_open
|
|
|
|
|
|
async def test_send_text_data_to_server(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def websocket_receive(self, message: WebSocketReceiveEvent):
|
|
_text = message.get("text")
|
|
assert _text is not None
|
|
await self.send({"type": "websocket.send", "text": _text})
|
|
|
|
async def send_text(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
await websocket.send("abc")
|
|
return await websocket.recv()
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert data == "abc"
|
|
|
|
|
|
async def test_send_binary_data_to_server(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def websocket_receive(self, message: WebSocketReceiveEvent):
|
|
_bytes = message.get("bytes")
|
|
assert _bytes is not None
|
|
await self.send({"type": "websocket.send", "bytes": _bytes})
|
|
|
|
async def send_text(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
await websocket.send(b"abc")
|
|
return await websocket.recv()
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert data == b"abc"
|
|
|
|
|
|
async def test_send_after_protocol_close(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
await self.send({"type": "websocket.send", "text": "123"})
|
|
await self.send({"type": "websocket.close"})
|
|
with pytest.raises(Exception):
|
|
await self.send({"type": "websocket.send", "text": "123"})
|
|
|
|
async def get_data(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
data = await websocket.recv()
|
|
is_open = True
|
|
try:
|
|
await websocket.recv()
|
|
except Exception:
|
|
is_open = False
|
|
return (data, is_open)
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
(data, is_open) = await get_data(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert data == "123"
|
|
assert not is_open
|
|
|
|
|
|
async def test_missing_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
pass
|
|
|
|
async def connect(url: str):
|
|
await websockets.client.connect(url)
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
await connect(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert exc_info.value.status_code == 500
|
|
|
|
|
|
async def test_send_before_handshake(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
await send({"type": "websocket.send", "text": "123"})
|
|
|
|
async def connect(url: str):
|
|
await websockets.client.connect(url)
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
await connect(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert exc_info.value.status_code == 500
|
|
|
|
|
|
async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
await send({"type": "websocket.accept"})
|
|
await send({"type": "websocket.accept"})
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
with pytest.raises(websockets.exceptions.ConnectionClosed):
|
|
_ = await websocket.recv()
|
|
assert websocket.close_code == 1006
|
|
|
|
|
|
async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
"""
|
|
The ASGI callable should return 'None'. If it doesn't, make sure that
|
|
the connection is closed with an error condition.
|
|
"""
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
await send({"type": "websocket.accept"})
|
|
return 123
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
with pytest.raises(websockets.exceptions.ConnectionClosed):
|
|
_ = await websocket.recv()
|
|
assert websocket.close_code == 1006
|
|
|
|
|
|
async def test_close_transport_on_asgi_return(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
"""The ASGI callable should call the `websocket.close` event.
|
|
|
|
If it doesn't, the server should still send a close frame to the client.
|
|
"""
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
with pytest.raises(websockets.exceptions.ConnectionClosed):
|
|
await websocket.recv()
|
|
assert websocket.close_code == 1006
|
|
|
|
|
|
@pytest.mark.parametrize("code", [None, 1000, 1001])
|
|
@pytest.mark.parametrize("reason", [None, "test", False], ids=["none_as_reason", "normal_reason", "without_reason"])
|
|
async def test_app_close(
|
|
ws_protocol_cls: WSProtocol,
|
|
http_protocol_cls: HTTPProtocol,
|
|
unused_tcp_port: int,
|
|
code: int | None,
|
|
reason: str | None,
|
|
):
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
while True:
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
elif message["type"] == "websocket.receive":
|
|
reply: WebSocketCloseEvent = {"type": "websocket.close"}
|
|
|
|
if code is not None:
|
|
reply["code"] = code
|
|
|
|
if reason is not False:
|
|
reply["reason"] = reason
|
|
|
|
await send(reply)
|
|
elif message["type"] == "websocket.disconnect":
|
|
break
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
await websocket.ping()
|
|
await websocket.send("abc")
|
|
with pytest.raises(websockets.exceptions.ConnectionClosed):
|
|
await websocket.recv()
|
|
assert websocket.close_code == (code or 1000)
|
|
assert websocket.close_reason == (reason or "")
|
|
|
|
|
|
async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
disconnect_message: WebSocketDisconnectEvent | None = None
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal disconnect_message
|
|
while True:
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
elif message["type"] == "websocket.receive":
|
|
pass
|
|
elif message["type"] == "websocket.disconnect":
|
|
disconnect_message = message
|
|
break
|
|
|
|
async def websocket_session(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
await websocket.ping()
|
|
await websocket.send("abc")
|
|
await websocket.close(code=1001, reason="custom reason")
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
assert disconnect_message == {"type": "websocket.disconnect", "code": 1001, "reason": "custom reason"}
|
|
|
|
|
|
async def test_client_connection_lost(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
got_disconnect_event = False
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal got_disconnect_event
|
|
while True:
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
elif message["type"] == "websocket.disconnect":
|
|
break
|
|
|
|
got_disconnect_event = True
|
|
|
|
config = Config(
|
|
app=app,
|
|
ws=ws_protocol_cls,
|
|
http=http_protocol_cls,
|
|
lifespan="off",
|
|
ws_ping_interval=0.0,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
websocket.transport.close()
|
|
await asyncio.sleep(0.1)
|
|
got_disconnect_event_before_shutdown = got_disconnect_event
|
|
|
|
assert got_disconnect_event_before_shutdown is True
|
|
|
|
|
|
async def test_client_connection_lost_on_send(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
disconnect = asyncio.Event()
|
|
got_disconnect_event = False
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal got_disconnect_event
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
try:
|
|
await disconnect.wait()
|
|
await send({"type": "websocket.send", "text": "123"})
|
|
except OSError:
|
|
got_disconnect_event = True
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
url = f"ws://127.0.0.1:{unused_tcp_port}"
|
|
async with websockets.client.connect(url):
|
|
await asyncio.sleep(0.1)
|
|
disconnect.set()
|
|
|
|
assert got_disconnect_event is True
|
|
|
|
|
|
async def test_connection_lost_before_handshake_complete(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
send_accept_task = asyncio.Event()
|
|
disconnect_message: WebSocketDisconnectEvent = {} # type: ignore
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal disconnect_message
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send_accept_task.wait()
|
|
disconnect_message = await receive() # type: ignore
|
|
|
|
async def websocket_session(uri: str):
|
|
async with httpx.AsyncClient() as client:
|
|
await client.get(
|
|
f"http://127.0.0.1:{unused_tcp_port}",
|
|
headers={
|
|
"upgrade": "websocket",
|
|
"connection": "upgrade",
|
|
"sec-websocket-version": "13",
|
|
"sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==",
|
|
},
|
|
)
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
task = asyncio.create_task(websocket_session(f"ws://127.0.0.1:{unused_tcp_port}"))
|
|
await asyncio.sleep(0.1)
|
|
send_accept_task.set()
|
|
await asyncio.sleep(0.1)
|
|
|
|
assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}
|
|
await task
|
|
|
|
|
|
async def test_send_close_on_server_shutdown(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
disconnect_message: WebSocketDisconnectEvent = {} # type: ignore
|
|
server_shutdown_event = asyncio.Event()
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal disconnect_message
|
|
while True:
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
elif message["type"] == "websocket.disconnect":
|
|
disconnect_message = message
|
|
break
|
|
|
|
websocket: websockets.client.WebSocketClientProtocol | None = None
|
|
|
|
async def websocket_session(uri: str):
|
|
nonlocal websocket
|
|
async with websockets.client.connect(uri) as ws_connection:
|
|
websocket = ws_connection
|
|
await server_shutdown_event.wait()
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
task = asyncio.create_task(websocket_session(f"ws://127.0.0.1:{unused_tcp_port}"))
|
|
await asyncio.sleep(0.1)
|
|
disconnect_message_before_shutdown = disconnect_message
|
|
server_shutdown_event.set()
|
|
|
|
assert websocket is not None
|
|
assert websocket.close_code == 1012
|
|
assert disconnect_message_before_shutdown == {}
|
|
assert disconnect_message == {"type": "websocket.disconnect", "code": 1012}
|
|
task.cancel()
|
|
|
|
|
|
@pytest.mark.parametrize("subprotocol", ["proto1", "proto2"])
|
|
async def test_subprotocols(
|
|
ws_protocol_cls: WSProtocol,
|
|
http_protocol_cls: HTTPProtocol,
|
|
subprotocol: str,
|
|
unused_tcp_port: int,
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
|
|
|
|
async def get_subprotocol(url: str):
|
|
async with websockets.client.connect(
|
|
url, subprotocols=[Subprotocol("proto1"), Subprotocol("proto2")]
|
|
) as websocket:
|
|
return websocket.subprotocol
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
accepted_subprotocol = await get_subprotocol(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert accepted_subprotocol == subprotocol
|
|
|
|
|
|
MAX_WS_BYTES = 1024 * 1024 * 16
|
|
MAX_WS_BYTES_PLUS1 = MAX_WS_BYTES + 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"client_size_sent, server_size_max, expected_result",
|
|
[
|
|
(MAX_WS_BYTES, MAX_WS_BYTES, 0),
|
|
(MAX_WS_BYTES_PLUS1, MAX_WS_BYTES, 1009),
|
|
(10, 10, 0),
|
|
(11, 10, 1009),
|
|
],
|
|
ids=[
|
|
"max=defaults sent=defaults",
|
|
"max=defaults sent=defaults+1",
|
|
"max=10 sent=10",
|
|
"max=10 sent=11",
|
|
],
|
|
)
|
|
async def test_send_binary_data_to_server_bigger_than_default_on_websockets(
|
|
http_protocol_cls: HTTPProtocol,
|
|
client_size_sent: int,
|
|
server_size_max: int,
|
|
expected_result: int,
|
|
unused_tcp_port: int,
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def websocket_receive(self, message: WebSocketReceiveEvent):
|
|
_bytes = message.get("bytes")
|
|
assert _bytes is not None
|
|
await self.send({"type": "websocket.send", "bytes": _bytes})
|
|
|
|
config = Config(
|
|
app=App,
|
|
ws=WebSocketProtocol,
|
|
http=http_protocol_cls,
|
|
lifespan="off",
|
|
ws_max_size=server_size_max,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}", max_size=client_size_sent) as ws:
|
|
await ws.send(b"\x01" * client_size_sent)
|
|
if expected_result == 0:
|
|
data = await ws.recv()
|
|
assert data == b"\x01" * client_size_sent
|
|
else:
|
|
with pytest.raises(websockets.exceptions.ConnectionClosedError):
|
|
await ws.recv()
|
|
assert ws.close_code == expected_result
|
|
|
|
|
|
async def test_fragmented_message_exceeding_max_size(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
"""Stream non-FIN fragments past `ws_max_size` - the server must close with 1009."""
|
|
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
config = Config(
|
|
app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", ws_max_size=2048, port=unused_tcp_port
|
|
)
|
|
async with run_server(config):
|
|
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}") as ws:
|
|
payload = b"A" * 1024
|
|
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
|
|
await ws.write_frame(False, Opcode.BINARY, payload)
|
|
for _ in range(63): # 64 KiB total, well past 2 KiB budget
|
|
await ws.write_frame(False, Opcode.CONT, payload)
|
|
await ws.recv()
|
|
assert exc_info.value.rcvd is not None
|
|
assert exc_info.value.rcvd.code == 1009
|
|
|
|
|
|
async def test_server_reject_connection(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
disconnected_message: ASGIReceiveEvent = {} # type: ignore
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal disconnected_message
|
|
assert scope["type"] == "websocket"
|
|
|
|
# Pull up first recv message.
|
|
message = await receive()
|
|
assert message["type"] == "websocket.connect"
|
|
|
|
# Reject the connection.
|
|
await send({"type": "websocket.close"})
|
|
# -- At this point websockets' recv() is unusable. --
|
|
|
|
# This doesn't raise `TypeError`:
|
|
# See https://github.com/Kludex/uvicorn/issues/244
|
|
disconnected_message = await receive()
|
|
|
|
async def websocket_session(url: str):
|
|
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
async with websockets.client.connect(url):
|
|
pass # pragma: no cover
|
|
assert exc_info.value.status_code == 403
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
|
|
|
|
|
|
class EmptyDict(TypedDict): ...
|
|
|
|
|
|
async def test_server_reject_connection_with_response(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
disconnected_message: WebSocketDisconnectEvent | EmptyDict = {}
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal disconnected_message
|
|
assert scope["type"] == "websocket"
|
|
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
|
|
|
|
# Pull up first recv message.
|
|
message = await receive()
|
|
assert message["type"] == "websocket.connect"
|
|
|
|
# Reject the connection with a response
|
|
response = Response(b"goodbye", status_code=400)
|
|
await response(scope, receive, send)
|
|
disconnected_message = await receive()
|
|
|
|
async def websocket_session(url: str):
|
|
response = await wsresponse(url)
|
|
assert response.status_code == 400
|
|
assert response.content == b"goodbye"
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
|
|
|
|
|
|
async def test_server_reject_connection_with_multibody_response(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
disconnected_message: ASGIReceiveEvent = {} # type: ignore
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal disconnected_message
|
|
assert scope["type"] == "websocket"
|
|
assert "extensions" in scope
|
|
assert "websocket.http.response" in scope["extensions"]
|
|
|
|
# Pull up first recv message.
|
|
message = await receive()
|
|
assert message["type"] == "websocket.connect"
|
|
await send(
|
|
{
|
|
"type": "websocket.http.response.start",
|
|
"status": 400,
|
|
"headers": [
|
|
(b"Content-Length", b"20"),
|
|
(b"Content-Type", b"text/plain"),
|
|
],
|
|
}
|
|
)
|
|
await send({"type": "websocket.http.response.body", "body": b"x" * 10, "more_body": True})
|
|
await send({"type": "websocket.http.response.body", "body": b"y" * 10})
|
|
disconnected_message = await receive()
|
|
|
|
async def websocket_session(url: str):
|
|
response = await wsresponse(url)
|
|
assert response.status_code == 400
|
|
assert response.content == (b"x" * 10) + (b"y" * 10)
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
|
|
|
|
|
|
async def test_server_reject_connection_with_invalid_status(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
# this test checks that even if there is an error in the response, the server
|
|
# can successfully send a 500 error back to the client
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
assert scope["type"] == "websocket"
|
|
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
|
|
|
|
# Pull up first recv message.
|
|
message = await receive()
|
|
assert message["type"] == "websocket.connect"
|
|
|
|
await send(
|
|
{
|
|
"type": "websocket.http.response.start",
|
|
"status": 700, # invalid status code
|
|
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
|
|
}
|
|
)
|
|
|
|
async def websocket_session(url: str):
|
|
response = await wsresponse(url)
|
|
assert response.status_code == 500
|
|
assert response.content == b"Internal Server Error"
|
|
assert response.headers["content-length"] == "21"
|
|
assert response.headers["connection"] == "close"
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
|
|
async def test_server_reject_connection_with_body_nolength(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
# test that the server can send a response with a body but no content-length
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
assert scope["type"] == "websocket"
|
|
assert "extensions" in scope
|
|
assert "websocket.http.response" in scope["extensions"]
|
|
|
|
# Pull up first recv message.
|
|
message = await receive()
|
|
assert message["type"] == "websocket.connect"
|
|
|
|
await send({"type": "websocket.http.response.start", "status": 403, "headers": []})
|
|
await send({"type": "websocket.http.response.body", "body": b"hardbody"})
|
|
|
|
async def websocket_session(url: str):
|
|
response = await wsresponse(url)
|
|
assert response.status_code == 403
|
|
assert response.content == b"hardbody"
|
|
if ws_protocol_cls == _WSProtocol:
|
|
# wsproto automatically makes the message chunked
|
|
assert response.headers["transfer-encoding"] == "chunked"
|
|
else:
|
|
# websockets automatically adds a content-length
|
|
assert response.headers["content-length"] == "8"
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
|
|
async def test_server_reject_connection_with_invalid_msg(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
|
|
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
assert scope["type"] == "websocket"
|
|
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
|
|
|
|
# Pull up first recv message.
|
|
message_rcvd = await receive()
|
|
assert message_rcvd["type"] == "websocket.connect"
|
|
|
|
message: WebSocketResponseStartEvent = {
|
|
"type": "websocket.http.response.start",
|
|
"status": 404,
|
|
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
|
|
}
|
|
await send(message)
|
|
# send invalid message. This will raise an exception here
|
|
await send(message)
|
|
|
|
async def websocket_session(url: str):
|
|
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
async with websockets.client.connect(url):
|
|
pass # pragma: no cover
|
|
assert exc_info.value.status_code == 404
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
|
|
async def test_server_reject_connection_with_missing_body(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
|
|
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
assert scope["type"] == "websocket"
|
|
assert "extensions" in scope and "websocket.http.response" in scope["extensions"]
|
|
|
|
# Pull up first recv message.
|
|
message = await receive()
|
|
assert message["type"] == "websocket.connect"
|
|
|
|
await send(
|
|
{
|
|
"type": "websocket.http.response.start",
|
|
"status": 404,
|
|
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
|
|
}
|
|
)
|
|
# no further message
|
|
|
|
async def websocket_session(url: str):
|
|
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
async with websockets.client.connect(url):
|
|
pass # pragma: no cover
|
|
assert exc_info.value.status_code == 404
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
|
|
async def test_server_multiple_websocket_http_response_start_events(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
"""
|
|
The server should raise an exception if it sends multiple
|
|
websocket.http.response.start events.
|
|
"""
|
|
if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol":
|
|
pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.")
|
|
exception_message: str | None = None
|
|
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
nonlocal exception_message
|
|
assert scope["type"] == "websocket"
|
|
assert "extensions" in scope
|
|
assert "websocket.http.response" in scope["extensions"]
|
|
|
|
# Pull up first recv message.
|
|
message = await receive()
|
|
assert message["type"] == "websocket.connect"
|
|
|
|
start_event: WebSocketResponseStartEvent = {
|
|
"type": "websocket.http.response.start",
|
|
"status": 404,
|
|
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
|
|
}
|
|
await send(start_event)
|
|
try:
|
|
await send(start_event)
|
|
except Exception as exc:
|
|
exception_message = str(exc)
|
|
|
|
async def websocket_session(url: str):
|
|
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
async with websockets.client.connect(url):
|
|
pass # pragma: no cover
|
|
assert exc_info.value.status_code == 404
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
assert exception_message == (
|
|
"Expected ASGI message 'websocket.http.response.body' but got 'websocket.http.response.start'."
|
|
)
|
|
|
|
|
|
async def test_server_can_read_messages_in_buffer_after_close(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
frames: list[bytes] = []
|
|
disconnect_message: WebSocketDisconnectEvent | EmptyDict = {}
|
|
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
# Ensure server doesn't start reading frames from read buffer until
|
|
# after client has sent close frame, but server is still able to
|
|
# read these frames
|
|
await asyncio.sleep(0.2)
|
|
|
|
async def websocket_disconnect(self, message: WebSocketDisconnectEvent):
|
|
nonlocal disconnect_message
|
|
disconnect_message = message
|
|
|
|
async def websocket_receive(self, message: WebSocketReceiveEvent):
|
|
_bytes = message.get("bytes")
|
|
assert _bytes is not None
|
|
frames.append(_bytes)
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
await websocket.send(b"abc")
|
|
await websocket.send(b"abc")
|
|
await websocket.send(b"abc")
|
|
|
|
assert frames == [b"abc", b"abc", b"abc"]
|
|
assert disconnect_message == {"type": "websocket.disconnect", "code": 1000, "reason": ""}
|
|
|
|
|
|
async def test_default_server_headers(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return websocket.response_headers
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert headers.get("server") == "uvicorn" and "date" in headers
|
|
|
|
|
|
async def test_no_server_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return websocket.response_headers
|
|
|
|
config = Config(
|
|
app=App,
|
|
ws=ws_protocol_cls,
|
|
http=http_protocol_cls,
|
|
lifespan="off",
|
|
server_header=False,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert "server" not in headers
|
|
|
|
|
|
@skip_if_no_wsproto
|
|
async def test_no_date_header_on_wsproto(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return websocket.response_headers
|
|
|
|
config = Config(
|
|
app=App,
|
|
ws=_WSProtocol,
|
|
http=http_protocol_cls,
|
|
lifespan="off",
|
|
date_header=False,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert "date" not in headers
|
|
|
|
|
|
async def test_multiple_server_header(
|
|
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
await self.send(
|
|
{
|
|
"type": "websocket.accept",
|
|
"headers": [
|
|
(b"Server", b"over-ridden"),
|
|
(b"Server", b"another-value"),
|
|
],
|
|
}
|
|
)
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return websocket.response_headers
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"]
|
|
|
|
|
|
async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
expected_states: list[dict[str, Any]] = [
|
|
{"a": 123, "b": [1]},
|
|
{"a": 123, "b": [1, 2]},
|
|
]
|
|
|
|
actual_states: list[dict[str, Any]] = []
|
|
|
|
async def lifespan_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
message = await receive()
|
|
assert message["type"] == "lifespan.startup" and "state" in scope
|
|
scope["state"]["a"] = 123
|
|
scope["state"]["b"] = [1]
|
|
await send({"type": "lifespan.startup.complete"})
|
|
message = await receive()
|
|
assert message["type"] == "lifespan.shutdown"
|
|
await send({"type": "lifespan.shutdown.complete"})
|
|
|
|
class App(WebSocketResponse):
|
|
async def websocket_connect(self, message: WebSocketConnectEvent):
|
|
assert "state" in self.scope
|
|
actual_states.append(deepcopy(self.scope["state"]))
|
|
self.scope["state"]["a"] = 456
|
|
self.scope["state"]["b"].append(2)
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
async with websockets.client.connect(url) as websocket:
|
|
return websocket.open
|
|
|
|
async def app_wrapper(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
if scope["type"] == "lifespan":
|
|
return await lifespan_app(scope, receive, send)
|
|
return await App(scope, receive, send)
|
|
|
|
config = Config(app=app_wrapper, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="on", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert is_open
|
|
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
assert is_open
|
|
|
|
assert expected_states == actual_states
|
|
|
|
|
|
@pytest.fixture(
|
|
params=[
|
|
pytest.param(
|
|
"uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
|
|
marks=skip_if_no_wsproto,
|
|
id="wsproto",
|
|
),
|
|
pytest.param(
|
|
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio"
|
|
),
|
|
]
|
|
)
|
|
def keepalive_ws_protocol_cls(request: pytest.FixtureRequest):
|
|
from uvicorn.importer import import_from_string
|
|
|
|
return import_from_string(request.param)
|
|
|
|
|
|
async def test_server_keepalive_ping_pong(
|
|
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
while True:
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
elif message["type"] == "websocket.disconnect":
|
|
break
|
|
|
|
config = Config(
|
|
app=app,
|
|
ws=keepalive_ws_protocol_cls,
|
|
http=http_protocol_cls,
|
|
lifespan="off",
|
|
ws_ping_interval=0.1,
|
|
ws_ping_timeout=5.0,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config) as server:
|
|
# The websockets client auto-responds to ping frames, keeping the connection alive.
|
|
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
|
|
protocol = list(server.server_state.connections)[0]
|
|
assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol))
|
|
|
|
# Wait until the server sends at least one keepalive ping, then
|
|
# sleep past the timeout window and ensure the connection stays open.
|
|
# This verifies that the client answered the ping without depending
|
|
# on clock granularity for the measured RTT.
|
|
async def ping_sent() -> None:
|
|
while protocol.ping_sent_at == 0.0:
|
|
await asyncio.sleep(0.05)
|
|
|
|
await asyncio.wait_for(ping_sent(), timeout=5.0)
|
|
await asyncio.sleep(0.2)
|
|
assert not protocol.transport.is_closing()
|
|
|
|
|
|
async def test_server_keepalive_ping_timeout(
|
|
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
while True:
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
elif message["type"] == "websocket.disconnect":
|
|
break
|
|
|
|
config = Config(
|
|
app=app,
|
|
ws=keepalive_ws_protocol_cls,
|
|
http=http_protocol_cls,
|
|
lifespan="off",
|
|
ws_ping_interval=0.1,
|
|
ws_ping_timeout=0.1,
|
|
log_level="trace",
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None) as websocket:
|
|
# Swallow outgoing pong frames so the server's ping never gets ack'd.
|
|
websocket.transport.write = lambda data: None # type: ignore[method-assign]
|
|
with pytest.raises(websockets.exceptions.ConnectionClosedError) as exc_info:
|
|
await asyncio.wait_for(websocket.recv(), timeout=1)
|
|
assert exc_info.value.rcvd is not None
|
|
assert exc_info.value.rcvd.code == 1011
|
|
assert exc_info.value.rcvd.reason == "keepalive ping timeout"
|
|
|
|
|
|
async def test_server_keepalive_disabled(
|
|
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
|
):
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
while True:
|
|
message = await receive()
|
|
if message["type"] == "websocket.connect":
|
|
await send({"type": "websocket.accept"})
|
|
elif message["type"] == "websocket.disconnect":
|
|
break
|
|
|
|
config = Config(
|
|
app=app,
|
|
ws=keepalive_ws_protocol_cls,
|
|
http=http_protocol_cls,
|
|
lifespan="off",
|
|
ws_ping_interval=None,
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config) as server:
|
|
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
|
|
protocol = list(server.server_state.connections)[0]
|
|
assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol))
|
|
assert protocol.ping_timer is None
|