Use Streams API for both requests and responses. (#648)

* Internal ContentStreams API
This commit is contained in:
Tom Christie 2019-12-20 16:05:04 +00:00 committed by GitHub
parent 36af9d9597
commit cee1fccaca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 431 additions and 394 deletions

View File

@ -19,7 +19,7 @@ from .config import (
UnsetType,
VerifyTypes,
)
from .content import RequestContent
from .content_streams import ContentStream
from .dispatch.asgi import ASGIDispatch
from .dispatch.base import Dispatcher
from .dispatch.connection_pool import ConnectionPool
@ -495,11 +495,11 @@ class Client:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url, method)
content = self.redirect_content(request, method)
stream = self.redirect_stream(request, method)
cookies = Cookies(self.cookies)
request = Request(method=method, url=url, headers=headers, cookies=cookies)
request.content = content
return request
return Request(
method=method, url=url, headers=headers, cookies=cookies, stream=stream
)
def redirect_method(self, request: Request, response: Response) -> str:
"""
@ -567,15 +567,17 @@ class Client:
return headers
def redirect_content(self, request: Request, method: str) -> RequestContent:
def redirect_stream(
self, request: Request, method: str
) -> typing.Optional[ContentStream]:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return RequestContent()
if not request.content.can_replay():
return None
if not request.stream.can_replay():
raise RedirectBodyUnavailable()
return request.content
return request.stream
async def send_handling_auth(
self,

View File

@ -1,171 +0,0 @@
import typing
from json import dumps as json_dumps
from urllib.parse import urlencode
from .multipart import multipart_encode
RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
RequestFiles = typing.Dict[
str,
typing.Union[
# file (or str)
typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
# (filename, file (or str))
typing.Tuple[
typing.Optional[str], typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
],
# (filename, file (or str), content_type)
typing.Tuple[
typing.Optional[str],
typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
typing.Optional[str],
],
],
]
class RequestContent:
"""
Base class for request content.
Defaults to a "no request body" implementation.
"""
def get_headers(self) -> typing.Dict[str, str]:
"""
Return a dictionary of request headers that are implied by the encoding.
"""
return {}
def can_replay(self) -> bool:
"""
Return `True` if `__aiter__` can be called multiple times.
We need this in order to determine if we can re-issue a request body
when we receive a redirect response.
"""
return True
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b""
async def aread(self) -> bytes:
return b"".join([part async for part in self])
class BytesRequestContent(RequestContent):
"""
Request content encoded as plain bytes.
"""
def __init__(self, body: typing.Union[str, bytes]) -> None:
self.body = body.encode("utf-8") if isinstance(body, str) else body
def get_headers(self) -> typing.Dict[str, str]:
content_length = str(len(self.body))
return {"Content-Length": content_length}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
class StreamingRequestContent(RequestContent):
"""
Request content encoded as plain bytes, using an async byte iterator.
"""
def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
self.aiterator = aiterator
def can_replay(self) -> bool:
return False
def get_headers(self) -> typing.Dict[str, str]:
return {"Transfer-Encoding": "chunked"}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for part in self.aiterator:
yield part
class JSONRequestContent(RequestContent):
"""
Request content encoded as JSON.
"""
def __init__(self, json: typing.Any) -> None:
self.body = json_dumps(json).encode("utf-8")
def get_headers(self) -> typing.Dict[str, str]:
content_length = str(len(self.body))
content_type = "application/json"
return {"Content-Length": content_length, "Content-Type": content_type}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
class URLEncodedRequestContent(RequestContent):
"""
Request content as URL encoded form data.
"""
def __init__(self, data: dict) -> None:
self.body = urlencode(data, doseq=True).encode("utf-8")
def get_headers(self) -> typing.Dict[str, str]:
content_length = str(len(self.body))
content_type = "application/x-www-form-urlencoded"
return {"Content-Length": content_length, "Content-Type": content_type}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
class MultipartRequestContent(RequestContent):
"""
Request content as multipart encoded form data.
"""
def __init__(self, data: dict, files: dict, boundary: bytes = None) -> None:
self.body, self.content_type = multipart_encode(data, files, boundary)
def get_headers(self) -> typing.Dict[str, str]:
content_length = str(len(self.body))
content_type = self.content_type
return {"Content-Length": content_length, "Content-Type": content_type}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
def encode(
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
boundary: bytes = None,
) -> RequestContent:
"""
Handles encoding the given `data`, `files`, and `json`, returning
a `RequestContent` implementation which provides a byte iterator onto
the content, as well as `.is_rewindable()` and `.get_headers()` interfaces.
The `boundary` argument is also included for reproducible test cases
when working with multipart data.
"""
if data is None:
if json is not None:
return JSONRequestContent(json)
elif files:
return MultipartRequestContent({}, files, boundary=boundary)
else:
return RequestContent()
elif isinstance(data, dict):
if files is not None:
return MultipartRequestContent(data, files, boundary=boundary)
else:
return URLEncodedRequestContent(data)
elif isinstance(data, (str, bytes)):
return BytesRequestContent(data)
else:
return StreamingRequestContent(data)

279
httpx/content_streams.py Normal file
View File

@ -0,0 +1,279 @@
import binascii
import mimetypes
import os
import typing
from io import BytesIO
from json import dumps as json_dumps
from pathlib import Path
from urllib.parse import urlencode
from .utils import format_form_param
RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
RequestFiles = typing.Dict[
str,
typing.Union[
# file (or str)
typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
# (filename, file (or str))
typing.Tuple[
typing.Optional[str], typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
],
# (filename, file (or str), content_type)
typing.Tuple[
typing.Optional[str],
typing.Union[typing.IO[typing.AnyStr], typing.AnyStr],
typing.Optional[str],
],
],
]
class ContentStream:
def get_headers(self) -> typing.Dict[str, str]:
"""
Return a dictionary of headers that are implied by the encoding.
"""
return {}
def can_replay(self) -> bool:
"""
Return `True` if `__aiter__` can be called multiple times.
We need this in cases such determining if we can re-issue a request
body when we receive a redirect response.
"""
return True
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b""
async def aclose(self) -> None:
pass
class ByteStream(ContentStream):
"""
Request content encoded as plain bytes.
"""
def __init__(self, body: typing.Union[str, bytes]) -> None:
self.body = body.encode("utf-8") if isinstance(body, str) else body
def get_headers(self) -> typing.Dict[str, str]:
if not self.body:
return {}
content_length = str(len(self.body))
return {"Content-Length": content_length}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
class AsyncIteratorStream(ContentStream):
"""
Request content encoded as plain bytes, using an async byte iterator.
"""
def __init__(
self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None
) -> None:
self.aiterator = aiterator
self.close_func = close_func
def can_replay(self) -> bool:
return False
def get_headers(self) -> typing.Dict[str, str]:
return {"Transfer-Encoding": "chunked"}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for part in self.aiterator:
yield part
async def aclose(self) -> None:
if self.close_func is not None:
await self.close_func()
class JSONStream(ContentStream):
"""
Request content encoded as JSON.
"""
def __init__(self, json: typing.Any) -> None:
self.body = json_dumps(json).encode("utf-8")
def get_headers(self) -> typing.Dict[str, str]:
content_length = str(len(self.body))
content_type = "application/json"
return {"Content-Length": content_length, "Content-Type": content_type}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
class URLEncodedStream(ContentStream):
"""
Request content as URL encoded form data.
"""
def __init__(self, data: dict) -> None:
self.body = urlencode(data, doseq=True).encode("utf-8")
def get_headers(self) -> typing.Dict[str, str]:
content_length = str(len(self.body))
content_type = "application/x-www-form-urlencoded"
return {"Content-Length": content_length, "Content-Type": content_type}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
class MultipartStream(ContentStream):
"""
Request content as multipart encoded form data.
"""
class DataField:
"""
A single form field item, within a multipart form field.
"""
def __init__(self, name: str, value: typing.Union[str, bytes]) -> None:
if not isinstance(name, str):
raise TypeError("Invalid type for name. Expected str.")
if not isinstance(value, (str, bytes)):
raise TypeError("Invalid type for value. Expected str or bytes.")
self.name = name
self.value = value
def render_headers(self) -> bytes:
name = format_form_param("name", self.name)
return b"".join([b"Content-Disposition: form-data; ", name, b"\r\n\r\n"])
def render_data(self) -> bytes:
return (
self.value
if isinstance(self.value, bytes)
else self.value.encode("utf-8")
)
class FileField:
"""
A single file field item, within a multipart form field.
"""
def __init__(
self, name: str, value: typing.Union[typing.IO[typing.AnyStr], tuple]
) -> None:
self.name = name
if not isinstance(value, tuple):
self.filename = Path(str(getattr(value, "name", "upload"))).name
self.file = (
value
) # type: typing.Union[typing.IO[str], typing.IO[bytes]]
self.content_type = self.guess_content_type()
else:
self.filename = value[0]
self.file = value[1]
self.content_type = (
value[2] if len(value) > 2 else self.guess_content_type()
)
def guess_content_type(self) -> typing.Optional[str]:
if self.filename:
return (
mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
)
else:
return None
def render_headers(self) -> bytes:
parts = [
b"Content-Disposition: form-data; ",
format_form_param("name", self.name),
]
if self.filename:
filename = format_form_param("filename", self.filename)
parts.extend([b"; ", filename])
if self.content_type is not None:
content_type = self.content_type.encode()
parts.extend([b"\r\nContent-Type: ", content_type])
parts.append(b"\r\n\r\n")
return b"".join(parts)
def render_data(self) -> bytes:
if isinstance(self.file, str):
content = self.file
else:
content = self.file.read()
return content.encode("utf-8") if isinstance(content, str) else content
def __init__(self, data: dict, files: dict, boundary: bytes = None) -> None:
body = BytesIO()
if boundary is None:
boundary = binascii.hexlify(os.urandom(16))
for field in self.iter_fields(data, files):
body.write(b"--%s\r\n" % boundary)
body.write(field.render_headers())
body.write(field.render_data())
body.write(b"\r\n")
body.write(b"--%s--\r\n" % boundary)
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
"ascii"
)
self.body = body.getvalue()
def iter_fields(
self, data: dict, files: dict
) -> typing.Iterator[typing.Union["FileField", "DataField"]]:
for name, value in data.items():
if isinstance(value, (list, dict)):
for item in value:
yield self.DataField(name=name, value=item)
else:
yield self.DataField(name=name, value=value)
for name, value in files.items():
yield self.FileField(name=name, value=value)
def get_headers(self) -> typing.Dict[str, str]:
content_length = str(len(self.body))
content_type = self.content_type
return {"Content-Length": content_length, "Content-Type": content_type}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
def encode(
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
boundary: bytes = None,
) -> ContentStream:
"""
Handles encoding the given `data`, `files`, and `json`, returning
a `ContentStream` implementation.
"""
if data is None:
if json is not None:
return JSONStream(json=json)
elif files:
return MultipartStream(data={}, files=files, boundary=boundary)
else:
return ByteStream(body=b"")
elif isinstance(data, dict):
if files is not None:
return MultipartStream(data=data, files=files, boundary=boundary)
else:
return URLEncodedStream(data=data)
elif isinstance(data, (str, bytes)):
return ByteStream(body=data)
else:
return AsyncIteratorStream(aiterator=data)

View File

@ -1,6 +1,7 @@
import typing
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..content_streams import ByteStream
from ..models import Request, Response
from .base import Dispatcher
@ -77,13 +78,14 @@ class ASGIDispatch(Dispatcher):
status_code = None
headers = None
body_parts = []
request_stream = request.stream()
response_started = False
response_complete = False
request_body_chunks = request.stream.__aiter__()
async def receive() -> dict:
try:
body = await request_stream.__anext__()
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
@ -120,10 +122,12 @@ class ASGIDispatch(Dispatcher):
assert status_code is not None
assert headers is not None
stream = ByteStream(b"".join(body_parts))
return Response(
status_code=status_code,
http_version="HTTP/1.1",
headers=headers,
content=b"".join(body_parts),
stream=stream,
request=request,
)

View File

@ -4,6 +4,7 @@ import h11
from ..concurrency.base import BaseSocketStream
from ..config import Timeout
from ..content_streams import AsyncIteratorStream
from ..exceptions import ConnectionClosed, ProtocolError
from ..models import Request, Response
from ..utils import get_logger
@ -50,14 +51,16 @@ class HTTP11Connection(OpenConnection):
await self._send_request(request, timeout)
await self._send_request_body(request, timeout)
http_version, status_code, headers = await self._receive_response(timeout)
content = self._receive_response_data(timeout)
stream = AsyncIteratorStream(
aiterator=self._receive_response_data(timeout),
close_func=self.response_closed,
)
return Response(
status_code=status_code,
http_version=http_version,
headers=headers,
content=content,
on_close=self.response_closed,
stream=stream,
request=request,
)
@ -93,7 +96,7 @@ class HTTP11Connection(OpenConnection):
"""
try:
# Send the request body.
async for chunk in request.stream():
async for chunk in request.stream:
logger.trace(f"send_data data=Data(<{len(chunk)} bytes>)")
event = h11.Data(data=chunk)
await self._send_event(event, timeout)

View File

@ -12,6 +12,7 @@ from ..concurrency.base import (
lookup_backend,
)
from ..config import Timeout
from ..content_streams import AsyncIteratorStream
from ..exceptions import ProtocolError
from ..models import Request, Response
from ..utils import get_logger
@ -209,13 +210,15 @@ class HTTP2Stream:
# Receive the response.
status_code, headers = await self.receive_response(timeout)
content = self.body_iter(timeout)
stream = AsyncIteratorStream(
aiterator=self.body_iter(timeout), close_func=self.close
)
return Response(
status_code=status_code,
http_version="HTTP/2",
headers=headers,
content=content,
on_close=self.close,
stream=stream,
request=request,
)
@ -238,7 +241,7 @@ class HTTP2Stream:
async def send_body(self, request: Request, timeout: Timeout) -> None:
logger.trace(f"send_body stream_id={self.stream_id}")
async for data in request.stream():
async for data in request.stream:
while data:
max_flow = await self.connection.wait_for_outgoing_flow(
self.stream_id, timeout

View File

@ -167,8 +167,9 @@ class HTTPProxy(ConnectionPool):
response=proxy_response,
)
else:
proxy_response.on_close = None
await proxy_response.read()
# Hack to ingest the response, without closing it.
async for chunk in proxy_response._raw_stream:
pass
return connection

View File

@ -13,7 +13,7 @@ import chardet
import rfc3986
from .config import USER_AGENT
from .content import RequestData, RequestFiles, encode
from .content_streams import ContentStream, RequestData, RequestFiles, encode
from .decoders import (
ACCEPT_ENCODING,
SUPPORTED_DECODERS,
@ -71,8 +71,6 @@ ProxiesTypes = typing.Union[
URLTypes, "Dispatcher", typing.Dict[URLTypes, typing.Union[URLTypes, "Dispatcher"]]
]
ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
class URL:
def __init__(
@ -595,6 +593,7 @@ class Request:
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
stream: ContentStream = None,
):
self.method = method.upper()
self.url = URL(url, params=params)
@ -602,11 +601,16 @@ class Request:
if cookies:
self._cookies = Cookies(cookies)
self._cookies.set_cookie_header(self)
self.content = encode(data, files, json)
if stream is not None:
self.stream = stream
else:
self.stream = encode(data, files, json)
self.prepare()
def prepare(self) -> None:
for key, value in self.content.get_headers().items():
for key, value in self.stream.get_headers().items():
self.headers.setdefault(key, value)
auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []
@ -649,11 +653,7 @@ class Request:
"""
Read and return the request content.
"""
return await self.content.aread()
async def stream(self) -> typing.AsyncIterator[bytes]:
async for part in self.content:
yield part
return b"".join([part async for part in self.stream])
class Response:
@ -663,8 +663,8 @@ class Response:
*,
http_version: str = None,
headers: HeaderTypes = None,
content: ResponseContent = None,
on_close: typing.Callable = None,
stream: ContentStream = None,
content: bytes = None,
request: Request = None,
history: typing.List["Response"] = None,
elapsed: datetime.timedelta = None,
@ -674,20 +674,19 @@ class Response:
self.headers = Headers(headers)
self.request = request
self.on_close = on_close
self.elapsed = datetime.timedelta(0) if elapsed is None else elapsed
self.call_next: typing.Optional[typing.Callable] = None
self.history = [] if history is None else list(history)
if content is None or isinstance(content, bytes):
if stream is None:
self.is_closed = True
self.is_stream_consumed = True
self._raw_content = content or b""
else:
self.is_closed = False
self.is_stream_consumed = False
self._raw_stream = content
self._raw_stream = stream
@property
def reason_phrase(self) -> str:
@ -942,8 +941,8 @@ class Response:
"""
if not self.is_closed:
self.is_closed = True
if self.on_close is not None:
await self.on_close()
if hasattr(self, "_raw_stream"):
await self._raw_stream.aclose()
class Cookies(MutableMapping):

View File

@ -1,126 +0,0 @@
import binascii
import mimetypes
import os
import re
import typing
from io import BytesIO
from pathlib import Path
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
{chr(c): "%{:02X}".format(c) for c in range(0x00, 0x1F + 1) if c != 0x1B}
)
_HTML5_FORM_ENCODING_RE = re.compile(
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
)
class Field:
def render_headers(self) -> bytes:
raise NotImplementedError() # pragma: nocover
def render_data(self) -> bytes:
raise NotImplementedError() # pragma: nocover
class DataField(Field):
def __init__(self, name: str, value: typing.Union[str, bytes]) -> None:
if not isinstance(name, str):
raise TypeError("Invalid type for name. Expected str.")
if not isinstance(value, (str, bytes)):
raise TypeError("Invalid type for value. Expected str or bytes.")
self.name = name
self.value = value
def render_headers(self) -> bytes:
name = _format_param("name", self.name)
return b"".join([b"Content-Disposition: form-data; ", name, b"\r\n\r\n"])
def render_data(self) -> bytes:
return (
self.value if isinstance(self.value, bytes) else self.value.encode("utf-8")
)
class FileField(Field):
def __init__(
self, name: str, value: typing.Union[typing.IO[typing.AnyStr], tuple]
) -> None:
self.name = name
if not isinstance(value, tuple):
self.filename = Path(str(getattr(value, "name", "upload"))).name
self.file = value # type: typing.Union[typing.IO[str], typing.IO[bytes]]
self.content_type = self.guess_content_type()
else:
self.filename = value[0]
self.file = value[1]
self.content_type = (
value[2] if len(value) > 2 else self.guess_content_type()
)
def guess_content_type(self) -> typing.Optional[str]:
if self.filename:
return mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
else:
return None
def render_headers(self) -> bytes:
parts = [b"Content-Disposition: form-data; ", _format_param("name", self.name)]
if self.filename:
filename = _format_param("filename", self.filename)
parts.extend([b"; ", filename])
if self.content_type is not None:
content_type = self.content_type.encode()
parts.extend([b"\r\nContent-Type: ", content_type])
parts.append(b"\r\n\r\n")
return b"".join(parts)
def render_data(self) -> bytes:
if isinstance(self.file, str):
content = self.file
else:
content = self.file.read()
return content.encode("utf-8") if isinstance(content, str) else content
def iter_fields(data: dict, files: dict) -> typing.Iterator[Field]:
for name, value in data.items():
if isinstance(value, (list, dict)):
for item in value:
yield DataField(name=name, value=item)
else:
yield DataField(name=name, value=value)
for name, value in files.items():
yield FileField(name=name, value=value)
def multipart_encode(
data: dict, files: dict, boundary: bytes = None
) -> typing.Tuple[bytes, str]:
body = BytesIO()
if boundary is None:
boundary = binascii.hexlify(os.urandom(16))
for field in iter_fields(data, files):
body.write(b"--%s\r\n" % boundary)
body.write(field.render_headers())
body.write(field.render_data())
body.write(b"\r\n")
body.write(b"--%s--\r\n" % boundary)
content_type = "multipart/form-data; boundary=%s" % boundary.decode("ascii")
return body.getvalue(), content_type
def _format_param(name: str, value: typing.Union[str, bytes]) -> bytes:
if isinstance(value, bytes):
value = value.decode()
def replacer(match: typing.Match[str]) -> str:
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
return f'{name}="{value}"'.encode()

View File

@ -17,6 +17,15 @@ if typing.TYPE_CHECKING: # pragma: no cover
from .models import URL
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
{chr(c): "%{:02X}".format(c) for c in range(0x00, 0x1F + 1) if c != 0x1B}
)
_HTML5_FORM_ENCODING_RE = re.compile(
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
)
def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header key.
@ -61,6 +70,20 @@ def is_known_encoding(encoding: str) -> bool:
return True
def format_form_param(name: str, value: typing.Union[str, bytes]) -> bytes:
"""
Encode a name/value pair within a multipart form.
"""
if isinstance(value, bytes):
value = value.decode()
def replacer(match: typing.Match[str]) -> str:
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
return f'{name}="{value}"'.encode()
# Null bytes; no need to recreate these on each call to guess_json_utf
_null = "\x00".encode("ascii") # encoding to ASCII for Python 3
_null2 = _null * 2

View File

@ -81,12 +81,12 @@ class MockDispatch(Dispatcher):
return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
elif request.url.path == "/redirect_no_body":
await request.read()
content = b"".join([part async for part in request.stream])
headers = {"location": "/redirect_body_target"}
return Response(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/redirect_body_target":
content = await request.read()
content = b"".join([part async for part in request.stream])
headers = dict(request.headers.items())
body = json.dumps({"body": content.decode(), "headers": headers}).encode()
return Response(codes.OK, content=body, request=request)

View File

@ -14,7 +14,7 @@ from .utils import MockHTTP2Backend
async def app(request):
method = request.method
path = request.url.path
body = await request.read()
body = b"".join([part async for part in request.stream])
content = json.dumps(
{"method": method, "path": path, "body": body.decode()}
).encode()

View File

@ -21,15 +21,19 @@ def test_content_length_header():
@pytest.mark.asyncio
async def test_url_encoded_data():
request = httpx.Request("POST", "http://example.org", data={"test": "123"})
content = b"".join([part async for part in request.stream])
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
assert await request.content.aread() == b"test=123"
assert content == b"test=123"
@pytest.mark.asyncio
async def test_json_encoded_data():
request = httpx.Request("POST", "http://example.org", json={"test": 123})
content = b"".join([part async for part in request.stream])
assert request.headers["Content-Type"] == "application/json"
assert await request.content.aread() == b'{"test": 123}'
assert content == b'{"test": 123}'
def test_transfer_encoding_header():

View File

@ -5,6 +5,7 @@ from unittest import mock
import pytest
import httpx
from httpx.content_streams import AsyncIteratorStream
def streaming_body():
@ -190,7 +191,8 @@ async def test_stream_interface_after_read():
@pytest.mark.asyncio
async def test_streaming_response():
response = httpx.Response(200, content=async_streaming_body())
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)
assert response.status_code == 200
assert not response.is_closed
@ -204,7 +206,8 @@ async def test_streaming_response():
@pytest.mark.asyncio
async def test_cannot_read_after_stream_consumed():
response = httpx.Response(200, content=async_streaming_body())
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)
content = b""
async for part in response.aiter_bytes():
@ -216,7 +219,8 @@ async def test_cannot_read_after_stream_consumed():
@pytest.mark.asyncio
async def test_cannot_read_after_response_closed():
response = httpx.Response(200, content=async_streaming_body())
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream)
await response.close()

View File

@ -2,25 +2,27 @@ import io
import pytest
from httpx.content import encode
from httpx.content_streams import encode
@pytest.mark.asyncio
async def test_empty_content():
content = encode()
stream = encode()
content = b"".join([part async for part in stream])
assert content.can_replay()
assert content.get_headers() == {}
assert await content.aread() == b""
assert stream.can_replay()
assert stream.get_headers() == {}
assert content == b""
@pytest.mark.asyncio
async def test_bytes_content():
content = encode(data=b"Hello, world!")
stream = encode(data=b"Hello, world!")
content = b"".join([part async for part in stream])
assert content.can_replay()
assert content.get_headers() == {"Content-Length": "13"}
assert await content.aread() == b"Hello, world!"
assert stream.can_replay()
assert stream.get_headers() == {"Content-Length": "13"}
assert content == b"Hello, world!"
@pytest.mark.asyncio
@ -29,48 +31,52 @@ async def test_aiterator_content():
yield b"Hello, "
yield b"world!"
content = encode(data=hello_world())
stream = encode(data=hello_world())
content = b"".join([part async for part in stream])
assert not content.can_replay()
assert content.get_headers() == {"Transfer-Encoding": "chunked"}
assert await content.aread() == b"Hello, world!"
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_json_content():
content = encode(json={"Hello": "world!"})
stream = encode(json={"Hello": "world!"})
content = b"".join([part async for part in stream])
assert content.can_replay()
assert content.get_headers() == {
assert stream.can_replay()
assert stream.get_headers() == {
"Content-Length": "19",
"Content-Type": "application/json",
}
assert await content.aread() == b'{"Hello": "world!"}'
assert content == b'{"Hello": "world!"}'
@pytest.mark.asyncio
async def test_urlencoded_content():
content = encode(data={"Hello": "world!"})
stream = encode(data={"Hello": "world!"})
content = b"".join([part async for part in stream])
assert content.can_replay()
assert content.get_headers() == {
assert stream.can_replay()
assert stream.get_headers() == {
"Content-Length": "14",
"Content-Type": "application/x-www-form-urlencoded",
}
assert await content.aread() == b"Hello=world%21"
assert content == b"Hello=world%21"
@pytest.mark.asyncio
async def test_multipart_files_content():
files = {"file": io.BytesIO(b"<file content>")}
content = encode(files=files, boundary=b"+++")
stream = encode(files=files, boundary=b"+++")
content = b"".join([part async for part in stream])
assert content.can_replay()
assert content.get_headers() == {
assert stream.can_replay()
assert stream.get_headers() == {
"Content-Length": "138",
"Content-Type": "multipart/form-data; boundary=+++",
}
assert await content.aread() == b"".join(
assert content == b"".join(
[
b"--+++\r\n",
b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
@ -86,14 +92,15 @@ async def test_multipart_files_content():
async def test_multipart_data_and_files_content():
data = {"message": "Hello, world!"}
files = {"file": io.BytesIO(b"<file content>")}
content = encode(data=data, files=files, boundary=b"+++")
stream = encode(data=data, files=files, boundary=b"+++")
content = b"".join([part async for part in stream])
assert content.can_replay()
assert content.get_headers() == {
assert stream.can_replay()
assert stream.get_headers() == {
"Content-Length": "210",
"Content-Type": "multipart/form-data; boundary=+++",
}
assert await content.aread() == b"".join(
assert content == b"".join(
[
b"--+++\r\n",
b'Content-Disposition: form-data; name="message"\r\n',

View File

@ -4,6 +4,7 @@ import brotli
import pytest
import httpx
from httpx.content_streams import AsyncIteratorStream
from httpx.decoders import (
BrotliDecoder,
DeflateDecoder,
@ -82,7 +83,8 @@ async def test_streaming():
yield compressor.flush()
headers = [(b"Content-Encoding", b"gzip")]
response = httpx.Response(200, headers=headers, content=compress(body))
stream = AsyncIteratorStream(aiterator=compress(body))
response = httpx.Response(200, headers=headers, stream=stream)
assert not hasattr(response, "body")
assert await response.read() == body
@ -137,7 +139,8 @@ async def test_text_decoder(data, encoding):
for chunk in data:
yield chunk
response = httpx.Response(200, content=iterator())
stream = AsyncIteratorStream(aiterator=iterator())
response = httpx.Response(200, stream=stream)
await response.read()
assert response.text == (b"".join(data)).decode(encoding)
@ -149,10 +152,11 @@ async def test_text_decoder_known_encoding():
yield b"\x83"
yield b"\x89\x83x\x83\x8b"
stream = AsyncIteratorStream(aiterator=iterator())
response = httpx.Response(
200,
headers=[(b"Content-Type", b"text/html; charset=shift-jis")],
content=iterator(),
stream=stream,
)
await response.read()

View File

@ -8,8 +8,9 @@ import pytest
import httpx
from httpx.config import CertTypes, TimeoutTypes, VerifyTypes
from httpx.content_streams import encode
from httpx.dispatch.base import Dispatcher
from httpx.multipart import _format_param, multipart_encode
from httpx.utils import format_form_param
class MockDispatch(Dispatcher):
@ -20,7 +21,7 @@ class MockDispatch(Dispatcher):
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> httpx.Response:
content = await request.read()
content = b"".join([part async for part in request.stream])
return httpx.Response(200, content=content)
@ -105,9 +106,9 @@ def test_multipart_encode():
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
body, content_type = multipart_encode(data=data, files=files)
assert content_type == f"multipart/form-data; boundary={boundary}"
assert body == (
stream = encode(data=data, files=files)
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
assert stream.body == (
'--{0}\r\nContent-Disposition: form-data; name="a"\r\n\r\n1\r\n'
'--{0}\r\nContent-Disposition: form-data; name="b"\r\n\r\nC\r\n'
'--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n11\r\n'
@ -129,10 +130,10 @@ def test_multipart_encode_files_allows_filenames_as_none():
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
body, content_type = multipart_encode(data={}, files=files)
stream = encode(data={}, files=files)
assert content_type == f"multipart/form-data; boundary={boundary}"
assert body == (
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
assert stream.body == (
'--{0}\r\nContent-Disposition: form-data; name="file"\r\n\r\n'
"<file content>\r\n--{0}--\r\n"
"".format(boundary).encode("ascii")
@ -154,10 +155,10 @@ def test_multipart_encode_files_guesses_correct_content_type(
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
body, content_type = multipart_encode(data={}, files=files)
stream = encode(data={}, files=files)
assert content_type == f"multipart/form-data; boundary={boundary}"
assert body == (
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
assert stream.body == (
f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
f'filename="{file_name}"\r\nContent-Type: '
f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
@ -170,10 +171,10 @@ def test_multipart_encode_files_allows_str_content():
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
body, content_type = multipart_encode(data={}, files=files)
stream = encode(data={}, files=files)
assert content_type == f"multipart/form-data; boundary={boundary}"
assert body == (
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
assert stream.body == (
'--{0}\r\nContent-Disposition: form-data; name="file"; '
'filename="test.txt"\r\n'
"Content-Type: text/plain\r\n\r\n<string content>\r\n"
@ -184,17 +185,17 @@ def test_multipart_encode_files_allows_str_content():
class TestHeaderParamHTML5Formatting:
def test_unicode(self):
param = _format_param("filename", "n\u00e4me")
param = format_form_param("filename", "n\u00e4me")
assert param == b'filename="n\xc3\xa4me"'
def test_ascii(self):
param = _format_param("filename", b"name")
param = format_form_param("filename", b"name")
assert param == b'filename="name"'
def test_unicode_escape(self):
param = _format_param("filename", "hello\\world\u0022")
param = format_form_param("filename", "hello\\world\u0022")
assert param == b'filename="hello\\\\world%22"'
def test_unicode_with_control_character(self):
param = _format_param("filename", "hello\x1A\x1B\x1C")
param = format_form_param("filename", "hello\x1A\x1B\x1C")
assert param == b'filename="hello%1A\x1B%1C"'