ASGI: Wait for response to complete before sending disconnect message (#919)
* asgi: Wait for response to complete before sending disconnect message * Dial back type checking + remove concurrency module * Remove somewhat redundant comment
This commit is contained in:
parent
560b119d32
commit
d568ecda53
@ -1,9 +1,28 @@
|
||||
import typing
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import httpcore
|
||||
import sniffio
|
||||
|
||||
from .._content_streams import ByteStream
|
||||
|
||||
if typing.TYPE_CHECKING: # pragma: no cover
|
||||
import asyncio
|
||||
import trio
|
||||
|
||||
Event = typing.Union[asyncio.Event, trio.Event]
|
||||
|
||||
|
||||
def create_event() -> "Event":
|
||||
if sniffio.current_async_library() == "trio":
|
||||
import trio
|
||||
|
||||
return trio.Event()
|
||||
else:
|
||||
import asyncio
|
||||
|
||||
return asyncio.Event()
|
||||
|
||||
|
||||
class ASGIDispatch(httpcore.AsyncHTTPTransport):
|
||||
"""
|
||||
@ -76,8 +95,9 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
|
||||
status_code = None
|
||||
response_headers = None
|
||||
body_parts = []
|
||||
request_complete = False
|
||||
response_started = False
|
||||
response_complete = False
|
||||
response_complete = create_event()
|
||||
|
||||
headers = [] if headers is None else headers
|
||||
stream = ByteStream(b"") if stream is None else stream
|
||||
@ -85,14 +105,16 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
|
||||
request_body_chunks = stream.__aiter__()
|
||||
|
||||
async def receive() -> dict:
|
||||
nonlocal response_complete
|
||||
nonlocal request_complete, response_complete
|
||||
|
||||
if response_complete:
|
||||
if request_complete:
|
||||
await response_complete.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
try:
|
||||
body = await request_body_chunks.__anext__()
|
||||
except StopAsyncIteration:
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
return {"type": "http.request", "body": body, "more_body": True}
|
||||
|
||||
@ -108,7 +130,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
|
||||
response_started = True
|
||||
|
||||
elif message["type"] == "http.response.body":
|
||||
assert not response_complete
|
||||
assert not response_complete.is_set()
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
@ -116,7 +138,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
|
||||
body_parts.append(body)
|
||||
|
||||
if not more_body:
|
||||
response_complete = True
|
||||
response_complete.set()
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
@ -124,7 +146,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
|
||||
if self.raise_app_exceptions or not response_complete:
|
||||
raise
|
||||
|
||||
assert response_complete
|
||||
assert response_complete.is_set()
|
||||
assert status_code is not None
|
||||
assert response_headers is not None
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ pytest-asyncio
|
||||
pytest-trio
|
||||
pytest-cov
|
||||
trio
|
||||
trio-typing
|
||||
trustme
|
||||
uvicorn
|
||||
seed-isort-config
|
||||
|
||||
1
setup.py
1
setup.py
@ -61,6 +61,7 @@ setup(
|
||||
"idna==2.*",
|
||||
"rfc3986>=1.3,<2",
|
||||
"httpcore>=0.8.3",
|
||||
"sniffio",
|
||||
],
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
|
||||
@ -69,8 +69,7 @@ async def test_asgi_exc_after_response():
|
||||
await client.get("http://www.example.org/")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asgi_disconnect_after_response_complete():
|
||||
async def test_asgi_disconnect_after_response_complete(async_environment):
|
||||
disconnect = False
|
||||
|
||||
async def read_body(scope, receive, send):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user