Drop cast in ASGI types (#2875)

This commit is contained in:
Marcelo Trylesinski 2026-03-28 11:20:47 +01:00 committed by GitHub
parent 1cb8e747e2
commit 5211880320
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 117 deletions

View File

@ -5,7 +5,7 @@ import contextvars
import http
import logging
from collections.abc import Callable
from typing import Any, Literal, cast
from typing import Any, Literal
from urllib.parse import unquote
import h11
@ -452,8 +452,6 @@ class RequestResponseCycle:
# ASGI interface
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain() # pragma: full coverage
@ -462,10 +460,8 @@ class RequestResponseCycle:
if not self.response_started:
# Sending response status line and headers
if message_type != "http.response.start":
msg = "Expected ASGI message 'http.response.start', but got '%s'."
raise RuntimeError(msg % message_type)
message = cast("HTTPResponseStartEvent", message)
if message["type"] != "http.response.start":
raise RuntimeError(f"Expected ASGI message 'http.response.start', but got '{message['type']}'.")
self.response_started = True
self.waiting_for_100_continue = False
@ -494,10 +490,8 @@ class RequestResponseCycle:
elif not self.response_complete:
# Sending response body
if message_type != "http.response.body":
msg = "Expected ASGI message 'http.response.body', but got '%s'."
raise RuntimeError(msg % message_type)
message = cast("HTTPResponseBodyEvent", message)
if message["type"] != "http.response.body":
raise RuntimeError(f"Expected ASGI message 'http.response.body', but got '{message['type']}'.")
body = message.get("body", b"")
more_body = message.get("more_body", False)
@ -516,8 +510,7 @@ class RequestResponseCycle:
else:
# Response already sent
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message_type)
raise RuntimeError(f"Unexpected ASGI message '{message['type']}' sent, after response already completed.")
if self.response_complete:
if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive:
@ -541,10 +534,6 @@ class RequestResponseCycle:
if self.disconnected or self.response_complete:
return {"type": "http.disconnect"}
message: HTTPRequestEvent = {
"type": "http.request",
"body": bytes(self.body),
"more_body": self.more_body,
}
message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body}
self.body = bytearray()
return message

View File

@ -9,7 +9,7 @@ import urllib
from asyncio.events import TimerHandle
from collections import deque
from collections.abc import Callable
from typing import Any, Literal, cast
from typing import Any, Literal
import httptools
@ -18,7 +18,6 @@ from uvicorn._types import (
ASGIReceiveEvent,
ASGISendEvent,
HTTPRequestEvent,
HTTPResponseStartEvent,
HTTPScope,
)
from uvicorn.config import Config
@ -455,8 +454,6 @@ class RequestResponseCycle:
# ASGI interface
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if self.flow.write_paused and not self.disconnected:
await self.flow.drain() # pragma: full coverage
@ -465,10 +462,8 @@ class RequestResponseCycle:
if not self.response_started:
# Sending response status line and headers
if message_type != "http.response.start":
msg = "Expected ASGI message 'http.response.start', but got '%s'."
raise RuntimeError(msg % message_type)
message = cast("HTTPResponseStartEvent", message)
if message["type"] != "http.response.start":
raise RuntimeError(f"Expected ASGI message 'http.response.start', but got '{message['type']}'.")
self.response_started = True
self.waiting_for_100_continue = False
@ -519,11 +514,10 @@ class RequestResponseCycle:
elif not self.response_complete:
# Sending response body
if message_type != "http.response.body":
msg = "Expected ASGI message 'http.response.body', but got '%s'."
raise RuntimeError(msg % message_type)
if message["type"] != "http.response.body":
raise RuntimeError(f"Expected ASGI message 'http.response.body', but got '{message['type']}'.")
body = cast(bytes, message.get("body", b""))
body = message.get("body", b"")
more_body = message.get("more_body", False)
# Write response body
@ -557,8 +551,7 @@ class RequestResponseCycle:
else:
# Response already sent
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message_type)
raise RuntimeError(f"Unexpected ASGI message '{message['type']}' sent, after response already completed.")
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():

View File

@ -20,15 +20,10 @@ from websockets.typing import Subprotocol
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketConnectEvent,
WebSocketDisconnectEvent,
WebSocketReceiveEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
@ -262,11 +257,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.transport.close()
async def asgi_send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
if not self.handshake_started_event.is_set():
if message_type == "websocket.accept":
message = cast("WebSocketAcceptEvent", message)
if message["type"] == "websocket.accept":
self.logger.info(
'%s - "WebSocket %s" [accepted]',
get_client_addr(self.scope),
@ -283,8 +275,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
)
self.handshake_started_event.set()
elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
elif message["type"] == "websocket.close":
self.logger.info(
'%s - "WebSocket %s" 403',
get_client_addr(self.scope),
@ -294,8 +285,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.handshake_started_event.set()
self.closed_event.set()
elif message_type == "websocket.http.response.start":
message = cast("WebSocketResponseStartEvent", message)
elif message["type"] == "websocket.http.response.start":
self.logger.info(
'%s - "WebSocket %s" %d',
get_client_addr(self.scope),
@ -311,50 +301,48 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.handshake_started_event.set()
else:
msg = (
raise RuntimeError(
"Expected ASGI message 'websocket.accept', 'websocket.close', "
"or 'websocket.http.response.start' but got '%s'."
f"or 'websocket.http.response.start' but got '{message['type']}'."
)
raise RuntimeError(msg % message_type)
elif not self.closed_event.is_set() and self.initial_response is None:
await self.handshake_completed_event.wait()
try:
if message_type == "websocket.send":
message = cast("WebSocketSendEvent", message)
if message["type"] == "websocket.send":
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
await self.send(data) # type: ignore[arg-type]
elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
elif message["type"] == "websocket.close":
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
await self.close(code, reason)
self.closed_event.set()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
raise RuntimeError(msg % message_type)
raise RuntimeError(
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
)
except ConnectionClosed as exc:
raise ClientDisconnected from exc
elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
message = cast("WebSocketResponseBodyEvent", message)
if message["type"] == "websocket.http.response.body":
body = self.initial_response[2] + message["body"]
self.initial_response = self.initial_response[:2] + (body,)
if not message.get("more_body", False):
self.closed_event.set()
else:
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
raise RuntimeError(msg % message_type)
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' or response already completed."
raise RuntimeError(msg % message_type)
raise RuntimeError(
f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close' "
"or response already completed."
)
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
if not self.connect_sent:

View File

@ -17,12 +17,7 @@ from websockets.server import ServerProtocol
from uvicorn._types import (
ASGIReceiveEvent,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
@ -295,11 +290,8 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()
message_type = message["type"]
if not self.handshake_complete and self.initial_response is None:
if message_type == "websocket.accept":
message = cast(WebSocketAcceptEvent, message)
if message["type"] == "websocket.accept":
self.logger.info(
'%s - "WebSocket %s" [accepted]',
get_client_addr(self.scope),
@ -320,8 +312,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
elif message_type == "websocket.close":
message = cast(WebSocketCloseEvent, message)
elif message["type"] == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
@ -335,8 +326,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.handshake_complete = True
self.transport.write(b"".join(output))
self.transport.close()
elif message_type == "websocket.http.response.start" and self.initial_response is None:
message = cast(WebSocketResponseStartEvent, message)
elif message["type"] == "websocket.http.response.start" and self.initial_response is None:
if not (100 <= message["status"] < 600):
raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"])
self.logger.info(
@ -351,17 +341,14 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
]
self.initial_response = (message["status"], headers, b"")
else:
msg = (
raise RuntimeError(
"Expected ASGI message 'websocket.accept', 'websocket.close' "
"or 'websocket.http.response.start' "
"but got '%s'."
f"or 'websocket.http.response.start' but got '{message['type']}'."
)
raise RuntimeError(msg % message_type)
elif not self.close_sent and self.initial_response is None:
try:
if message_type == "websocket.send":
message = cast(WebSocketSendEvent, message)
if message["type"] == "websocket.send":
bytes_data = message.get("bytes")
text_data = message.get("text")
if bytes_data is not None:
@ -371,9 +358,8 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
elif message_type == "websocket.close":
elif message["type"] == "websocket.close":
if not self.transport.is_closing():
message = cast(WebSocketCloseEvent, message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
@ -383,13 +369,13 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.close_sent = True
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
raise RuntimeError(msg % message_type)
raise RuntimeError(
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
)
except InvalidState:
raise ClientDisconnected()
elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
message = cast(WebSocketResponseBodyEvent, message)
if message["type"] == "websocket.http.response.body":
body = self.initial_response[2] + message["body"]
self.initial_response = self.initial_response[:2] + (body,)
if not message.get("more_body", False):
@ -402,12 +388,10 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
self.transport.write(b"".join(output))
self.transport.close()
else: # pragma: no cover
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
raise RuntimeError(msg % message_type)
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)
raise RuntimeError(f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close'.")
async def receive(self) -> ASGIReceiveEvent:
message = await self.queue.get()

View File

@ -14,13 +14,8 @@ from wsproto.utilities import LocalProtocolError, RemoteProtocolError
from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
@ -249,11 +244,8 @@ class WSProtocol(asyncio.Protocol):
async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()
message_type = message["type"]
if not self.handshake_complete:
if message_type == "websocket.accept":
message = cast(WebSocketAcceptEvent, message)
if message["type"] == "websocket.accept":
self.logger.info(
'%s - "WebSocket %s" [accepted]',
get_client_addr(self.scope),
@ -275,7 +267,7 @@ class WSProtocol(asyncio.Protocol):
)
self.transport.write(output)
elif message_type == "websocket.close":
elif message["type"] == "websocket.close":
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
@ -289,8 +281,7 @@ class WSProtocol(asyncio.Protocol):
self.transport.write(output)
self.transport.close()
elif message_type == "websocket.http.response.start":
message = cast(WebSocketResponseStartEvent, message)
elif message["type"] == "websocket.http.response.start":
# ensure status code is in the valid range
if not (100 <= message["status"] < 600):
msg = "Invalid HTTP status code '%d' in response."
@ -312,17 +303,14 @@ class WSProtocol(asyncio.Protocol):
self.response_started = True
else:
msg = (
raise RuntimeError(
"Expected ASGI message 'websocket.accept', 'websocket.close' "
"or 'websocket.http.response.start' "
"but got '%s'."
f"or 'websocket.http.response.start' but got '{message['type']}'."
)
raise RuntimeError(msg % message_type)
elif not self.close_sent and not self.response_started:
try:
if message_type == "websocket.send":
message = cast(WebSocketSendEvent, message)
if message["type"] == "websocket.send":
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
@ -330,8 +318,7 @@ class WSProtocol(asyncio.Protocol):
if not self.transport.is_closing():
self.transport.write(output)
elif message_type == "websocket.close":
message = cast(WebSocketCloseEvent, message)
elif message["type"] == "websocket.close":
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
@ -342,13 +329,13 @@ class WSProtocol(asyncio.Protocol):
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
raise RuntimeError(msg % message_type)
raise RuntimeError(
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
)
except LocalProtocolError as exc:
raise ClientDisconnected from exc
elif self.response_started:
if message_type == "websocket.http.response.body":
message = cast("WebSocketResponseBodyEvent", message)
if message["type"] == "websocket.http.response.body":
body_finished = not message.get("more_body", False)
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
output = self.conn.send(reject_data)
@ -360,12 +347,10 @@ class WSProtocol(asyncio.Protocol):
self.transport.close()
else:
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
raise RuntimeError(msg % message_type)
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)
raise RuntimeError(f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close'.")
async def receive(self) -> WebSocketEvent:
message = await self.queue.get()