Adds check to enforce single consumption of AsyncIteratorStream. (#697)
This commit is contained in:
parent
35b7516674
commit
de8b95533d
@ -7,6 +7,7 @@ from json import dumps as json_dumps
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from .exceptions import StreamConsumed
|
||||
from .utils import format_form_param
|
||||
|
||||
RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
|
||||
@ -81,6 +82,7 @@ class AsyncIteratorStream(ContentStream):
|
||||
) -> None:
|
||||
self.aiterator = aiterator
|
||||
self.close_func = close_func
|
||||
self.is_stream_consumed = False
|
||||
|
||||
def can_replay(self) -> bool:
|
||||
return False
|
||||
@ -89,6 +91,9 @@ class AsyncIteratorStream(ContentStream):
|
||||
return {"Transfer-Encoding": "chunked"}
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
if self.is_stream_consumed:
|
||||
raise StreamConsumed()
|
||||
self.is_stream_consumed = True
|
||||
async for part in self.aiterator:
|
||||
yield part
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import io
|
||||
import pytest
|
||||
|
||||
from httpx.content_streams import encode
|
||||
from httpx.exceptions import StreamConsumed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -39,6 +40,21 @@ async def test_aiterator_content():
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiterator_is_stream_consumed():
|
||||
async def hello_world():
|
||||
yield b"Hello, "
|
||||
yield b"world!"
|
||||
|
||||
stream = encode(data=hello_world())
|
||||
b"".join([part async for part in stream])
|
||||
|
||||
assert stream.is_stream_consumed
|
||||
|
||||
with pytest.raises(StreamConsumed) as _:
|
||||
b"".join([part async for part in stream])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_content():
|
||||
stream = encode(json={"Hello": "world!"})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user