Drop cast in ASGI types (#2875)
This commit is contained in:
parent
1cb8e747e2
commit
5211880320
@ -5,7 +5,7 @@ import contextvars
|
||||
import http
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import unquote
|
||||
|
||||
import h11
|
||||
@ -452,8 +452,6 @@ class RequestResponseCycle:
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
@ -462,10 +460,8 @@ class RequestResponseCycle:
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
if message_type != "http.response.start":
|
||||
msg = "Expected ASGI message 'http.response.start', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseStartEvent", message)
|
||||
if message["type"] != "http.response.start":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.start', but got '{message['type']}'.")
|
||||
|
||||
self.response_started = True
|
||||
self.waiting_for_100_continue = False
|
||||
@ -494,10 +490,8 @@ class RequestResponseCycle:
|
||||
|
||||
elif not self.response_complete:
|
||||
# Sending response body
|
||||
if message_type != "http.response.body":
|
||||
msg = "Expected ASGI message 'http.response.body', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseBodyEvent", message)
|
||||
if message["type"] != "http.response.body":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.body', but got '{message['type']}'.")
|
||||
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
@ -516,8 +510,7 @@ class RequestResponseCycle:
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
msg = "Unexpected ASGI message '%s' sent, after response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}' sent, after response already completed.")
|
||||
|
||||
if self.response_complete:
|
||||
if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive:
|
||||
@ -541,10 +534,6 @@ class RequestResponseCycle:
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
message: HTTPRequestEvent = {
|
||||
"type": "http.request",
|
||||
"body": bytes(self.body),
|
||||
"more_body": self.more_body,
|
||||
}
|
||||
message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body}
|
||||
self.body = bytearray()
|
||||
return message
|
||||
|
||||
@ -9,7 +9,7 @@ import urllib
|
||||
from asyncio.events import TimerHandle
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
import httptools
|
||||
|
||||
@ -18,7 +18,6 @@ from uvicorn._types import (
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
@ -455,8 +454,6 @@ class RequestResponseCycle:
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
@ -465,10 +462,8 @@ class RequestResponseCycle:
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
if message_type != "http.response.start":
|
||||
msg = "Expected ASGI message 'http.response.start', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
message = cast("HTTPResponseStartEvent", message)
|
||||
if message["type"] != "http.response.start":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.start', but got '{message['type']}'.")
|
||||
|
||||
self.response_started = True
|
||||
self.waiting_for_100_continue = False
|
||||
@ -519,11 +514,10 @@ class RequestResponseCycle:
|
||||
|
||||
elif not self.response_complete:
|
||||
# Sending response body
|
||||
if message_type != "http.response.body":
|
||||
msg = "Expected ASGI message 'http.response.body', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
if message["type"] != "http.response.body":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.body', but got '{message['type']}'.")
|
||||
|
||||
body = cast(bytes, message.get("body", b""))
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# Write response body
|
||||
@ -557,8 +551,7 @@ class RequestResponseCycle:
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
msg = "Unexpected ASGI message '%s' sent, after response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}' sent, after response already completed.")
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
|
||||
@ -20,15 +20,10 @@ from websockets.typing import Subprotocol
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketConnectEvent,
|
||||
WebSocketDisconnectEvent,
|
||||
WebSocketReceiveEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
@ -262,11 +257,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
self.transport.close()
|
||||
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_started_event.is_set():
|
||||
if message_type == "websocket.accept":
|
||||
message = cast("WebSocketAcceptEvent", message)
|
||||
if message["type"] == "websocket.accept":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
get_client_addr(self.scope),
|
||||
@ -283,8 +275,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
)
|
||||
self.handshake_started_event.set()
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
elif message["type"] == "websocket.close":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
get_client_addr(self.scope),
|
||||
@ -294,8 +285,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
self.handshake_started_event.set()
|
||||
self.closed_event.set()
|
||||
|
||||
elif message_type == "websocket.http.response.start":
|
||||
message = cast("WebSocketResponseStartEvent", message)
|
||||
elif message["type"] == "websocket.http.response.start":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
get_client_addr(self.scope),
|
||||
@ -311,50 +301,48 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
self.handshake_started_event.set()
|
||||
|
||||
else:
|
||||
msg = (
|
||||
raise RuntimeError(
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close', "
|
||||
"or 'websocket.http.response.start' but got '%s'."
|
||||
f"or 'websocket.http.response.start' but got '{message['type']}'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.closed_event.is_set() and self.initial_response is None:
|
||||
await self.handshake_completed_event.wait()
|
||||
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast("WebSocketSendEvent", message)
|
||||
if message["type"] == "websocket.send":
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
await self.send(data) # type: ignore[arg-type]
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast("WebSocketCloseEvent", message)
|
||||
elif message["type"] == "websocket.close":
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
await self.close(code, reason)
|
||||
self.closed_event.set()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(
|
||||
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
|
||||
)
|
||||
except ConnectionClosed as exc:
|
||||
raise ClientDisconnected from exc
|
||||
|
||||
elif self.initial_response is not None:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast("WebSocketResponseBodyEvent", message)
|
||||
if message["type"] == "websocket.http.response.body":
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
self.closed_event.set()
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close' or response already completed."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(
|
||||
f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close' "
|
||||
"or response already completed."
|
||||
)
|
||||
|
||||
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
|
||||
if not self.connect_sent:
|
||||
|
||||
@ -17,12 +17,7 @@ from websockets.server import ServerProtocol
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
@ -295,11 +290,8 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_complete and self.initial_response is None:
|
||||
if message_type == "websocket.accept":
|
||||
message = cast(WebSocketAcceptEvent, message)
|
||||
if message["type"] == "websocket.accept":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
get_client_addr(self.scope),
|
||||
@ -320,8 +312,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast(WebSocketCloseEvent, message)
|
||||
elif message["type"] == "websocket.close":
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
@ -335,8 +326,7 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.handshake_complete = True
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
elif message_type == "websocket.http.response.start" and self.initial_response is None:
|
||||
message = cast(WebSocketResponseStartEvent, message)
|
||||
elif message["type"] == "websocket.http.response.start" and self.initial_response is None:
|
||||
if not (100 <= message["status"] < 600):
|
||||
raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"])
|
||||
self.logger.info(
|
||||
@ -351,17 +341,14 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
]
|
||||
self.initial_response = (message["status"], headers, b"")
|
||||
else:
|
||||
msg = (
|
||||
raise RuntimeError(
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
"or 'websocket.http.response.start' "
|
||||
"but got '%s'."
|
||||
f"or 'websocket.http.response.start' but got '{message['type']}'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.close_sent and self.initial_response is None:
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast(WebSocketSendEvent, message)
|
||||
if message["type"] == "websocket.send":
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
if bytes_data is not None:
|
||||
@ -371,9 +358,8 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
elif message["type"] == "websocket.close":
|
||||
if not self.transport.is_closing():
|
||||
message = cast(WebSocketCloseEvent, message)
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
@ -383,13 +369,13 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(
|
||||
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
|
||||
)
|
||||
except InvalidState:
|
||||
raise ClientDisconnected()
|
||||
elif self.initial_response is not None:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast(WebSocketResponseBodyEvent, message)
|
||||
if message["type"] == "websocket.http.response.body":
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
@ -402,12 +388,10 @@ class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
else: # pragma: no cover
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close'.")
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
message = await self.queue.get()
|
||||
|
||||
@ -14,13 +14,8 @@ from wsproto.utilities import LocalProtocolError, RemoteProtocolError
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketAcceptEvent,
|
||||
WebSocketCloseEvent,
|
||||
WebSocketEvent,
|
||||
WebSocketResponseBodyEvent,
|
||||
WebSocketResponseStartEvent,
|
||||
WebSocketScope,
|
||||
WebSocketSendEvent,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
@ -249,11 +244,8 @@ class WSProtocol(asyncio.Protocol):
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
message_type = message["type"]
|
||||
|
||||
if not self.handshake_complete:
|
||||
if message_type == "websocket.accept":
|
||||
message = cast(WebSocketAcceptEvent, message)
|
||||
if message["type"] == "websocket.accept":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
get_client_addr(self.scope),
|
||||
@ -275,7 +267,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
)
|
||||
self.transport.write(output)
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
elif message["type"] == "websocket.close":
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
@ -289,8 +281,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
elif message_type == "websocket.http.response.start":
|
||||
message = cast(WebSocketResponseStartEvent, message)
|
||||
elif message["type"] == "websocket.http.response.start":
|
||||
# ensure status code is in the valid range
|
||||
if not (100 <= message["status"] < 600):
|
||||
msg = "Invalid HTTP status code '%d' in response."
|
||||
@ -312,17 +303,14 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.response_started = True
|
||||
|
||||
else:
|
||||
msg = (
|
||||
raise RuntimeError(
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
"or 'websocket.http.response.start' "
|
||||
"but got '%s'."
|
||||
f"or 'websocket.http.response.start' but got '{message['type']}'."
|
||||
)
|
||||
raise RuntimeError(msg % message_type)
|
||||
|
||||
elif not self.close_sent and not self.response_started:
|
||||
try:
|
||||
if message_type == "websocket.send":
|
||||
message = cast(WebSocketSendEvent, message)
|
||||
if message["type"] == "websocket.send":
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
@ -330,8 +318,7 @@ class WSProtocol(asyncio.Protocol):
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
|
||||
elif message_type == "websocket.close":
|
||||
message = cast(WebSocketCloseEvent, message)
|
||||
elif message["type"] == "websocket.close":
|
||||
self.close_sent = True
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
@ -342,13 +329,13 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.send' or 'websocket.close', but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(
|
||||
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
|
||||
)
|
||||
except LocalProtocolError as exc:
|
||||
raise ClientDisconnected from exc
|
||||
elif self.response_started:
|
||||
if message_type == "websocket.http.response.body":
|
||||
message = cast("WebSocketResponseBodyEvent", message)
|
||||
if message["type"] == "websocket.http.response.body":
|
||||
body_finished = not message.get("more_body", False)
|
||||
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
|
||||
output = self.conn.send(reject_data)
|
||||
@ -360,12 +347,10 @@ class WSProtocol(asyncio.Protocol):
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
msg = "Expected ASGI message 'websocket.http.response.body' but got '%s'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
|
||||
|
||||
else:
|
||||
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
|
||||
raise RuntimeError(msg % message_type)
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close'.")
|
||||
|
||||
async def receive(self) -> WebSocketEvent:
|
||||
message = await self.queue.get()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user