Refactor ASGITransport.request() (#1021)

This commit is contained in:
Florimond Manca 2020-06-13 19:59:09 +02:00 committed by GitHub
parent 838f417ce0
commit b481166481
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,4 @@
import typing
from typing import Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import httpcore
import sniffio
@ -7,11 +6,11 @@ import sniffio
from .._content_streams import ByteStream
from .._utils import warn_deprecated
if typing.TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING: # pragma: no cover
import asyncio
import trio
Event = typing.Union[asyncio.Event, trio.Event]
Event = Union[asyncio.Event, trio.Event]
def create_event() -> "Event":
@ -78,6 +77,10 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
stream: httpcore.AsyncByteStream = None,
timeout: Dict[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream
# ASGI scope.
scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
scope = {
@ -93,20 +96,22 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
"client": self.client,
"root_path": self.root_path,
}
# Request.
request_body_chunks = stream.__aiter__()
request_complete = False
# Response.
status_code = None
response_headers = None
body_parts = []
request_complete = False
response_started = False
response_complete = create_event()
headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream
request_body_chunks = stream.__aiter__()
# ASGI callables.
async def receive() -> dict:
nonlocal request_complete, response_complete
nonlocal request_complete
if request_complete:
await response_complete.wait()
@ -120,8 +125,7 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: dict) -> None:
nonlocal status_code, response_headers, body_parts
nonlocal response_started, response_complete
nonlocal status_code, response_headers, response_started
if message["type"] == "http.response.start":
assert not response_started
@ -144,7 +148,7 @@ class ASGITransport(httpcore.AsyncHTTPTransport):
try:
await self.app(scope, receive, send)
except Exception:
if self.raise_app_exceptions or not response_complete:
if self.raise_app_exceptions or not response_complete.is_set():
raise
assert response_complete.is_set()