Support Response(content=<bytes iterator>) (#1265)

* Support Response(content=<bytes iterator>)

* Update test for merged master
This commit is contained in:
Tom Christie 2020-09-11 10:28:18 +01:00 committed by GitHub
parent 4bd08bed22
commit 5ee6135256
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 112 additions and 83 deletions

View File

@ -8,7 +8,7 @@ from urllib.parse import urlencode
import httpcore
from ._exceptions import StreamConsumed
from ._types import FileContent, FileTypes, RequestData, RequestFiles
from ._types import FileContent, FileTypes, RequestData, RequestFiles, ResponseContent
from ._utils import (
format_form_param,
guess_content_type,
@ -72,11 +72,8 @@ class IteratorStream(ContentStream):
Request content encoded as plain bytes, using an byte iterator.
"""
def __init__(
self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None
) -> None:
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
self.iterator = iterator
self.close_func = close_func
self.is_stream_consumed = False
def can_replay(self) -> bool:
@ -95,21 +92,14 @@ class IteratorStream(ContentStream):
def __aiter__(self) -> typing.AsyncIterator[bytes]:
raise RuntimeError("Attempted to call a async iterator on an sync stream.")
def close(self) -> None:
if self.close_func is not None:
self.close_func()
class AsyncIteratorStream(ContentStream):
"""
Request content encoded as plain bytes, using an async byte iterator.
"""
def __init__(
self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None
) -> None:
def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
self.aiterator = aiterator
self.close_func = close_func
self.is_stream_consumed = False
def can_replay(self) -> bool:
@ -128,10 +118,6 @@ class AsyncIteratorStream(ContentStream):
async for part in self.aiterator:
yield part
async def aclose(self) -> None:
if self.close_func is not None:
await self.close_func()
class JSONStream(ContentStream):
"""
@ -402,3 +388,18 @@ def encode(
return IteratorStream(iterator=data)
raise TypeError(f"Unexpected type for 'data', {type(data)!r}")
def encode_response(content: ResponseContent = None) -> ContentStream:
if content is None:
return ByteStream(b"")
elif isinstance(content, bytes):
return ByteStream(body=content)
elif hasattr(content, "__aiter__"):
content = typing.cast(typing.AsyncIterator[bytes], content)
return AsyncIteratorStream(aiterator=content)
elif hasattr(content, "__iter__"):
content = typing.cast(typing.Iterator[bytes], content)
return IteratorStream(iterator=content)
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")

View File

@ -14,7 +14,7 @@ import chardet
import rfc3986
import rfc3986.exceptions
from ._content_streams import ByteStream, ContentStream, encode
from ._content_streams import ByteStream, ContentStream, encode, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ContentDecoder,
@ -44,6 +44,7 @@ from ._types import (
QueryParamTypes,
RequestData,
RequestFiles,
ResponseContent,
URLTypes,
)
from ._utils import (
@ -674,7 +675,7 @@ class Response:
http_version: str = None,
headers: HeaderTypes = None,
stream: ContentStream = None,
content: bytes = None,
content: ResponseContent = None,
history: typing.List["Response"] = None,
elapsed_func: typing.Callable = None,
):
@ -694,8 +695,10 @@ class Response:
if stream is not None:
self._raw_stream = stream
else:
self._raw_stream = ByteStream(body=content or b"")
self.read()
self._raw_stream = encode_response(content)
if content is None or isinstance(content, bytes):
# Load the response body, except for streaming content.
self.read()
self._num_bytes_downloaded = 0

View File

@ -63,6 +63,8 @@ AuthTypes = Union[
None,
]
ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]]
RequestData = Union[dict, str, bytes, Iterator[bytes], AsyncIterator[bytes]]
FileContent = Union[IO[str], IO[bytes], str, bytes]

View File

@ -5,7 +5,6 @@ import brotli
import pytest
import httpx
from httpx._content_streams import AsyncIteratorStream, IteratorStream
def streaming_body():
@ -215,10 +214,9 @@ async def test_aread():
def test_iter_raw():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)
raw = b""
@ -228,12 +226,7 @@ def test_iter_raw():
def test_iter_raw_increments_updates_counter():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=streaming_body())
num_downloaded = response.num_bytes_downloaded
for part in response.iter_raw():
@ -243,11 +236,7 @@ def test_iter_raw_increments_updates_counter():
@pytest.mark.asyncio
async def test_aiter_raw():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=async_streaming_body())
raw = b""
async for part in response.aiter_raw():
@ -257,12 +246,7 @@ async def test_aiter_raw():
@pytest.mark.asyncio
async def test_aiter_raw_increments_updates_counter():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=async_streaming_body())
num_downloaded = response.num_bytes_downloaded
async for part in response.aiter_raw():
@ -346,10 +330,9 @@ async def test_aiter_lines():
def test_sync_streaming_response():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)
assert response.status_code == 200
@ -364,10 +347,9 @@ def test_sync_streaming_response():
@pytest.mark.asyncio
async def test_async_streaming_response():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)
assert response.status_code == 200
@ -381,10 +363,9 @@ async def test_async_streaming_response():
def test_cannot_read_after_stream_consumed():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)
content = b""
@ -397,10 +378,9 @@ def test_cannot_read_after_stream_consumed():
@pytest.mark.asyncio
async def test_cannot_aread_after_stream_consumed():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)
content = b""
@ -412,54 +392,33 @@ async def test_cannot_aread_after_stream_consumed():
def test_cannot_read_after_response_closed():
is_closed = False
def close_func():
nonlocal is_closed
is_closed = True
stream = IteratorStream(iterator=streaming_body(), close_func=close_func)
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)
response.close()
assert is_closed
with pytest.raises(httpx.ResponseClosed):
response.read()
@pytest.mark.asyncio
async def test_cannot_aread_after_response_closed():
is_closed = False
async def close_func():
nonlocal is_closed
is_closed = True
stream = AsyncIteratorStream(
aiterator=async_streaming_body(), close_func=close_func
)
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)
await response.aclose()
assert is_closed
with pytest.raises(httpx.ResponseClosed):
await response.aread()
@pytest.mark.asyncio
async def test_elapsed_not_available_until_closed():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)
with pytest.raises(RuntimeError):

View File

@ -3,7 +3,7 @@ import io
import pytest
from httpx import StreamConsumed
from httpx._content_streams import ContentStream, encode
from httpx._content_streams import ContentStream, encode, encode_response
@pytest.mark.asyncio
@ -251,3 +251,72 @@ async def test_multipart_multiple_files_single_input_content():
b"--+++--\r\n",
]
)
@pytest.mark.asyncio
async def test_response_empty_content():
stream = encode_response()
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {}
assert sync_content == b""
assert async_content == b""
@pytest.mark.asyncio
async def test_response_bytes_content():
stream = encode_response(content=b"Hello, world!")
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {"Content-Length": "13"}
assert sync_content == b"Hello, world!"
assert async_content == b"Hello, world!"
@pytest.mark.asyncio
async def test_response_iterator_content():
def hello_world():
yield b"Hello, "
yield b"world!"
stream = encode_response(content=hello_world())
content = b"".join([part for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
with pytest.raises(RuntimeError):
[part async for part in stream]
with pytest.raises(StreamConsumed):
[part for part in stream]
@pytest.mark.asyncio
async def test_response_aiterator_content():
async def hello_world():
yield b"Hello, "
yield b"world!"
stream = encode_response(content=hello_world())
content = b"".join([part async for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
with pytest.raises(RuntimeError):
[part for part in stream]
with pytest.raises(StreamConsumed):
[part async for part in stream]
def test_response_invalid_argument():
with pytest.raises(TypeError):
encode_response(123) # type: ignore

View File

@ -4,7 +4,6 @@ import brotli
import pytest
import httpx
from httpx._content_streams import AsyncIteratorStream
from httpx._decoders import (
BrotliDecoder,
DeflateDecoder,
@ -130,11 +129,10 @@ async def test_streaming():
yield compressor.flush()
headers = [(b"Content-Encoding", b"gzip")]
stream = AsyncIteratorStream(aiterator=compress(body))
response = httpx.Response(
200,
headers=headers,
stream=stream,
content=compress(body),
)
assert not hasattr(response, "body")
assert await response.aread() == body
@ -199,19 +197,17 @@ async def test_text_decoder(data, encoding):
yield chunk
# Accessing `.text` on a read response.
stream = AsyncIteratorStream(aiterator=iterator())
response = httpx.Response(
200,
stream=stream,
content=iterator(),
)
await response.aread()
assert response.text == (b"".join(data)).decode(encoding)
# Streaming `.aiter_text` iteratively.
stream = AsyncIteratorStream(aiterator=iterator())
response = httpx.Response(
200,
stream=stream,
content=iterator(),
)
text = "".join([part async for part in response.aiter_text()])
assert text == (b"".join(data)).decode(encoding)
@ -224,11 +220,10 @@ async def test_text_decoder_known_encoding():
yield b"\x83"
yield b"\x89\x83x\x83\x8b"
stream = AsyncIteratorStream(aiterator=iterator())
response = httpx.Response(
200,
headers=[(b"Content-Type", b"text/html; charset=shift-jis")],
stream=stream,
content=iterator(),
)
await response.aread()