ASGI: Wait for response to complete before sending disconnect message (#919)

* asgi: Wait for response to complete before sending disconnect message

* Dial back type checking + remove concurrency module

* Remove somewhat redundant comment
This commit is contained in:
Jamie Hewland 2020-05-12 11:06:53 +02:00 committed by GitHub
parent 560b119d32
commit d568ecda53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 8 deletions

View File

@ -1,9 +1,28 @@
import typing
from typing import Callable, Dict, List, Optional, Tuple
import httpcore
import sniffio
from .._content_streams import ByteStream
if typing.TYPE_CHECKING: # pragma: no cover
import asyncio
import trio
Event = typing.Union[asyncio.Event, trio.Event]
def create_event() -> "Event":
if sniffio.current_async_library() == "trio":
import trio
return trio.Event()
else:
import asyncio
return asyncio.Event()
class ASGIDispatch(httpcore.AsyncHTTPTransport):
"""
@ -76,8 +95,9 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
status_code = None
response_headers = None
body_parts = []
request_complete = False
response_started = False
response_complete = False
response_complete = create_event()
headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream
@ -85,14 +105,16 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
request_body_chunks = stream.__aiter__()
async def receive() -> dict:
nonlocal response_complete
nonlocal request_complete, response_complete
if response_complete:
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}
try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
@ -108,7 +130,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
response_started = True
elif message["type"] == "http.response.body":
assert not response_complete
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)
@ -116,7 +138,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
body_parts.append(body)
if not more_body:
response_complete = True
response_complete.set()
try:
await self.app(scope, receive, send)
@ -124,7 +146,7 @@ class ASGIDispatch(httpcore.AsyncHTTPTransport):
if self.raise_app_exceptions or not response_complete:
raise
assert response_complete
assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

View File

@ -26,6 +26,7 @@ pytest-asyncio
pytest-trio
pytest-cov
trio
trio-typing
trustme
uvicorn
seed-isort-config

View File

@ -61,6 +61,7 @@ setup(
"idna==2.*",
"rfc3986>=1.3,<2",
"httpcore>=0.8.3",
"sniffio",
],
classifiers=[
"Development Status :: 4 - Beta",

View File

@ -69,8 +69,7 @@ async def test_asgi_exc_after_response():
await client.get("http://www.example.org/")
@pytest.mark.asyncio
async def test_asgi_disconnect_after_response_complete():
async def test_asgi_disconnect_after_response_complete(async_environment):
disconnect = False
async def read_body(scope, receive, send):