Compare commits

...

2 Commits

Author SHA1 Message Date
florimondmanca
8afd29a8ce
Use httpcore PR 2023-02-13 01:15:59 +01:00
florimondmanca
5a3603726a
Fix unclosed generator on trio 2023-02-13 00:25:23 +01:00
6 changed files with 77 additions and 30 deletions

View File

@ -142,9 +142,8 @@ class BoundAsyncStream(AsyncByteStream):
self._response = response
self._timer = timer
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for chunk in self._stream:
yield chunk
def __aiter__(self) -> typing.AsyncIterator[bytes]:
return self._stream.__aiter__()
async def aclose(self) -> None:
seconds = await self._timer.async_elapsed()

View File

@ -4,6 +4,7 @@ import json as jsonlib
import typing
import urllib.request
from collections.abc import Mapping
from contextlib import aclosing
from http.cookiejar import Cookie, CookieJar
from ._content import ByteStream, UnattachedStream, encode_request, encode_response
@ -911,7 +912,7 @@ class Response:
async def aiter_bytes(
self, chunk_size: typing.Optional[int] = None
) -> typing.AsyncIterator[bytes]:
) -> typing.AsyncGenerator[bytes, None]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
@ -924,19 +925,20 @@ class Response:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
async with aclosing(self.aiter_raw()) as stream:
async for raw_bytes in stream:
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
async def aiter_text(
self, chunk_size: typing.Optional[int] = None
) -> typing.AsyncIterator[str]:
) -> typing.AsyncGenerator[str, None]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
@ -945,28 +947,30 @@ class Response:
decoder = TextDecoder(encoding=self.encoding or "utf-8")
chunker = TextChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
async with aclosing(self.aiter_bytes()) as stream:
async for byte_content in stream:
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
for chunk in chunker.flush():
yield chunk
for chunk in chunker.flush():
yield chunk
async def aiter_lines(self) -> typing.AsyncIterator[str]:
async def aiter_lines(self) -> typing.AsyncGenerator[str, None]:
decoder = LineDecoder()
with request_context(request=self._request):
async for text in self.aiter_text():
for line in decoder.decode(text):
async with aclosing(self.aiter_text()) as stream:
async for text in stream:
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
for line in decoder.flush():
yield line
async def aiter_raw(
self, chunk_size: typing.Optional[int] = None
) -> typing.AsyncIterator[bytes]:
) -> typing.AsyncGenerator[bytes, None]:
"""
A byte-iterator over the raw response content.
"""

View File

@ -232,12 +232,14 @@ class HTTPTransport(BaseTransport):
class AsyncResponseStream(AsyncByteStream):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]):
self._httpcore_stream = httpcore_stream
self._httpcore_stream = httpcore_stream.__aiter__()
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
def __aiter__(self) -> typing.AsyncIterator[bytes]:
return self
async def __anext__(self) -> bytes:
with map_httpcore_exceptions():
async for part in self._httpcore_stream:
yield part
return await self._httpcore_stream.__anext__()
async def aclose(self) -> None:
if hasattr(self._httpcore_stream, "aclose"):

View File

@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"certifi",
"httpcore>=0.15.0,<0.17.0",
"httpcore==git+https://github.com/encode/httpcore.git@bug/async-early-stream-break",
"idna",
"sniffio",
]

View File

@ -1,4 +1,5 @@
import typing
from contextlib import aclosing
from datetime import timedelta
import pytest
@ -76,6 +77,34 @@ async def test_stream_response(server):
assert response.content == b"Hello, world!"
@pytest.mark.anyio
async def test_stream_iterator(server):
body = b""
async with httpx.AsyncClient() as client:
async with client.stream("GET", server.url) as response:
async for chunk in response.aiter_bytes():
body += chunk
assert response.status_code == 200
assert body == b"Hello, world!"
@pytest.mark.anyio
async def test_stream_iterator_partial(server):
body = ""
async with httpx.AsyncClient() as client:
async with client.stream("GET", server.url) as response:
async with aclosing(response.aiter_text(5)) as stream:
async for chunk in stream:
body += chunk
break
assert response.status_code == 200
assert body == "Hello"
@pytest.mark.anyio
async def test_access_content_stream_response(server):
async with httpx.AsyncClient() as client:

View File

@ -107,6 +107,19 @@ def test_stream_iterator(server):
assert body == b"Hello, world!"
def test_stream_iterator_partial(server):
body = ""
with httpx.Client() as client:
with client.stream("GET", server.url) as response:
for chunk in response.iter_text(5):
body += chunk
break
assert response.status_code == 200
assert body == "Hello"
def test_raw_iterator(server):
body = b""