Compare commits
1 Commits
main
...
send-close
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1217967a3d |
@ -555,9 +555,9 @@ async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cl
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
||||
"""
|
||||
The ASGI callable should return 'None'. If it doesn't, make sure that
|
||||
the connection is closed with an error condition.
|
||||
"""The ASGI callable should return 'None'.
|
||||
|
||||
If it doesn't, make sure that the connection is closed with an error condition.
|
||||
"""
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
@ -581,6 +581,29 @@ async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls:
|
||||
assert exc_info.value.code == 1006
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_close_transport_on_asgi_return(
|
||||
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
|
||||
):
|
||||
"""The ASGI callable should call the `websocket.close` event.
|
||||
|
||||
If it doesn't, the server should still send a close frame to the client.
|
||||
"""
|
||||
|
||||
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
||||
await send({"type": "websocket.accept"})
|
||||
|
||||
async def connect(url: str):
|
||||
async with websockets.client.connect(url) as websocket:
|
||||
_ = await websocket.recv()
|
||||
|
||||
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
||||
async with run_server(config):
|
||||
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
|
||||
await connect(f"ws://127.0.0.1:{unused_tcp_port}")
|
||||
assert exc_info.value.code == 1006
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize("code", [None, 1000, 1001])
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -240,28 +240,22 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
|
||||
except ClientDisconnected:
|
||||
self.closed_event.set()
|
||||
self.transport.close()
|
||||
except BaseException as exc:
|
||||
except BaseException:
|
||||
self.closed_event.set()
|
||||
msg = "Exception in ASGI application\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
if not self.handshake_started_event.is_set():
|
||||
self.send_500_response()
|
||||
else:
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
else:
|
||||
self.closed_event.set()
|
||||
if not self.handshake_started_event.is_set():
|
||||
msg = "ASGI callable returned without sending handshake."
|
||||
self.logger.error(msg)
|
||||
self.logger.error("ASGI callable returned without sending handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
message_type = message["type"]
|
||||
|
||||
@ -232,21 +232,17 @@ class WSProtocol(asyncio.Protocol):
|
||||
try:
|
||||
result = await self.app(self.scope, self.receive, self.send)
|
||||
except ClientDisconnected:
|
||||
self.transport.close()
|
||||
pass
|
||||
except BaseException:
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
else:
|
||||
if not self.handshake_complete:
|
||||
msg = "ASGI callable returned without completing handshake."
|
||||
self.logger.error(msg)
|
||||
self.logger.error("ASGI callable returned without completing handshake.")
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
elif result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.transport.close()
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
self.transport.close()
|
||||
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user