Include host header directly (#109)
* Improve HTTP protocol detection * Include host header when request is instantiated * Add raise_app_exceptions * Tweaks to ASGI dispatching * Linting * Don't quote multipart values * Tweak decoder ordering in header * Allow str data in request bodys
This commit is contained in:
parent
04cb3a6d98
commit
5539b69ac2
@ -54,6 +54,7 @@ class BaseClient:
|
||||
base_url: URLTypes = None,
|
||||
dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
|
||||
app: typing.Callable = None,
|
||||
raise_app_exceptions: bool = True,
|
||||
backend: ConcurrencyBackend = None,
|
||||
):
|
||||
if backend is None:
|
||||
@ -63,9 +64,13 @@ class BaseClient:
|
||||
param_count = len(inspect.signature(app).parameters)
|
||||
assert param_count in (2, 3)
|
||||
if param_count == 2:
|
||||
dispatch = WSGIDispatch(app=app)
|
||||
dispatch = WSGIDispatch(
|
||||
app=app, raise_app_exceptions=raise_app_exceptions
|
||||
)
|
||||
else:
|
||||
dispatch = ASGIDispatch(app=app)
|
||||
dispatch = ASGIDispatch(
|
||||
app=app, raise_app_exceptions=raise_app_exceptions
|
||||
)
|
||||
|
||||
if dispatch is None:
|
||||
async_dispatch = ConnectionPool(
|
||||
@ -558,7 +563,7 @@ class Client(BaseClient):
|
||||
If the request data is an bytes iterator then return an async bytes
|
||||
iterator onto the request data.
|
||||
"""
|
||||
if data is None or isinstance(data, (bytes, dict)):
|
||||
if data is None or isinstance(data, (str, bytes, dict)):
|
||||
return data
|
||||
|
||||
# Coerce an iterator into an async iterator, with each item in the
|
||||
|
||||
@ -134,8 +134,8 @@ class MultiDecoder(Decoder):
|
||||
|
||||
SUPPORTED_DECODERS = {
|
||||
"identity": IdentityDecoder,
|
||||
"deflate": DeflateDecoder,
|
||||
"gzip": GZipDecoder,
|
||||
"deflate": DeflateDecoder,
|
||||
"br": BrotliDecoder,
|
||||
}
|
||||
|
||||
|
||||
@ -35,10 +35,12 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
app: typing.Callable,
|
||||
raise_app_exceptions: bool = True,
|
||||
root_path: str = "",
|
||||
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
self.root_path = root_path
|
||||
self.client = client
|
||||
|
||||
@ -53,11 +55,12 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"headers": request.headers.raw,
|
||||
"scheme": request.url.scheme,
|
||||
"path": request.url.path,
|
||||
"query": request.url.query.encode("ascii"),
|
||||
"query_string": request.url.query.encode("ascii"),
|
||||
"server": request.url.host,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
@ -80,7 +83,7 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
return {"type": "http.request", "body": body, "more_body": True}
|
||||
|
||||
async def send(message: dict) -> None:
|
||||
nonlocal status_code, headers, response_started, response_body
|
||||
nonlocal status_code, headers, response_started, response_body, request
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
@ -89,14 +92,13 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
elif message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if body:
|
||||
if body and request.method != "HEAD":
|
||||
await response_body.put(body)
|
||||
if not more_body:
|
||||
await response_body.done()
|
||||
|
||||
async def run_app() -> None:
|
||||
nonlocal app, scope, receive, send, app_exc, response_body
|
||||
|
||||
try:
|
||||
await app(scope, receive, send)
|
||||
except Exception as exc:
|
||||
@ -117,17 +119,18 @@ class ASGIDispatch(AsyncDispatcher):
|
||||
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if app_exc is not None:
|
||||
if app_exc is not None and self.raise_app_exceptions:
|
||||
raise app_exc
|
||||
|
||||
assert response_started.is_set, "application did not return a response."
|
||||
assert response_started.is_set(), "application did not return a response."
|
||||
assert status_code is not None
|
||||
assert headers is not None
|
||||
|
||||
async def on_close() -> None:
|
||||
nonlocal app_task
|
||||
nonlocal app_task, response_body
|
||||
await response_body.drain()
|
||||
await app_task
|
||||
if app_exc is not None:
|
||||
if app_exc is not None and self.raise_app_exceptions:
|
||||
raise app_exc
|
||||
|
||||
return AsyncResponse(
|
||||
@ -163,6 +166,14 @@ class BodyIterator:
|
||||
assert isinstance(data, bytes)
|
||||
yield data
|
||||
|
||||
async def drain(self) -> None:
|
||||
"""
|
||||
Drain any remaining body, in order to allow any blocked `put()` calls
|
||||
to complete.
|
||||
"""
|
||||
async for chunk in self.iterate():
|
||||
pass # pragma: no cover
|
||||
|
||||
async def put(self, data: bytes) -> None:
|
||||
"""
|
||||
Used by the server to add data to the response body.
|
||||
|
||||
@ -79,9 +79,6 @@ class HTTP11Connection:
|
||||
method = request.method.encode("ascii")
|
||||
target = request.url.full_path.encode("ascii")
|
||||
headers = request.headers.raw
|
||||
if "Host" not in request.headers:
|
||||
host = request.url.authority.encode("ascii")
|
||||
headers = [(b"host", host)] + headers
|
||||
event = h11.Request(method=method, target=target, headers=headers)
|
||||
await self._send_event(event, timeout)
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ class HTTP2Connection:
|
||||
(b":authority", request.url.authority.encode("ascii")),
|
||||
(b":scheme", request.url.scheme.encode("ascii")),
|
||||
(b":path", request.url.full_path.encode("ascii")),
|
||||
] + request.headers.raw
|
||||
] + [(k, v) for k, v in request.headers.raw if k != b"host"]
|
||||
self.h2_state.send_headers(stream_id, headers)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
|
||||
@ -35,10 +35,12 @@ class WSGIDispatch(Dispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
app: typing.Callable,
|
||||
raise_app_exceptions: bool = True,
|
||||
script_name: str = "",
|
||||
remote_addr: str = "127.0.0.1",
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
self.script_name = script_name
|
||||
self.remote_addr = remote_addr
|
||||
|
||||
@ -87,7 +89,7 @@ class WSGIDispatch(Dispatcher):
|
||||
|
||||
assert seen_status is not None
|
||||
assert seen_response_headers is not None
|
||||
if seen_exc_info:
|
||||
if seen_exc_info and self.raise_app_exceptions:
|
||||
raise seen_exc_info[1]
|
||||
|
||||
return Response(
|
||||
|
||||
@ -51,9 +51,9 @@ AuthTypes = typing.Union[
|
||||
typing.Callable[["AsyncRequest"], "AsyncRequest"],
|
||||
]
|
||||
|
||||
AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
|
||||
AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
|
||||
|
||||
RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
|
||||
RequestData = typing.Union[dict, str, bytes, typing.Iterator[bytes]]
|
||||
|
||||
RequestFiles = typing.Dict[
|
||||
str,
|
||||
@ -527,6 +527,7 @@ class BaseRequest:
|
||||
|
||||
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
|
||||
has_host = "host" in self.headers
|
||||
has_user_agent = "user-agent" in self.headers
|
||||
has_accept = "accept" in self.headers
|
||||
has_content_length = (
|
||||
@ -534,6 +535,8 @@ class BaseRequest:
|
||||
)
|
||||
has_accept_encoding = "accept-encoding" in self.headers
|
||||
|
||||
if not has_host:
|
||||
auto_headers.append((b"host", self.url.authority.encode("ascii")))
|
||||
if not has_user_agent:
|
||||
auto_headers.append((b"user-agent", b"http3"))
|
||||
if not has_accept:
|
||||
@ -585,7 +588,8 @@ class AsyncRequest(BaseRequest):
|
||||
self.content = content
|
||||
if content_type:
|
||||
self.headers["Content-Type"] = content_type
|
||||
elif isinstance(data, bytes):
|
||||
elif isinstance(data, (str, bytes)):
|
||||
data = data.encode("utf-8") if isinstance(data, str) else data
|
||||
self.is_streaming = False
|
||||
self.content = data
|
||||
else:
|
||||
@ -634,7 +638,8 @@ class Request(BaseRequest):
|
||||
self.content = content
|
||||
if content_type:
|
||||
self.headers["Content-Type"] = content_type
|
||||
elif isinstance(data, bytes):
|
||||
elif isinstance(data, (str, bytes)):
|
||||
data = data.encode("utf-8") if isinstance(data, str) else data
|
||||
self.is_streaming = False
|
||||
self.content = data
|
||||
else:
|
||||
|
||||
@ -3,7 +3,7 @@ import mimetypes
|
||||
import os
|
||||
import typing
|
||||
from io import BytesIO
|
||||
from urllib.parse import quote_plus
|
||||
from urllib.parse import quote
|
||||
|
||||
|
||||
class Field:
|
||||
@ -20,13 +20,13 @@ class DataField(Field):
|
||||
self.value = value
|
||||
|
||||
def render_headers(self) -> bytes:
|
||||
name = quote_plus(self.name, encoding="utf-8").encode("ascii")
|
||||
name = quote(self.name, encoding="utf-8").encode("ascii")
|
||||
return b"".join(
|
||||
[b'Content-Disposition: form-data; name="', name, b'"\r\n' b"\r\n"]
|
||||
)
|
||||
|
||||
def render_data(self) -> bytes:
|
||||
return quote_plus(self.value, encoding="utf-8").encode("ascii")
|
||||
return self.value.encode("utf-8")
|
||||
|
||||
|
||||
class FileField(Field):
|
||||
@ -49,8 +49,8 @@ class FileField(Field):
|
||||
return mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
|
||||
|
||||
def render_headers(self) -> bytes:
|
||||
name = quote_plus(self.name, encoding="utf-8").encode("ascii")
|
||||
filename = quote_plus(self.filename, encoding="utf-8").encode("ascii")
|
||||
name = quote(self.name, encoding="utf-8").encode("ascii")
|
||||
filename = quote(self.filename, encoding="utf-8").encode("ascii")
|
||||
content_type = self.content_type.encode("ascii")
|
||||
return b"".join(
|
||||
[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user