Compare commits

...

1 Commits

Author SHA1 Message Date
Marcelo Trylesinski
1217967a3d Send close frame after ASGI application returned 2024-05-12 18:01:47 +02:00
3 changed files with 36 additions and 23 deletions

View File

@ -555,9 +555,9 @@ async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cl
@pytest.mark.anyio
async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
"""
The ASGI callable should return 'None'. If it doesn't, make sure that
the connection is closed with an error condition.
"""The ASGI callable should return 'None'.
If it doesn't, make sure that the connection is closed with an error condition.
"""
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
@ -581,6 +581,29 @@ async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls:
assert exc_info.value.code == 1006
@pytest.mark.anyio
async def test_close_transport_on_asgi_return(
ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
):
"""The ASGI callable should call the `websocket.close` event.
If it doesn't, the server should still send a close frame to the client.
"""
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "websocket.accept"})
async def connect(url: str):
async with websockets.client.connect(url) as websocket:
_ = await websocket.recv()
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
await connect(f"ws://127.0.0.1:{unused_tcp_port}")
assert exc_info.value.code == 1006
@pytest.mark.anyio
@pytest.mark.parametrize("code", [None, 1000, 1001])
@pytest.mark.parametrize(

View File

@ -240,28 +240,22 @@ class WebSocketProtocol(WebSocketServerProtocol):
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
except ClientDisconnected:
self.closed_event.set()
self.transport.close()
except BaseException as exc:
except BaseException:
self.closed_event.set()
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_started_event.is_set():
self.send_500_response()
else:
await self.handshake_completed_event.wait()
self.transport.close()
else:
self.closed_event.set()
if not self.handshake_started_event.is_set():
msg = "ASGI callable returned without sending handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without sending handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
await self.handshake_completed_event.wait()
self.transport.close()
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
await self.handshake_completed_event.wait()
self.transport.close()
async def asgi_send(self, message: ASGISendEvent) -> None:
message_type = message["type"]

View File

@ -232,21 +232,17 @@ class WSProtocol(asyncio.Protocol):
try:
result = await self.app(self.scope, self.receive, self.send)
except ClientDisconnected:
self.transport.close()
pass
except BaseException:
self.logger.exception("Exception in ASGI application\n")
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without completing handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.transport.close()
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
self.transport.close()
async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()