Sync streaming interface on responses (#695)

* Sync streaming interface on responses

* Fix test case

* Test coverage for sync response APIs

* Address review comments
This commit is contained in:
Tom Christie 2020-01-02 12:56:11 +00:00 committed by GitHub
parent b0bf2a7513
commit 11e7604d1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 356 additions and 84 deletions

View File

@ -10,7 +10,9 @@ from urllib.parse import urlencode
from .exceptions import StreamConsumed
from .utils import format_form_param
RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
RequestData = typing.Union[
dict, str, bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]
]
RequestFiles = typing.Dict[
str,
@ -47,6 +49,12 @@ class ContentStream:
"""
return True
def __iter__(self) -> typing.Iterator[bytes]:
yield b""
def close(self) -> None:
pass
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b""
@ -68,10 +76,46 @@ class ByteStream(ContentStream):
content_length = str(len(self.body))
return {"Content-Length": content_length}
def __iter__(self) -> typing.Iterator[bytes]:
yield self.body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
class IteratorStream(ContentStream):
"""
Request content encoded as plain bytes, using an byte iterator.
"""
def __init__(
self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None
) -> None:
self.iterator = iterator
self.close_func = close_func
self.is_stream_consumed = False
def can_replay(self) -> bool:
return False
def get_headers(self) -> typing.Dict[str, str]:
return {"Transfer-Encoding": "chunked"}
def __iter__(self) -> typing.Iterator[bytes]:
if self.is_stream_consumed:
raise StreamConsumed()
self.is_stream_consumed = True
for part in self.iterator:
yield part
def __aiter__(self) -> typing.AsyncIterator[bytes]:
raise RuntimeError("Attempted to call a async iterator on an sync stream.")
def close(self) -> None:
if self.close_func is not None:
self.close_func()
class AsyncIteratorStream(ContentStream):
"""
Request content encoded as plain bytes, using an async byte iterator.
@ -90,6 +134,9 @@ class AsyncIteratorStream(ContentStream):
def get_headers(self) -> typing.Dict[str, str]:
return {"Transfer-Encoding": "chunked"}
def __iter__(self) -> typing.Iterator[bytes]:
raise RuntimeError("Attempted to call a sync iterator on an async stream.")
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
if self.is_stream_consumed:
raise StreamConsumed()
@ -115,6 +162,9 @@ class JSONStream(ContentStream):
content_type = "application/json"
return {"Content-Length": content_length, "Content-Type": content_type}
def __iter__(self) -> typing.Iterator[bytes]:
yield self.body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
@ -132,6 +182,9 @@ class URLEncodedStream(ContentStream):
content_type = "application/x-www-form-urlencoded"
return {"Content-Length": content_length, "Content-Type": content_type}
def __iter__(self) -> typing.Iterator[bytes]:
yield self.body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
@ -252,6 +305,9 @@ class MultipartStream(ContentStream):
content_type = self.content_type
return {"Content-Length": content_length, "Content-Type": content_type}
def __iter__(self) -> typing.Iterator[bytes]:
yield self.body
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield self.body
@ -280,5 +336,11 @@ def encode(
return URLEncodedStream(data=data)
elif isinstance(data, (str, bytes)):
return ByteStream(body=data)
else:
elif hasattr(data, "__aiter__"):
data = typing.cast(typing.AsyncIterator[bytes], data)
return AsyncIteratorStream(aiterator=data)
elif hasattr(data, "__iter__"):
data = typing.cast(typing.Iterator[bytes], data)
return IteratorStream(iterator=data)
raise TypeError(f"Unexpected type for 'data', {type(data)!r}")

View File

@ -13,7 +13,13 @@ import chardet
import rfc3986
from .config import USER_AGENT
from .content_streams import ContentStream, RequestData, RequestFiles, encode
from .content_streams import (
ByteStream,
ContentStream,
RequestData,
RequestFiles,
encode,
)
from .decoders import (
ACCEPT_ENCODING,
SUPPORTED_DECODERS,
@ -665,15 +671,13 @@ class Response:
self.history = [] if history is None else list(history)
if stream is None:
self.is_closed = True
self.is_stream_consumed = True
self._raw_content = content or b""
self._elapsed = request.timer.elapsed
else:
self.is_closed = False
self.is_stream_consumed = False
self.is_closed = False
self.is_stream_consumed = False
if stream is not None:
self._raw_stream = stream
else:
self._raw_stream = ByteStream(body=content or b"")
self.read()
@property
def elapsed(self) -> datetime.timedelta:
@ -702,13 +706,7 @@ class Response:
@property
def content(self) -> bytes:
if not hasattr(self, "_content"):
if hasattr(self, "_raw_content"):
raw_content = self._raw_content # type: ignore
content = self.decoder.decode(raw_content)
content += self.decoder.flush()
self._content = content
else:
raise ResponseNotRead()
raise ResponseNotRead()
return self._content
@property
@ -850,14 +848,6 @@ class Response:
def __repr__(self) -> str:
return f"<Response [{self.status_code} {self.reason_phrase}]>"
async def aread(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content
@property
def stream(self): # type: ignore
warnings.warn( # pragma: nocover
@ -874,6 +864,78 @@ class Response:
)
return self.aiter_raw # pragma: nocover
def read(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part for part in self.iter_bytes()])
return self._content
def iter_bytes(self) -> typing.Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
else:
for chunk in self.iter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
def iter_text(self) -> typing.Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(encoding=self.charset_encoding)
for chunk in self.iter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
for text in self.iter_text():
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
def iter_raw(self) -> typing.Iterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
self.is_stream_consumed = True
for part in self._raw_stream:
yield part
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not self.is_closed:
self.is_closed = True
self._elapsed = self.request.timer.elapsed
if hasattr(self, "_raw_stream"):
self._raw_stream.close()
async def aread(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content
async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
@ -909,18 +971,15 @@ class Response:
"""
A byte-iterator over the raw response content.
"""
if hasattr(self, "_raw_content"):
yield self._raw_content
else:
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise ResponseClosed()
self.is_stream_consumed = True
async for part in self._raw_stream:
yield part
await self.aclose()
self.is_stream_consumed = True
async for part in self._raw_stream:
yield part
await self.aclose()
async def anext(self) -> "Response":
"""

View File

@ -5,7 +5,7 @@ from unittest import mock
import pytest
import httpx
from httpx.content_streams import AsyncIteratorStream
from httpx.content_streams import AsyncIteratorStream, IteratorStream
REQUEST = httpx.Request("GET", "https://example.org")
@ -124,8 +124,23 @@ def test_response_force_encoding():
assert response.encoding == "iso-8859-1"
def test_read():
response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
assert response.status_code == 200
assert response.text == "Hello, world!"
assert response.encoding == "ascii"
assert response.is_closed
content = response.read()
assert content == b"Hello, world!"
assert response.content == b"Hello, world!"
assert response.is_closed
@pytest.mark.asyncio
async def test_read_response():
async def test_aread():
response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
assert response.status_code == 200
@ -140,9 +155,20 @@ async def test_read_response():
assert response.is_closed
def test_iter_raw():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream, request=REQUEST)
raw = b""
for part in response.iter_raw():
raw += part
assert raw == b"Hello, world!"
@pytest.mark.asyncio
async def test_raw_interface():
response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
async def test_aiter_raw():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream, request=REQUEST)
raw = b""
async for part in response.aiter_raw():
@ -150,8 +176,17 @@ async def test_raw_interface():
assert raw == b"Hello, world!"
def test_iter_bytes():
response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
content = b""
for part in response.iter_bytes():
content += part
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_bytes_interface():
async def test_aiter_bytes():
response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
content = b""
@ -160,11 +195,18 @@ async def test_bytes_interface():
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_text_interface():
def test_iter_text():
response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
await response.aread()
content = ""
for part in response.iter_text():
content += part
assert content == "Hello, world!"
@pytest.mark.asyncio
async def test_aiter_text():
response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
content = ""
async for part in response.aiter_text():
@ -172,11 +214,18 @@ async def test_text_interface():
assert content == "Hello, world!"
@pytest.mark.asyncio
async def test_lines_interface():
def test_iter_lines():
response = httpx.Response(200, content=b"Hello,\nworld!", request=REQUEST)
await response.aread()
content = []
for line in response.iter_lines():
content.append(line)
assert content == ["Hello,\n", "world!"]
@pytest.mark.asyncio
async def test_aiter_lines():
response = httpx.Response(200, content=b"Hello,\nworld!", request=REQUEST)
content = []
async for line in response.aiter_lines():
@ -184,20 +233,22 @@ async def test_lines_interface():
assert content == ["Hello,\n", "world!"]
@pytest.mark.asyncio
async def test_stream_interface_after_read():
response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
def test_sync_streaming_response():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream, request=REQUEST)
await response.aread()
assert response.status_code == 200
assert not response.is_closed
content = response.read()
content = b""
async for part in response.aiter_bytes():
content += part
assert content == b"Hello, world!"
assert response.content == b"Hello, world!"
assert response.is_closed
@pytest.mark.asyncio
async def test_streaming_response():
async def test_async_streaming_response():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream, request=REQUEST)
@ -211,8 +262,20 @@ async def test_streaming_response():
assert response.is_closed
def test_cannot_read_after_stream_consumed():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(200, stream=stream, request=REQUEST)
content = b""
for part in response.iter_bytes():
content += part
with pytest.raises(httpx.StreamConsumed):
response.read()
@pytest.mark.asyncio
async def test_cannot_read_after_stream_consumed():
async def test_cannot_aread_after_stream_consumed():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(200, stream=stream, request=REQUEST)
@ -224,12 +287,38 @@ async def test_cannot_read_after_stream_consumed():
await response.aread()
def test_cannot_read_after_response_closed():
is_closed = False
def close_func():
nonlocal is_closed
is_closed = True
stream = IteratorStream(iterator=streaming_body(), close_func=close_func)
response = httpx.Response(200, stream=stream, request=REQUEST)
response.close()
assert is_closed
with pytest.raises(httpx.ResponseClosed):
response.read()
@pytest.mark.asyncio
async def test_cannot_read_after_response_closed():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
async def test_cannot_aread_after_response_closed():
is_closed = False
async def close_func():
nonlocal is_closed
is_closed = True
stream = AsyncIteratorStream(
aiterator=async_streaming_body(), close_func=close_func
)
response = httpx.Response(200, stream=stream, request=REQUEST)
await response.aclose()
assert is_closed
with pytest.raises(httpx.ResponseClosed):
await response.aread()

View File

@ -2,29 +2,65 @@ import io
import pytest
from httpx.content_streams import encode
from httpx.content_streams import ContentStream, encode
from httpx.exceptions import StreamConsumed
@pytest.mark.asyncio
async def test_base_content():
stream = ContentStream()
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {}
assert sync_content == b""
assert async_content == b""
@pytest.mark.asyncio
async def test_empty_content():
stream = encode()
content = b"".join([part async for part in stream])
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {}
assert content == b""
assert sync_content == b""
assert async_content == b""
@pytest.mark.asyncio
async def test_bytes_content():
stream = encode(data=b"Hello, world!")
content = b"".join([part async for part in stream])
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {"Content-Length": "13"}
assert sync_content == b"Hello, world!"
assert async_content == b"Hello, world!"
@pytest.mark.asyncio
async def test_iterator_content():
def hello_world():
yield b"Hello, "
yield b"world!"
stream = encode(data=hello_world())
content = b"".join([part for part in stream])
assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
with pytest.raises(RuntimeError):
[part async for part in stream]
with pytest.raises(StreamConsumed):
[part for part in stream]
@pytest.mark.asyncio
async def test_aiterator_content():
@ -39,60 +75,66 @@ async def test_aiterator_content():
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"
with pytest.raises(RuntimeError):
[part for part in stream]
@pytest.mark.asyncio
async def test_aiterator_is_stream_consumed():
async def hello_world():
yield b"Hello, "
yield b"world!"
stream = encode(data=hello_world())
b"".join([part async for part in stream])
assert stream.is_stream_consumed
with pytest.raises(StreamConsumed) as _:
b"".join([part async for part in stream])
with pytest.raises(StreamConsumed):
[part async for part in stream]
@pytest.mark.asyncio
async def test_json_content():
stream = encode(json={"Hello": "world!"})
content = b"".join([part async for part in stream])
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
"Content-Length": "19",
"Content-Type": "application/json",
}
assert content == b'{"Hello": "world!"}'
assert sync_content == b'{"Hello": "world!"}'
assert async_content == b'{"Hello": "world!"}'
@pytest.mark.asyncio
async def test_urlencoded_content():
stream = encode(data={"Hello": "world!"})
content = b"".join([part async for part in stream])
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
"Content-Length": "14",
"Content-Type": "application/x-www-form-urlencoded",
}
assert content == b"Hello=world%21"
assert sync_content == b"Hello=world%21"
assert async_content == b"Hello=world%21"
@pytest.mark.asyncio
async def test_multipart_files_content():
files = {"file": io.BytesIO(b"<file content>")}
stream = encode(files=files, boundary=b"+++")
content = b"".join([part async for part in stream])
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
"Content-Length": "138",
"Content-Type": "multipart/form-data; boundary=+++",
}
assert content == b"".join(
assert sync_content == b"".join(
[
b"--+++\r\n",
b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
b"Content-Type: application/octet-stream\r\n",
b"\r\n",
b"<file content>\r\n",
b"--+++--\r\n",
]
)
assert async_content == b"".join(
[
b"--+++\r\n",
b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
@ -109,14 +151,15 @@ async def test_multipart_data_and_files_content():
data = {"message": "Hello, world!"}
files = {"file": io.BytesIO(b"<file content>")}
stream = encode(data=data, files=files, boundary=b"+++")
content = b"".join([part async for part in stream])
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])
assert stream.can_replay()
assert stream.get_headers() == {
"Content-Length": "210",
"Content-Type": "multipart/form-data; boundary=+++",
}
assert content == b"".join(
assert sync_content == b"".join(
[
b"--+++\r\n",
b'Content-Disposition: form-data; name="message"\r\n',
@ -130,3 +173,22 @@ async def test_multipart_data_and_files_content():
b"--+++--\r\n",
]
)
assert async_content == b"".join(
[
b"--+++\r\n",
b'Content-Disposition: form-data; name="message"\r\n',
b"\r\n",
b"Hello, world!\r\n",
b"--+++\r\n",
b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
b"Content-Type: application/octet-stream\r\n",
b"\r\n",
b"<file content>\r\n",
b"--+++--\r\n",
]
)
def test_invalid_argument():
with pytest.raises(TypeError):
encode(123)