Support Response(content=<bytes iterator>) (#1265)
* Support Response(content=<bytes iterator>) * Update test for merged master
This commit is contained in:
parent
4bd08bed22
commit
5ee6135256
@ -8,7 +8,7 @@ from urllib.parse import urlencode
|
||||
import httpcore
|
||||
|
||||
from ._exceptions import StreamConsumed
|
||||
from ._types import FileContent, FileTypes, RequestData, RequestFiles
|
||||
from ._types import FileContent, FileTypes, RequestData, RequestFiles, ResponseContent
|
||||
from ._utils import (
|
||||
format_form_param,
|
||||
guess_content_type,
|
||||
@ -72,11 +72,8 @@ class IteratorStream(ContentStream):
|
||||
Request content encoded as plain bytes, using an byte iterator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None
|
||||
) -> None:
|
||||
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
|
||||
self.iterator = iterator
|
||||
self.close_func = close_func
|
||||
self.is_stream_consumed = False
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
@ -95,21 +92,14 @@ class IteratorStream(ContentStream):
|
||||
def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
raise RuntimeError("Attempted to call a async iterator on an sync stream.")
|
||||
|
||||
def close(self) -> None:
|
||||
if self.close_func is not None:
|
||||
self.close_func()
|
||||
|
||||
|
||||
class AsyncIteratorStream(ContentStream):
|
||||
"""
|
||||
Request content encoded as plain bytes, using an async byte iterator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None
|
||||
) -> None:
|
||||
def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
|
||||
self.aiterator = aiterator
|
||||
self.close_func = close_func
|
||||
self.is_stream_consumed = False
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
@ -128,10 +118,6 @@ class AsyncIteratorStream(ContentStream):
|
||||
async for part in self.aiterator:
|
||||
yield part
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.close_func is not None:
|
||||
await self.close_func()
|
||||
|
||||
|
||||
class JSONStream(ContentStream):
|
||||
"""
|
||||
@ -402,3 +388,18 @@ def encode(
|
||||
return IteratorStream(iterator=data)
|
||||
|
||||
raise TypeError(f"Unexpected type for 'data', {type(data)!r}")
|
||||
|
||||
|
||||
def encode_response(content: ResponseContent = None) -> ContentStream:
|
||||
if content is None:
|
||||
return ByteStream(b"")
|
||||
elif isinstance(content, bytes):
|
||||
return ByteStream(body=content)
|
||||
elif hasattr(content, "__aiter__"):
|
||||
content = typing.cast(typing.AsyncIterator[bytes], content)
|
||||
return AsyncIteratorStream(aiterator=content)
|
||||
elif hasattr(content, "__iter__"):
|
||||
content = typing.cast(typing.Iterator[bytes], content)
|
||||
return IteratorStream(iterator=content)
|
||||
|
||||
raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
|
||||
|
||||
@ -14,7 +14,7 @@ import chardet
|
||||
import rfc3986
|
||||
import rfc3986.exceptions
|
||||
|
||||
from ._content_streams import ByteStream, ContentStream, encode
|
||||
from ._content_streams import ByteStream, ContentStream, encode, encode_response
|
||||
from ._decoders import (
|
||||
SUPPORTED_DECODERS,
|
||||
ContentDecoder,
|
||||
@ -44,6 +44,7 @@ from ._types import (
|
||||
QueryParamTypes,
|
||||
RequestData,
|
||||
RequestFiles,
|
||||
ResponseContent,
|
||||
URLTypes,
|
||||
)
|
||||
from ._utils import (
|
||||
@ -674,7 +675,7 @@ class Response:
|
||||
http_version: str = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: ContentStream = None,
|
||||
content: bytes = None,
|
||||
content: ResponseContent = None,
|
||||
history: typing.List["Response"] = None,
|
||||
elapsed_func: typing.Callable = None,
|
||||
):
|
||||
@ -694,8 +695,10 @@ class Response:
|
||||
if stream is not None:
|
||||
self._raw_stream = stream
|
||||
else:
|
||||
self._raw_stream = ByteStream(body=content or b"")
|
||||
self.read()
|
||||
self._raw_stream = encode_response(content)
|
||||
if content is None or isinstance(content, bytes):
|
||||
# Load the response body, except for streaming content.
|
||||
self.read()
|
||||
|
||||
self._num_bytes_downloaded = 0
|
||||
|
||||
|
||||
@ -63,6 +63,8 @@ AuthTypes = Union[
|
||||
None,
|
||||
]
|
||||
|
||||
ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]]
|
||||
|
||||
RequestData = Union[dict, str, bytes, Iterator[bytes], AsyncIterator[bytes]]
|
||||
|
||||
FileContent = Union[IO[str], IO[bytes], str, bytes]
|
||||
|
||||
@ -5,7 +5,6 @@ import brotli
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx._content_streams import AsyncIteratorStream, IteratorStream
|
||||
|
||||
|
||||
def streaming_body():
|
||||
@ -215,10 +214,9 @@ async def test_aread():
|
||||
|
||||
|
||||
def test_iter_raw():
|
||||
stream = IteratorStream(iterator=streaming_body())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=streaming_body(),
|
||||
)
|
||||
|
||||
raw = b""
|
||||
@ -228,12 +226,7 @@ def test_iter_raw():
|
||||
|
||||
|
||||
def test_iter_raw_increments_updates_counter():
|
||||
stream = IteratorStream(iterator=streaming_body())
|
||||
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
)
|
||||
response = httpx.Response(200, content=streaming_body())
|
||||
|
||||
num_downloaded = response.num_bytes_downloaded
|
||||
for part in response.iter_raw():
|
||||
@ -243,11 +236,7 @@ def test_iter_raw_increments_updates_counter():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_raw():
|
||||
stream = AsyncIteratorStream(aiterator=async_streaming_body())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
)
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
|
||||
raw = b""
|
||||
async for part in response.aiter_raw():
|
||||
@ -257,12 +246,7 @@ async def test_aiter_raw():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_raw_increments_updates_counter():
|
||||
stream = AsyncIteratorStream(aiterator=async_streaming_body())
|
||||
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
)
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
|
||||
num_downloaded = response.num_bytes_downloaded
|
||||
async for part in response.aiter_raw():
|
||||
@ -346,10 +330,9 @@ async def test_aiter_lines():
|
||||
|
||||
|
||||
def test_sync_streaming_response():
|
||||
stream = IteratorStream(iterator=streaming_body())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=streaming_body(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@ -364,10 +347,9 @@ def test_sync_streaming_response():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_response():
|
||||
stream = AsyncIteratorStream(aiterator=async_streaming_body())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=async_streaming_body(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@ -381,10 +363,9 @@ async def test_async_streaming_response():
|
||||
|
||||
|
||||
def test_cannot_read_after_stream_consumed():
|
||||
stream = IteratorStream(iterator=streaming_body())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=streaming_body(),
|
||||
)
|
||||
|
||||
content = b""
|
||||
@ -397,10 +378,9 @@ def test_cannot_read_after_stream_consumed():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_aread_after_stream_consumed():
|
||||
stream = AsyncIteratorStream(aiterator=async_streaming_body())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=async_streaming_body(),
|
||||
)
|
||||
|
||||
content = b""
|
||||
@ -412,54 +392,33 @@ async def test_cannot_aread_after_stream_consumed():
|
||||
|
||||
|
||||
def test_cannot_read_after_response_closed():
|
||||
is_closed = False
|
||||
|
||||
def close_func():
|
||||
nonlocal is_closed
|
||||
is_closed = True
|
||||
|
||||
stream = IteratorStream(iterator=streaming_body(), close_func=close_func)
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=streaming_body(),
|
||||
)
|
||||
|
||||
response.close()
|
||||
assert is_closed
|
||||
|
||||
with pytest.raises(httpx.ResponseClosed):
|
||||
response.read()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_aread_after_response_closed():
|
||||
is_closed = False
|
||||
|
||||
async def close_func():
|
||||
nonlocal is_closed
|
||||
is_closed = True
|
||||
|
||||
stream = AsyncIteratorStream(
|
||||
aiterator=async_streaming_body(), close_func=close_func
|
||||
)
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=async_streaming_body(),
|
||||
)
|
||||
|
||||
await response.aclose()
|
||||
assert is_closed
|
||||
|
||||
with pytest.raises(httpx.ResponseClosed):
|
||||
await response.aread()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_elapsed_not_available_until_closed():
|
||||
stream = AsyncIteratorStream(aiterator=async_streaming_body())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=async_streaming_body(),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
|
||||
@ -3,7 +3,7 @@ import io
|
||||
import pytest
|
||||
|
||||
from httpx import StreamConsumed
|
||||
from httpx._content_streams import ContentStream, encode
|
||||
from httpx._content_streams import ContentStream, encode, encode_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -251,3 +251,72 @@ async def test_multipart_multiple_files_single_input_content():
|
||||
b"--+++--\r\n",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_empty_content():
|
||||
stream = encode_response()
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {}
|
||||
assert sync_content == b""
|
||||
assert async_content == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_bytes_content():
|
||||
stream = encode_response(content=b"Hello, world!")
|
||||
sync_content = b"".join([part for part in stream])
|
||||
async_content = b"".join([part async for part in stream])
|
||||
|
||||
assert stream.can_replay()
|
||||
assert stream.get_headers() == {"Content-Length": "13"}
|
||||
assert sync_content == b"Hello, world!"
|
||||
assert async_content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_iterator_content():
|
||||
def hello_world():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
stream = encode_response(content=hello_world())
|
||||
content = b"".join([part for part in stream])
|
||||
|
||||
assert not stream.can_replay()
|
||||
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
[part async for part in stream]
|
||||
|
||||
with pytest.raises(StreamConsumed):
|
||||
[part for part in stream]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_aiterator_content():
|
||||
async def hello_world():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
stream = encode_response(content=hello_world())
|
||||
content = b"".join([part async for part in stream])
|
||||
|
||||
assert not stream.can_replay()
|
||||
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
[part for part in stream]
|
||||
|
||||
with pytest.raises(StreamConsumed):
|
||||
[part async for part in stream]
|
||||
|
||||
|
||||
def test_response_invalid_argument():
|
||||
with pytest.raises(TypeError):
|
||||
encode_response(123) # type: ignore
|
||||
|
||||
@ -4,7 +4,6 @@ import brotli
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx._content_streams import AsyncIteratorStream
|
||||
from httpx._decoders import (
|
||||
BrotliDecoder,
|
||||
DeflateDecoder,
|
||||
@ -130,11 +129,10 @@ async def test_streaming():
|
||||
yield compressor.flush()
|
||||
|
||||
headers = [(b"Content-Encoding", b"gzip")]
|
||||
stream = AsyncIteratorStream(aiterator=compress(body))
|
||||
response = httpx.Response(
|
||||
200,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
content=compress(body),
|
||||
)
|
||||
assert not hasattr(response, "body")
|
||||
assert await response.aread() == body
|
||||
@ -199,19 +197,17 @@ async def test_text_decoder(data, encoding):
|
||||
yield chunk
|
||||
|
||||
# Accessing `.text` on a read response.
|
||||
stream = AsyncIteratorStream(aiterator=iterator())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=iterator(),
|
||||
)
|
||||
await response.aread()
|
||||
assert response.text == (b"".join(data)).decode(encoding)
|
||||
|
||||
# Streaming `.aiter_text` iteratively.
|
||||
stream = AsyncIteratorStream(aiterator=iterator())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
stream=stream,
|
||||
content=iterator(),
|
||||
)
|
||||
text = "".join([part async for part in response.aiter_text()])
|
||||
assert text == (b"".join(data)).decode(encoding)
|
||||
@ -224,11 +220,10 @@ async def test_text_decoder_known_encoding():
|
||||
yield b"\x83"
|
||||
yield b"\x89\x83x\x83\x8b"
|
||||
|
||||
stream = AsyncIteratorStream(aiterator=iterator())
|
||||
response = httpx.Response(
|
||||
200,
|
||||
headers=[(b"Content-Type", b"text/html; charset=shift-jis")],
|
||||
stream=stream,
|
||||
content=iterator(),
|
||||
)
|
||||
|
||||
await response.aread()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user