Adds check to enforce single consumption of AsyncIteratorStream. (#697)

This commit is contained in:
Gabriel Strauss 2019-12-31 07:01:43 -05:00 committed by Tom Christie
parent 35b7516674
commit de8b95533d
2 changed files with 21 additions and 0 deletions

View File

@ -7,6 +7,7 @@ from json import dumps as json_dumps
from pathlib import Path
from urllib.parse import urlencode
from .exceptions import StreamConsumed
from .utils import format_form_param
RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
@ -81,6 +82,7 @@ class AsyncIteratorStream(ContentStream):
) -> None:
self.aiterator = aiterator
self.close_func = close_func
self.is_stream_consumed = False
def can_replay(self) -> bool:
return False
@ -89,6 +91,9 @@ class AsyncIteratorStream(ContentStream):
return {"Transfer-Encoding": "chunked"}
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
if self.is_stream_consumed:
raise StreamConsumed()
self.is_stream_consumed = True
async for part in self.aiterator:
yield part

View File

@ -3,6 +3,7 @@ import io
import pytest
from httpx.content_streams import encode
from httpx.exceptions import StreamConsumed
@pytest.mark.asyncio
@ -39,6 +40,21 @@ async def test_aiterator_content():
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_aiterator_is_stream_consumed():
async def hello_world():
yield b"Hello, "
yield b"world!"
stream = encode(data=hello_world())
b"".join([part async for part in stream])
assert stream.is_stream_consumed
with pytest.raises(StreamConsumed) as _:
b"".join([part async for part in stream])
@pytest.mark.asyncio
async def test_json_content():
stream = encode(json={"Hello": "world!"})