Drop ContentStream (#1295)

* Drop ContentStream
This commit is contained in:
Tom Christie 2020-09-18 08:41:09 +01:00 committed by GitHub
parent e1f7791e97
commit fbb21fb1ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 416 additions and 311 deletions

View File

@ -74,7 +74,6 @@ except httpx.HTTPStatusError as exc:
* UnsupportedProtocol
* DecodingError
* TooManyRedirects
* RequestBodyUnavailable
* HTTPStatusError
* InvalidURL
* NotRedirectResponse
@ -149,9 +148,6 @@ except httpx.HTTPStatusError as exc:
::: httpx.TooManyRedirects
:docstring:
::: httpx.RequestBodyUnavailable
:docstring:
::: httpx.HTTPStatusError
:docstring:

View File

@ -21,7 +21,6 @@ from ._exceptions import (
ReadError,
ReadTimeout,
RemoteProtocolError,
RequestBodyUnavailable,
RequestError,
RequestNotRead,
ResponseClosed,
@ -84,7 +83,6 @@ __all__ = [
"RemoteProtocolError",
"request",
"Request",
"RequestBodyUnavailable",
"RequestError",
"RequestNotRead",
"Response",

View File

@ -6,7 +6,7 @@ import typing
from base64 import b64encode
from urllib.request import parse_http_list
from ._exceptions import ProtocolError, RequestBodyUnavailable
from ._exceptions import ProtocolError
from ._models import Request, Response
from ._utils import to_bytes, to_str, unquote
@ -157,13 +157,6 @@ class DigestAuth(Auth):
self._password = to_bytes(password)
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
if not request.stream.can_replay():
raise RequestBodyUnavailable(
"Cannot use digest auth with streaming requests that are unable "
"to replay the request body if a second request is required.",
request=request,
)
response = yield request
if response.status_code != 401 or "www-authenticate" not in response.headers:

View File

@ -1,3 +1,4 @@
import datetime
import functools
import typing
import warnings
@ -18,13 +19,11 @@ from ._config import (
UnsetType,
create_ssl_context,
)
from ._content_streams import ContentStream
from ._decoders import SUPPORTED_DECODERS
from ._exceptions import (
HTTPCORE_EXC_MAP,
InvalidURL,
RemoteProtocolError,
RequestBodyUnavailable,
TooManyRedirects,
map_exceptions,
)
@ -34,6 +33,7 @@ from ._transports.asgi import ASGITransport
from ._transports.wsgi import WSGITransport
from ._types import (
AuthTypes,
ByteStream,
CertTypes,
CookieTypes,
HeaderTypes,
@ -480,20 +480,13 @@ class BaseClient:
def _redirect_stream(
self, request: Request, method: str
) -> typing.Optional[ContentStream]:
) -> typing.Optional[ByteStream]:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return None
if not request.stream.can_replay():
raise RequestBodyUnavailable(
"Got a redirect response, but the request body was streaming "
"and is no longer available.",
request=request,
)
return request.stream
@ -864,16 +857,22 @@ class Client(BaseClient):
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
stream=request.stream,
stream=request.stream, # type: ignore
timeout=timeout.as_dict(),
)
def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(timer.sync_elapsed())
if hasattr(stream, "close"):
stream.close()
response = Response(
status_code,
http_version=http_version.decode("ascii"),
headers=headers,
stream=stream, # type: ignore
request=request,
elapsed_func=timer.sync_elapsed,
on_close=on_close,
)
self.cookies.extract_cookies(response)
@ -1509,16 +1508,22 @@ class AsyncClient(BaseClient):
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
stream=request.stream,
stream=request.stream, # type: ignore
timeout=timeout.as_dict(),
)
async def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(await timer.async_elapsed())
if hasattr(stream, "close"):
await stream.aclose()
response = Response(
status_code,
http_version=http_version.decode("ascii"),
headers=headers,
stream=stream, # type: ignore
request=request,
elapsed_func=timer.async_elapsed,
on_close=on_close,
)
self.cookies.extract_cookies(response)

View File

@ -1,14 +1,14 @@
import binascii
import inspect
import os
import typing
from json import dumps as json_dumps
from pathlib import Path
from urllib.parse import urlencode
import httpcore
from ._exceptions import StreamConsumed
from ._types import (
ByteStream,
FileContent,
FileTypes,
RequestContent,
@ -24,36 +24,7 @@ from ._utils import (
)
class ContentStream(httpcore.AsyncByteStream, httpcore.SyncByteStream):
def get_headers(self) -> typing.Dict[str, str]:
"""
Return a dictionary of headers that are implied by the encoding.
"""
return {}
def can_replay(self) -> bool:
"""
Return `True` if `__aiter__` can be called multiple times.
We need this in cases such determining if we can re-issue a request
body when we receive a redirect response.
"""
return True
def __iter__(self) -> typing.Iterator[bytes]:
yield b""
def close(self) -> None:
pass
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b""
async def aclose(self) -> None:
pass
class ByteStream(ContentStream):
class PlainByteStream:
"""
Request content encoded as plain bytes.
"""
@ -74,59 +45,41 @@ class ByteStream(ContentStream):
yield self.body
class IteratorStream(ContentStream):
class GeneratorStream:
"""
Request content encoded as plain bytes, using an byte iterator.
Request content encoded as plain bytes, using an byte generator.
"""
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
self.iterator = iterator
self.is_stream_consumed = False
def can_replay(self) -> bool:
return False
def get_headers(self) -> typing.Dict[str, str]:
return {"Transfer-Encoding": "chunked"}
def __init__(self, generator: typing.Iterable[bytes]) -> None:
self._generator = generator
self._is_stream_consumed = False
def __iter__(self) -> typing.Iterator[bytes]:
if self.is_stream_consumed:
if self._is_stream_consumed:
raise StreamConsumed()
self.is_stream_consumed = True
for part in self.iterator:
self._is_stream_consumed = True
for part in self._generator:
yield part
def __aiter__(self) -> typing.AsyncIterator[bytes]:
raise RuntimeError("Attempted to call a async iterator on an sync stream.")
class AsyncIteratorStream(ContentStream):
class AsyncGeneratorStream:
"""
Request content encoded as plain bytes, using an async byte iterator.
"""
def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
self.aiterator = aiterator
self.is_stream_consumed = False
def can_replay(self) -> bool:
return False
def get_headers(self) -> typing.Dict[str, str]:
return {"Transfer-Encoding": "chunked"}
def __iter__(self) -> typing.Iterator[bytes]:
raise RuntimeError("Attempted to call a sync iterator on an async stream.")
def __init__(self, agenerator: typing.AsyncIterable[bytes]) -> None:
self._agenerator = agenerator
self._is_stream_consumed = False
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
if self.is_stream_consumed:
if self._is_stream_consumed:
raise StreamConsumed()
self.is_stream_consumed = True
async for part in self.aiterator:
self._is_stream_consumed = True
async for part in self._agenerator:
yield part
class JSONStream(ContentStream):
class JSONStream:
"""
Request content encoded as JSON.
"""
@ -146,7 +99,7 @@ class JSONStream(ContentStream):
yield self.body
class URLEncodedStream(ContentStream):
class URLEncodedStream:
"""
Request content as URL encoded form data.
"""
@ -166,7 +119,7 @@ class URLEncodedStream(ContentStream):
yield self.body
class MultipartStream(ContentStream):
class MultipartStream:
"""
Request content as streaming multipart encoded form data.
"""
@ -208,9 +161,6 @@ class MultipartStream(ContentStream):
data = self.render_data()
return len(headers) + len(data)
def can_replay(self) -> bool:
return True
def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
yield self.render_data()
@ -239,6 +189,7 @@ class MultipartStream(ContentStream):
self.filename = filename
self.file = fileobj
self.content_type = content_type
self._consumed = False
def get_length(self) -> int:
headers = self.render_headers()
@ -284,17 +235,13 @@ class MultipartStream(ContentStream):
yield self._data
return
if self._consumed: # pragma: nocover
self.file.seek(0)
self._consumed = True
for chunk in self.file:
yield to_bytes(chunk)
# Get ready for the next replay, if possible.
if self.can_replay():
assert self.file.seekable()
self.file.seek(0)
def can_replay(self) -> bool:
return True if isinstance(self.file, (str, bytes)) else self.file.seekable()
def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
yield from self.render_data()
@ -346,9 +293,6 @@ class MultipartStream(ContentStream):
# Content stream interface.
def can_replay(self) -> bool:
return all(field.can_replay() for field in self.fields)
def get_headers(self) -> typing.Dict[str, str]:
content_length = str(self.get_content_length())
content_type = self.content_type
@ -369,10 +313,10 @@ def encode_request(
files: RequestFiles = None,
json: typing.Any = None,
boundary: bytes = None,
) -> ContentStream:
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
"""
Handles encoding the given `content`, `data`, `files`, and `json`,
returning a `ContentStream` implementation.
returning a two-tuple of (<headers>, <stream>).
"""
if data is not None and not isinstance(data, dict):
# We prefer to seperate `content=<bytes|byte iterator|bytes aiterator>`
@ -387,39 +331,65 @@ def encode_request(
if content is not None:
if isinstance(content, (str, bytes)):
return ByteStream(body=content)
elif hasattr(content, "__aiter__"):
content = typing.cast(typing.AsyncIterator[bytes], content)
return AsyncIteratorStream(aiterator=content)
elif hasattr(content, "__iter__"):
content = typing.cast(typing.Iterator[bytes], content)
return IteratorStream(iterator=content)
byte_stream = PlainByteStream(body=content)
headers = byte_stream.get_headers()
return headers, byte_stream
elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
if inspect.isgenerator(content):
generator_stream = GeneratorStream(content) # type: ignore
return {"Transfer-Encoding": "chunked"}, generator_stream
if inspect.isasyncgen(content):
agenerator_stream = AsyncGeneratorStream(content) # type: ignore
return {"Transfer-Encoding": "chunked"}, agenerator_stream
return {"Transfer-Encoding": "chunked"}, content # type: ignore
else:
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
elif data:
if files:
return MultipartStream(data=data, files=files, boundary=boundary)
multipart_stream = MultipartStream(
data=data, files=files, boundary=boundary
)
headers = multipart_stream.get_headers()
return headers, multipart_stream
else:
return URLEncodedStream(data=data)
urlencoded_stream = URLEncodedStream(data=data)
headers = urlencoded_stream.get_headers()
return headers, urlencoded_stream
elif files:
return MultipartStream(data={}, files=files, boundary=boundary)
multipart_stream = MultipartStream(data={}, files=files, boundary=boundary)
headers = multipart_stream.get_headers()
return headers, multipart_stream
elif json is not None:
return JSONStream(json=json)
json_stream = JSONStream(json=json)
headers = json_stream.get_headers()
return headers, json_stream
return ByteStream(body=b"")
byte_stream = PlainByteStream(body=b"")
headers = byte_stream.get_headers()
return headers, byte_stream
def encode_response(content: ResponseContent = None) -> ContentStream:
def encode_response(
content: ResponseContent = None,
) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
if content is None:
return ByteStream(b"")
byte_stream = PlainByteStream(b"")
headers = byte_stream.get_headers()
return headers, byte_stream
elif isinstance(content, bytes):
return ByteStream(body=content)
elif isinstance(content, typing.AsyncIterator):
return AsyncIteratorStream(aiterator=content)
elif isinstance(content, typing.Iterator):
return IteratorStream(iterator=content)
byte_stream = PlainByteStream(body=content)
headers = byte_stream.get_headers()
return headers, byte_stream
elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
if inspect.isgenerator(content):
generator_stream = GeneratorStream(content) # type: ignore
return {"Transfer-Encoding": "chunked"}, generator_stream
elif inspect.isasyncgen(content):
agenerator_stream = AsyncGeneratorStream(content) # type: ignore
return {"Transfer-Encoding": "chunked"}, agenerator_stream
return {"Transfer-Encoding": "chunked"}, content # type: ignore
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")

View File

@ -207,13 +207,6 @@ class TooManyRedirects(RequestError):
"""
class RequestBodyUnavailable(RequestError):
"""
Had to send the request again, but the request body was streaming, and is
no longer available.
"""
# Client errors
@ -283,14 +276,18 @@ class StreamError(Exception):
class StreamConsumed(StreamError):
"""
Attempted to read or stream response content, but the content has already
Attempted to read or stream content, but the content has already
been streamed.
"""
def __init__(self) -> None:
message = (
"Attempted to read or stream response content, but the content has "
"already been streamed."
"Attempted to read or stream some content, but the content has "
"already been streamed. For requests, this could be due to passing "
"a generator as request content, and then receiving a redirect "
"response or a secondary request as part of an authentication flow."
"For responses, this could be due to attempting to stream the response "
"content more than once."
)
super().__init__(message)

View File

@ -13,7 +13,7 @@ from urllib.parse import parse_qsl, quote, unquote, urlencode
import rfc3986
import rfc3986.exceptions
from ._content_streams import ByteStream, ContentStream, encode_request, encode_response
from ._content_streams import PlainByteStream, encode_request, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ContentDecoder,
@ -37,6 +37,7 @@ from ._exceptions import (
)
from ._status_codes import codes
from ._types import (
ByteStream,
CookieTypes,
HeaderTypes,
PrimitiveData,
@ -606,7 +607,7 @@ class Request:
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
stream: ContentStream = None,
stream: ByteStream = None,
):
if isinstance(method, bytes):
self.method = method.decode("ascii").upper()
@ -618,14 +619,28 @@ class Request:
Cookies(cookies).set_cookie_header(self)
if stream is not None:
# There's an important distinction between `Request(content=...)`,
# and `Request(stream=...)`.
#
# Using `content=...` implies automatically populated content headers,
# of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
#
# Using `stream=...` will not automatically include any content headers.
#
# As an end-user you don't really need `stream=...`. It's only
# useful when:
#
# * Preserving the request stream when copying requests, eg for redirects.
# * Creating request instances on the *server-side* of the transport API.
self.stream = stream
self._prepare({})
else:
self.stream = encode_request(content, data, files, json)
headers, stream = encode_request(content, data, files, json)
self._prepare(headers)
self.stream = stream
self._prepare()
def _prepare(self) -> None:
for key, value in self.stream.get_headers().items():
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
for key, value in default_headers.items():
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
continue
@ -657,11 +672,12 @@ class Request:
Read and return the request content.
"""
if not hasattr(self, "_content"):
assert isinstance(self.stream, typing.Iterable)
self._content = b"".join(self.stream)
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
# to ensure that any non-replayable streams can still be used.
self.stream = ByteStream(self._content)
self.stream = PlainByteStream(self._content)
return self._content
async def aread(self) -> bytes:
@ -669,11 +685,12 @@ class Request:
Read and return the request content.
"""
if not hasattr(self, "_content"):
assert isinstance(self.stream, typing.AsyncIterable)
self._content = b"".join([part async for part in self.stream])
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
# to ensure that any non-replayable streams can still be used.
self.stream = ByteStream(self._content)
self.stream = PlainByteStream(self._content)
return self._content
def __repr__(self) -> str:
@ -690,10 +707,10 @@ class Response:
request: Request = None,
http_version: str = None,
headers: HeaderTypes = None,
stream: ContentStream = None,
content: ResponseContent = None,
stream: ByteStream = None,
history: typing.List["Response"] = None,
elapsed_func: typing.Callable = None,
on_close: typing.Callable = None,
):
self.status_code = status_code
self.http_version = http_version
@ -704,20 +721,41 @@ class Response:
self.call_next: typing.Optional[typing.Callable] = None
self.history = [] if history is None else list(history)
self._elapsed_func = elapsed_func
self._on_close = on_close
self.is_closed = False
self.is_stream_consumed = False
if stream is not None:
self._raw_stream = stream
# There's an important distinction between `Response(content=...)`,
# and `Response(stream=...)`.
#
# Using `content=...` implies automatically populated content headers,
# of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
#
# Using `stream=...` will not automatically include any content headers.
#
# As an end-user you don't really need `stream=...`. It's only
# useful when creating response instances having received a stream
# from the transport API.
self.stream = stream
else:
self._raw_stream = encode_response(content)
headers, stream = encode_response(content)
self._prepare(headers)
self.stream = stream
if content is None or isinstance(content, bytes):
# Load the response body, except for streaming content.
self.read()
self._num_bytes_downloaded = 0
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
for key, value in default_headers.items():
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
continue
self.headers.setdefault(key, value)
@property
def elapsed(self) -> datetime.timedelta:
"""
@ -729,7 +767,11 @@ class Response:
"'.elapsed' may only be accessed after the response "
"has been read or closed."
)
return datetime.timedelta(seconds=self._elapsed)
return self._elapsed
@elapsed.setter
def elapsed(self, elapsed: datetime.timedelta) -> None:
self._elapsed = elapsed
@property
def request(self) -> Request:
@ -963,11 +1005,13 @@ class Response:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
if not isinstance(self.stream, typing.Iterable):
raise RuntimeError("Attempted to call a sync iterator on an async stream.")
self.is_stream_consumed = True
self._num_bytes_downloaded = 0
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
for part in self._raw_stream:
for part in self.stream:
self._num_bytes_downloaded += len(part)
yield part
self.close()
@ -992,9 +1036,8 @@ class Response:
"""
if not self.is_closed:
self.is_closed = True
if self._elapsed_func is not None:
self._elapsed = self._elapsed_func()
self._raw_stream.close()
if self._on_close is not None:
self._on_close(self)
async def aread(self) -> bytes:
"""
@ -1047,11 +1090,13 @@ class Response:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
if not isinstance(self.stream, typing.AsyncIterable):
raise RuntimeError("Attempted to call a async iterator on a sync stream.")
self.is_stream_consumed = True
self._num_bytes_downloaded = 0
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
async for part in self._raw_stream:
async for part in self.stream:
self._num_bytes_downloaded += len(part)
yield part
await self.aclose()
@ -1075,9 +1120,8 @@ class Response:
"""
if not self.is_closed:
self.is_closed = True
if self._elapsed_func is not None:
self._elapsed = await self._elapsed_func()
await self._raw_stream.aclose()
if self._on_close is not None:
await self._on_close(self)
class Cookies(MutableMapping):

View File

@ -7,10 +7,10 @@ from http.cookiejar import CookieJar
from typing import (
IO,
TYPE_CHECKING,
AsyncIterator,
AsyncIterable,
Callable,
Dict,
Iterator,
Iterable,
List,
Mapping,
Optional,
@ -66,8 +66,9 @@ AuthTypes = Union[
None,
]
RequestContent = Union[str, bytes, Iterator[bytes], AsyncIterator[bytes]]
ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]]
RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
ResponseContent = Union[bytes, Iterable[bytes], AsyncIterable[bytes]]
ByteStream = Union[Iterable[bytes], AsyncIterable[bytes]]
RequestData = dict

View File

@ -13,16 +13,7 @@ import typing
import pytest
import httpx
from httpx import (
URL,
Auth,
BasicAuth,
DigestAuth,
ProtocolError,
Request,
RequestBodyUnavailable,
Response,
)
from httpx import URL, Auth, BasicAuth, DigestAuth, ProtocolError, Request, Response
from tests.utils import AsyncMockTransport, MockTransport
from ..common import FIXTURES_DIR
@ -617,13 +608,13 @@ def test_sync_auth_history() -> None:
async def test_digest_auth_unavailable_streaming_body():
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
app = App()
app = DigestApp()
async def streaming_body():
yield b"Example request body" # pragma: nocover
async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
with pytest.raises(RequestBodyUnavailable):
with pytest.raises(httpx.StreamConsumed):
await client.post(url, data=streaming_body(), auth=auth)

View File

@ -334,7 +334,7 @@ def test_cannot_redirect_streaming_body():
def streaming_body():
yield b"Example request body" # pragma: nocover
with pytest.raises(httpx.RequestBodyUnavailable):
with pytest.raises(httpx.StreamConsumed):
client.post(url, content=streaming_body())

View File

@ -1,3 +1,5 @@
import typing
import pytest
import httpx
@ -18,6 +20,40 @@ def test_content_length_header():
assert request.headers["Content-Length"] == "8"
def test_iterable_content():
class Content:
def __iter__(self):
yield b"test 123" # pragma: nocover
request = httpx.Request("POST", "http://example.org", content=Content())
assert request.headers == httpx.Headers(
{"Host": "example.org", "Transfer-Encoding": "chunked"}
)
def test_generator_with_transfer_encoding_header():
def content():
yield b"test 123" # pragma: nocover
request = httpx.Request("POST", "http://example.org", content=content())
assert request.headers == httpx.Headers(
{"Host": "example.org", "Transfer-Encoding": "chunked"}
)
def test_generator_with_content_length_header():
def content():
yield b"test 123" # pragma: nocover
headers = {"Content-Length": "8"}
request = httpx.Request(
"POST", "http://example.org", content=content(), headers=headers
)
assert request.headers == httpx.Headers(
{"Host": "example.org", "Content-Length": "8"}
)
def test_url_encoded_data():
request = httpx.Request("POST", "http://example.org", data={"test": "123"})
request.read()
@ -51,6 +87,8 @@ def test_read_and_stream_data():
# Needed for cases such as authentication classes that read the request body.
request = httpx.Request("POST", "http://example.org", json={"test": 123})
request.read()
assert request.stream is not None
assert isinstance(request.stream, typing.Iterable)
content = b"".join([part for part in request.stream])
assert content == request.content
@ -61,6 +99,8 @@ async def test_aread_and_stream_data():
# Needed for cases such as authentication classes that read the request body.
request = httpx.Request("POST", "http://example.org", json={"test": 123})
await request.aread()
assert request.stream is not None
assert isinstance(request.stream, typing.AsyncIterable)
content = b"".join([part async for part in request.stream])
assert content == request.content
@ -68,7 +108,7 @@ async def test_aread_and_stream_data():
@pytest.mark.asyncio
async def test_cannot_access_content_without_read():
# Ensure a request may still be streamed if it has been read.
#  Needed for cases such as authentication classes that read the request body.
# Needed for cases such as authentication classes that read the request body.
request = httpx.Request("POST", "http://example.org", json={"test": 123})
with pytest.raises(httpx.RequestNotRead):
request.content

View File

@ -7,6 +7,12 @@ import pytest
import httpx
class StreamingBody:
def __iter__(self):
yield b"Hello, "
yield b"world!"
def streaming_body():
yield b"Hello, "
yield b"world!"
@ -230,6 +236,21 @@ def test_read():
assert response.is_closed
def test_empty_read():
response = httpx.Response(200)
assert response.status_code == 200
assert response.text == ""
assert response.encoding is None
assert response.is_closed
content = response.read()
assert content == b""
assert response.content == b""
assert response.is_closed
@pytest.mark.asyncio
async def test_aread():
response = httpx.Response(
@ -249,6 +270,22 @@ async def test_aread():
assert response.is_closed
@pytest.mark.asyncio
async def test_empty_aread():
response = httpx.Response(200)
assert response.status_code == 200
assert response.text == ""
assert response.encoding is None
assert response.is_closed
content = await response.aread()
assert content == b""
assert response.content == b""
assert response.is_closed
def test_iter_raw():
response = httpx.Response(
200,
@ -261,6 +298,28 @@ def test_iter_raw():
assert raw == b"Hello, world!"
def test_iter_raw_on_iterable():
response = httpx.Response(
200,
content=StreamingBody(),
)
raw = b""
for part in response.iter_raw():
raw += part
assert raw == b"Hello, world!"
def test_iter_raw_on_async():
response = httpx.Response(
200,
content=async_streaming_body(),
)
with pytest.raises(RuntimeError):
[part for part in response.iter_raw()]
def test_iter_raw_increments_updates_counter():
response = httpx.Response(200, content=streaming_body())
@ -280,6 +339,17 @@ async def test_aiter_raw():
assert raw == b"Hello, world!"
@pytest.mark.asyncio
async def test_aiter_raw_on_sync():
response = httpx.Response(
200,
content=streaming_body(),
)
with pytest.raises(RuntimeError):
[part async for part in response.aiter_raw()]
@pytest.mark.asyncio
async def test_aiter_raw_increments_updates_counter():
response = httpx.Response(200, content=async_streaming_body())
@ -610,3 +680,20 @@ def test_cannot_access_unset_request():
with pytest.raises(RuntimeError):
response.request
def test_generator_with_transfer_encoding_header():
def content():
yield b"test 123" # pragma: nocover
response = httpx.Response(200, content=content())
assert response.headers == httpx.Headers({"Transfer-Encoding": "chunked"})
def test_generator_with_content_length_header():
def content():
yield b"test 123" # pragma: nocover
headers = {"Content-Length": "8"}
response = httpx.Response(200, content=content(), headers=headers)
assert response.headers == httpx.Headers({"Content-Length": "8"})

View File

@ -1,53 +1,48 @@
import io
import typing
import pytest
from httpx import StreamConsumed
from httpx._content_streams import ContentStream, encode_request, encode_response
@pytest.mark.asyncio
async def test_base_content():
stream = ContentStream()
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {}
assert sync_content == b""
assert async_content == b""
from httpx._content_streams import encode_request, encode_response
@pytest.mark.asyncio
async def test_empty_content():
stream = encode_request()
headers, stream = encode_request()
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {}
assert headers == {}
assert sync_content == b""
assert async_content == b""
@pytest.mark.asyncio
async def test_bytes_content():
stream = encode_request(content=b"Hello, world!")
headers, stream = encode_request(content=b"Hello, world!")
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {"Content-Length": "13"}
assert headers == {"Content-Length": "13"}
assert sync_content == b"Hello, world!"
assert async_content == b"Hello, world!"
# Support 'data' for compat with requests.
stream = encode_request(data=b"Hello, world!") # type: ignore
headers, stream = encode_request(data=b"Hello, world!") # type: ignore
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {"Content-Length": "13"}
assert headers == {"Content-Length": "13"}
assert sync_content == b"Hello, world!"
assert async_content == b"Hello, world!"
@ -58,25 +53,26 @@ async def test_iterator_content():
yield b"Hello, "
yield b"world!"
stream = encode_request(content=hello_world())
headers, stream = encode_request(content=hello_world())
assert isinstance(stream, typing.Iterable)
assert not isinstance(stream, typing.AsyncIterable)
content = b"".join([part for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert headers == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
with pytest.raises(RuntimeError):
[part async for part in stream]
with pytest.raises(StreamConsumed):
[part for part in stream]
# Support 'data' for compat with requests.
stream = encode_request(data=hello_world()) # type: ignore
headers, stream = encode_request(data=hello_world()) # type: ignore
assert isinstance(stream, typing.Iterable)
assert not isinstance(stream, typing.AsyncIterable)
content = b"".join([part for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert headers == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
@ -86,36 +82,39 @@ async def test_aiterator_content():
yield b"Hello, "
yield b"world!"
stream = encode_request(content=hello_world())
headers, stream = encode_request(content=hello_world())
assert not isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
content = b"".join([part async for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert headers == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
with pytest.raises(RuntimeError):
[part for part in stream]
with pytest.raises(StreamConsumed):
[part async for part in stream]
# Support 'data' for compat with requests.
stream = encode_request(data=hello_world()) # type: ignore
headers, stream = encode_request(data=hello_world()) # type: ignore
assert not isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
content = b"".join([part async for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert headers == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_json_content():
stream = encode_request(json={"Hello": "world!"})
headers, stream = encode_request(json={"Hello": "world!"})
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
assert headers == {
"Content-Length": "19",
"Content-Type": "application/json",
}
@ -125,12 +124,14 @@ async def test_json_content():
@pytest.mark.asyncio
async def test_urlencoded_content():
stream = encode_request(data={"Hello": "world!"})
headers, stream = encode_request(data={"Hello": "world!"})
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
assert headers == {
"Content-Length": "14",
"Content-Type": "application/x-www-form-urlencoded",
}
@ -141,12 +142,14 @@ async def test_urlencoded_content():
@pytest.mark.asyncio
async def test_multipart_files_content():
files = {"file": io.BytesIO(b"<file content>")}
stream = encode_request(files=files, boundary=b"+++")
headers, stream = encode_request(files=files, boundary=b"+++")
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
assert headers == {
"Content-Length": "138",
"Content-Type": "multipart/form-data; boundary=+++",
}
@ -176,12 +179,14 @@ async def test_multipart_files_content():
async def test_multipart_data_and_files_content():
data = {"message": "Hello, world!"}
files = {"file": io.BytesIO(b"<file content>")}
stream = encode_request(data=data, files=files, boundary=b"+++")
headers, stream = encode_request(data=data, files=files, boundary=b"+++")
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
assert headers == {
"Content-Length": "210",
"Content-Type": "multipart/form-data; boundary=+++",
}
@ -217,12 +222,14 @@ async def test_multipart_data_and_files_content():
@pytest.mark.asyncio
async def test_empty_request():
stream = encode_request(data={}, files={})
headers, stream = encode_request(data={}, files={})
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {}
assert headers == {}
assert sync_content == b""
assert async_content == b""
@ -238,12 +245,14 @@ async def test_multipart_multiple_files_single_input_content():
("file", io.BytesIO(b"<file content 1>")),
("file", io.BytesIO(b"<file content 2>")),
]
stream = encode_request(files=files, boundary=b"+++")
headers, stream = encode_request(files=files, boundary=b"+++")
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
assert headers == {
"Content-Length": "271",
"Content-Type": "multipart/form-data; boundary=+++",
}
@ -281,24 +290,28 @@ async def test_multipart_multiple_files_single_input_content():
@pytest.mark.asyncio
async def test_response_empty_content():
stream = encode_response()
headers, stream = encode_response()
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {}
assert headers == {}
assert sync_content == b""
assert async_content == b""
@pytest.mark.asyncio
async def test_response_bytes_content():
stream = encode_response(content=b"Hello, world!")
headers, stream = encode_response(content=b"Hello, world!")
assert isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {"Content-Length": "13"}
assert headers == {"Content-Length": "13"}
assert sync_content == b"Hello, world!"
assert async_content == b"Hello, world!"
@ -309,16 +322,15 @@ async def test_response_iterator_content():
yield b"Hello, "
yield b"world!"
stream = encode_response(content=hello_world())
headers, stream = encode_response(content=hello_world())
assert isinstance(stream, typing.Iterable)
assert not isinstance(stream, typing.AsyncIterable)
content = b"".join([part for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert headers == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
with pytest.raises(RuntimeError):
[part async for part in stream]
with pytest.raises(StreamConsumed):
[part for part in stream]
@ -329,16 +341,15 @@ async def test_response_aiterator_content():
yield b"Hello, "
yield b"world!"
stream = encode_response(content=hello_world())
headers, stream = encode_response(content=hello_world())
assert not isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)
content = b"".join([part async for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert headers == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
with pytest.raises(RuntimeError):
[part for part in stream]
with pytest.raises(StreamConsumed):
[part async for part in stream]

View File

@ -110,9 +110,8 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = os.urandom(16).hex()
stream = encode_request(data=data, files=files)
headers, stream = encode_request(data=data, files=files)
assert isinstance(stream, MultipartStream)
assert stream.can_replay()
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
content = (
@ -128,7 +127,7 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
"--{0}--\r\n"
"".format(boundary).encode("ascii")
)
assert stream.get_headers()["Content-Length"] == str(len(content))
assert headers["Content-Length"] == str(len(content))
assert b"".join(stream) == content
@ -137,9 +136,8 @@ def test_multipart_encode_files_allows_filenames_as_none() -> None:
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = os.urandom(16).hex()
stream = encode_request(data={}, files=files)
headers, stream = encode_request(data={}, files=files)
assert isinstance(stream, MultipartStream)
assert stream.can_replay()
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
assert b"".join(stream) == (
@ -164,9 +162,8 @@ def test_multipart_encode_files_guesses_correct_content_type(
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = os.urandom(16).hex()
stream = encode_request(data={}, files=files)
headers, stream = encode_request(data={}, files=files)
assert isinstance(stream, MultipartStream)
assert stream.can_replay()
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
assert b"".join(stream) == (
@ -188,9 +185,8 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = os.urandom(16).hex()
stream = encode_request(data={}, files=files)
headers, stream = encode_request(data={}, files=files)
assert isinstance(stream, MultipartStream)
assert stream.can_replay()
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
content = (
@ -200,7 +196,7 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
"--{0}--\r\n"
"".format(boundary, output).encode("ascii")
)
assert stream.get_headers()["Content-Length"] == str(len(content))
assert headers["Content-Length"] == str(len(content))
assert b"".join(stream) == content
@ -214,9 +210,6 @@ def test_multipart_encode_non_seekable_filelike() -> None:
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
self._iterator = iterator
def seekable(self) -> bool:
return False
def read(self, *args: typing.Any) -> bytes:
return b"".join(self._iterator)
@ -226,8 +219,8 @@ def test_multipart_encode_non_seekable_filelike() -> None:
fileobj: typing.Any = IteratorIO(data())
files = {"file": fileobj}
stream = encode_request(files=files, boundary=b"+++")
assert not stream.can_replay()
headers, stream = encode_request(files=files, boundary=b"+++")
assert isinstance(stream, typing.Iterable)
content = (
b"--+++\r\n"
@ -237,7 +230,7 @@ def test_multipart_encode_non_seekable_filelike() -> None:
b"HelloWorld\r\n"
b"--+++--\r\n"
)
assert stream.get_headers() == {
assert headers == {
"Content-Type": "multipart/form-data; boundary=+++",
"Content-Length": str(len(content)),
}

View File

@ -36,22 +36,11 @@ class MockTransport(httpcore.SyncHTTPTransport):
stream: httpcore.SyncByteStream = None,
timeout: Mapping[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]:
request_headers = httpx.Headers(headers)
content = (
(item for item in stream)
if stream
and (
"Content-Length" in request_headers
or "Transfer-Encoding" in request_headers
)
else None
)
request = httpx.Request(
method=method,
url=url,
headers=request_headers,
content=content,
headers=headers,
stream=stream,
)
request.read()
response = self.handler(request)
@ -60,13 +49,13 @@ class MockTransport(httpcore.SyncHTTPTransport):
response.status_code,
response.reason_phrase.encode("ascii"),
response.headers.raw,
response._raw_stream,
response.stream,
)
class AsyncMockTransport(httpcore.AsyncHTTPTransport):
def __init__(self, handler: Callable) -> None:
self.impl = MockTransport(handler)
self.handler = handler
async def request(
self,
@ -76,28 +65,18 @@ class AsyncMockTransport(httpcore.AsyncHTTPTransport):
stream: httpcore.AsyncByteStream = None,
timeout: Mapping[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
content = (
httpcore.PlainByteStream(b"".join([part async for part in stream]))
if stream
else httpcore.PlainByteStream(b"")
request = httpx.Request(
method=method,
url=url,
headers=headers,
stream=stream,
)
(
http_version,
status_code,
reason_phrase,
headers,
response_stream,
) = self.impl.request(
method, url, headers=headers, stream=content, timeout=timeout
)
content = httpcore.PlainByteStream(b"".join([part for part in response_stream]))
await request.aread()
response = self.handler(request)
return (
http_version,
status_code,
reason_phrase,
headers,
content,
(response.http_version or "HTTP/1.1").encode("ascii"),
response.status_code,
response.reason_phrase.encode("ascii"),
response.headers.raw,
response.stream,
)