Include host header directly (#109)

* Improve HTTP protocol detection

* Include host header when request is instantiated

* Add raise_app_exceptions

* Tweaks to ASGI dispatching

* Linting

* Don't quote multipart values

* Tweak decoder ordering in header

* Allow str data in request bodys
This commit is contained in:
Tom Christie 2019-06-27 16:46:13 +01:00 committed by GitHub
parent 04cb3a6d98
commit 5539b69ac2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 46 additions and 26 deletions

View File

@ -54,6 +54,7 @@ class BaseClient:
base_url: URLTypes = None,
dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
app: typing.Callable = None,
raise_app_exceptions: bool = True,
backend: ConcurrencyBackend = None,
):
if backend is None:
@ -63,9 +64,13 @@ class BaseClient:
param_count = len(inspect.signature(app).parameters)
assert param_count in (2, 3)
if param_count == 2:
dispatch = WSGIDispatch(app=app)
dispatch = WSGIDispatch(
app=app, raise_app_exceptions=raise_app_exceptions
)
else:
dispatch = ASGIDispatch(app=app)
dispatch = ASGIDispatch(
app=app, raise_app_exceptions=raise_app_exceptions
)
if dispatch is None:
async_dispatch = ConnectionPool(
@ -558,7 +563,7 @@ class Client(BaseClient):
If the request data is an bytes iterator then return an async bytes
iterator onto the request data.
"""
if data is None or isinstance(data, (bytes, dict)):
if data is None or isinstance(data, (str, bytes, dict)):
return data
# Coerce an iterator into an async iterator, with each item in the

View File

@ -134,8 +134,8 @@ class MultiDecoder(Decoder):
SUPPORTED_DECODERS = {
"identity": IdentityDecoder,
"deflate": DeflateDecoder,
"gzip": GZipDecoder,
"deflate": DeflateDecoder,
"br": BrotliDecoder,
}

View File

@ -35,10 +35,12 @@ class ASGIDispatch(AsyncDispatcher):
def __init__(
self,
app: typing.Callable,
raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
@ -53,11 +55,12 @@ class ASGIDispatch(AsyncDispatcher):
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": request.headers.raw,
"scheme": request.url.scheme,
"path": request.url.path,
"query": request.url.query.encode("ascii"),
"query_string": request.url.query.encode("ascii"),
"server": request.url.host,
"client": self.client,
"root_path": self.root_path,
@ -80,7 +83,7 @@ class ASGIDispatch(AsyncDispatcher):
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: dict) -> None:
nonlocal status_code, headers, response_started, response_body
nonlocal status_code, headers, response_started, response_body, request
if message["type"] == "http.response.start":
status_code = message["status"]
@ -89,14 +92,13 @@ class ASGIDispatch(AsyncDispatcher):
elif message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body:
if body and request.method != "HEAD":
await response_body.put(body)
if not more_body:
await response_body.done()
async def run_app() -> None:
nonlocal app, scope, receive, send, app_exc, response_body
try:
await app(scope, receive, send)
except Exception as exc:
@ -117,17 +119,18 @@ class ASGIDispatch(AsyncDispatcher):
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
if app_exc is not None:
if app_exc is not None and self.raise_app_exceptions:
raise app_exc
assert response_started.is_set, "application did not return a response."
assert response_started.is_set(), "application did not return a response."
assert status_code is not None
assert headers is not None
async def on_close() -> None:
nonlocal app_task
nonlocal app_task, response_body
await response_body.drain()
await app_task
if app_exc is not None:
if app_exc is not None and self.raise_app_exceptions:
raise app_exc
return AsyncResponse(
@ -163,6 +166,14 @@ class BodyIterator:
assert isinstance(data, bytes)
yield data
async def drain(self) -> None:
"""
Drain any remaining body, in order to allow any blocked `put()` calls
to complete.
"""
async for chunk in self.iterate():
pass # pragma: no cover
async def put(self, data: bytes) -> None:
"""
Used by the server to add data to the response body.

View File

@ -79,9 +79,6 @@ class HTTP11Connection:
method = request.method.encode("ascii")
target = request.url.full_path.encode("ascii")
headers = request.headers.raw
if "Host" not in request.headers:
host = request.url.authority.encode("ascii")
headers = [(b"host", host)] + headers
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event, timeout)

View File

@ -76,7 +76,7 @@ class HTTP2Connection:
(b":authority", request.url.authority.encode("ascii")),
(b":scheme", request.url.scheme.encode("ascii")),
(b":path", request.url.full_path.encode("ascii")),
] + request.headers.raw
] + [(k, v) for k, v in request.headers.raw if k != b"host"]
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)

View File

@ -35,10 +35,12 @@ class WSGIDispatch(Dispatcher):
def __init__(
self,
app: typing.Callable,
raise_app_exceptions: bool = True,
script_name: str = "",
remote_addr: str = "127.0.0.1",
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.script_name = script_name
self.remote_addr = remote_addr
@ -87,7 +89,7 @@ class WSGIDispatch(Dispatcher):
assert seen_status is not None
assert seen_response_headers is not None
if seen_exc_info:
if seen_exc_info and self.raise_app_exceptions:
raise seen_exc_info[1]
return Response(

View File

@ -51,9 +51,9 @@ AuthTypes = typing.Union[
typing.Callable[["AsyncRequest"], "AsyncRequest"],
]
AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
RequestData = typing.Union[dict, str, bytes, typing.Iterator[bytes]]
RequestFiles = typing.Dict[
str,
@ -527,6 +527,7 @@ class BaseRequest:
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
has_host = "host" in self.headers
has_user_agent = "user-agent" in self.headers
has_accept = "accept" in self.headers
has_content_length = (
@ -534,6 +535,8 @@ class BaseRequest:
)
has_accept_encoding = "accept-encoding" in self.headers
if not has_host:
auto_headers.append((b"host", self.url.authority.encode("ascii")))
if not has_user_agent:
auto_headers.append((b"user-agent", b"http3"))
if not has_accept:
@ -585,7 +588,8 @@ class AsyncRequest(BaseRequest):
self.content = content
if content_type:
self.headers["Content-Type"] = content_type
elif isinstance(data, bytes):
elif isinstance(data, (str, bytes)):
data = data.encode("utf-8") if isinstance(data, str) else data
self.is_streaming = False
self.content = data
else:
@ -634,7 +638,8 @@ class Request(BaseRequest):
self.content = content
if content_type:
self.headers["Content-Type"] = content_type
elif isinstance(data, bytes):
elif isinstance(data, (str, bytes)):
data = data.encode("utf-8") if isinstance(data, str) else data
self.is_streaming = False
self.content = data
else:

View File

@ -3,7 +3,7 @@ import mimetypes
import os
import typing
from io import BytesIO
from urllib.parse import quote_plus
from urllib.parse import quote
class Field:
@ -20,13 +20,13 @@ class DataField(Field):
self.value = value
def render_headers(self) -> bytes:
name = quote_plus(self.name, encoding="utf-8").encode("ascii")
name = quote(self.name, encoding="utf-8").encode("ascii")
return b"".join(
[b'Content-Disposition: form-data; name="', name, b'"\r\n' b"\r\n"]
)
def render_data(self) -> bytes:
return quote_plus(self.value, encoding="utf-8").encode("ascii")
return self.value.encode("utf-8")
class FileField(Field):
@ -49,8 +49,8 @@ class FileField(Field):
return mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
def render_headers(self) -> bytes:
name = quote_plus(self.name, encoding="utf-8").encode("ascii")
filename = quote_plus(self.filename, encoding="utf-8").encode("ascii")
name = quote(self.name, encoding="utf-8").encode("ascii")
filename = quote(self.filename, encoding="utf-8").encode("ascii")
content_type = self.content_type.encode("ascii")
return b"".join(
[