Support for chunk_size (#1277)
* Support iter_raw(chunk_size=...) and aiter_raw(chunk_size=...) * Unit tests for ByteChunker * Support iter_bytes(chunk_size=...) * Add TextChunker * Support iter_text(chunk_size=...) * Fix merge with master Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
This commit is contained in:
parent
c4d2e6fa28
commit
27df5e49c7
@ -4,6 +4,7 @@ Handlers for Content-Encoding.
|
||||
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
|
||||
"""
|
||||
import codecs
|
||||
import io
|
||||
import typing
|
||||
import zlib
|
||||
|
||||
@ -155,6 +156,84 @@ class MultiDecoder(ContentDecoder):
|
||||
return data
|
||||
|
||||
|
||||
class ByteChunker:
|
||||
"""
|
||||
Handles returning byte content in fixed-size chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: int = None) -> None:
|
||||
self._buffer = io.BytesIO()
|
||||
self._chunk_size = chunk_size
|
||||
|
||||
def decode(self, content: bytes) -> typing.List[bytes]:
|
||||
if self._chunk_size is None:
|
||||
return [content]
|
||||
|
||||
self._buffer.write(content)
|
||||
if self._buffer.tell() >= self._chunk_size:
|
||||
value = self._buffer.getvalue()
|
||||
chunks = [
|
||||
value[i : i + self._chunk_size]
|
||||
for i in range(0, len(value), self._chunk_size)
|
||||
]
|
||||
if len(chunks[-1]) == self._chunk_size:
|
||||
self._buffer.seek(0)
|
||||
self._buffer.truncate()
|
||||
return chunks
|
||||
else:
|
||||
self._buffer.seek(0)
|
||||
self._buffer.write(chunks[-1])
|
||||
self._buffer.truncate()
|
||||
return chunks[:-1]
|
||||
else:
|
||||
return []
|
||||
|
||||
def flush(self) -> typing.List[bytes]:
|
||||
value = self._buffer.getvalue()
|
||||
self._buffer.seek(0)
|
||||
self._buffer.truncate()
|
||||
return [value] if value else []
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""
|
||||
Handles returning text content in fixed-size chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: int = None) -> None:
|
||||
self._buffer = io.StringIO()
|
||||
self._chunk_size = chunk_size
|
||||
|
||||
def decode(self, content: str) -> typing.List[str]:
|
||||
if self._chunk_size is None:
|
||||
return [content]
|
||||
|
||||
self._buffer.write(content)
|
||||
if self._buffer.tell() >= self._chunk_size:
|
||||
value = self._buffer.getvalue()
|
||||
chunks = [
|
||||
value[i : i + self._chunk_size]
|
||||
for i in range(0, len(value), self._chunk_size)
|
||||
]
|
||||
if len(chunks[-1]) == self._chunk_size:
|
||||
self._buffer.seek(0)
|
||||
self._buffer.truncate()
|
||||
return chunks
|
||||
else:
|
||||
self._buffer.seek(0)
|
||||
self._buffer.write(chunks[-1])
|
||||
self._buffer.truncate()
|
||||
return chunks[:-1]
|
||||
else:
|
||||
return []
|
||||
|
||||
def flush(self) -> typing.List[str]:
|
||||
value = self._buffer.getvalue()
|
||||
self._buffer.seek(0)
|
||||
self._buffer.truncate()
|
||||
return [value] if value else []
|
||||
|
||||
|
||||
class TextDecoder:
|
||||
"""
|
||||
Handles incrementally decoding bytes into text
|
||||
|
||||
100
httpx/_models.py
100
httpx/_models.py
@ -15,10 +15,12 @@ import rfc3986.exceptions
|
||||
from ._content import PlainByteStream, encode_request, encode_response
|
||||
from ._decoders import (
|
||||
SUPPORTED_DECODERS,
|
||||
ByteChunker,
|
||||
ContentDecoder,
|
||||
IdentityDecoder,
|
||||
LineDecoder,
|
||||
MultiDecoder,
|
||||
TextChunker,
|
||||
TextDecoder,
|
||||
)
|
||||
from ._exceptions import (
|
||||
@ -1162,31 +1164,47 @@ class Response:
|
||||
self._content = b"".join(self.iter_bytes())
|
||||
return self._content
|
||||
|
||||
def iter_bytes(self) -> typing.Iterator[bytes]:
|
||||
def iter_bytes(self, chunk_size: int = None) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
This allows us to handle gzip, deflate, and brotli encoded responses.
|
||||
"""
|
||||
if hasattr(self, "_content"):
|
||||
yield self._content
|
||||
chunk_size = len(self._content) if chunk_size is None else chunk_size
|
||||
for i in range(0, len(self._content), chunk_size):
|
||||
yield self._content[i : i + chunk_size]
|
||||
else:
|
||||
decoder = self._get_content_decoder()
|
||||
chunker = ByteChunker(chunk_size=chunk_size)
|
||||
with self._wrap_decoder_errors():
|
||||
for chunk in self.iter_raw():
|
||||
yield decoder.decode(chunk)
|
||||
yield decoder.flush()
|
||||
for raw_bytes in self.iter_raw():
|
||||
decoded = decoder.decode(raw_bytes)
|
||||
for chunk in chunker.decode(decoded):
|
||||
yield chunk
|
||||
decoded = decoder.flush()
|
||||
for chunk in chunker.decode(decoded):
|
||||
yield chunk
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
def iter_text(self) -> typing.Iterator[str]:
|
||||
def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]:
|
||||
"""
|
||||
A str-iterator over the decoded response content
|
||||
that handles both gzip, deflate, etc but also detects the content's
|
||||
string encoding.
|
||||
"""
|
||||
decoder = TextDecoder(encoding=self.encoding)
|
||||
chunker = TextChunker(chunk_size=chunk_size)
|
||||
with self._wrap_decoder_errors():
|
||||
for chunk in self.iter_bytes():
|
||||
yield decoder.decode(chunk)
|
||||
yield decoder.flush()
|
||||
for byte_content in self.iter_bytes():
|
||||
text_content = decoder.decode(byte_content)
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk
|
||||
text_content = decoder.flush()
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
def iter_lines(self) -> typing.Iterator[str]:
|
||||
decoder = LineDecoder()
|
||||
@ -1197,7 +1215,7 @@ class Response:
|
||||
for line in decoder.flush():
|
||||
yield line
|
||||
|
||||
def iter_raw(self) -> typing.Iterator[bytes]:
|
||||
def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
"""
|
||||
@ -1210,10 +1228,17 @@ class Response:
|
||||
|
||||
self.is_stream_consumed = True
|
||||
self._num_bytes_downloaded = 0
|
||||
chunker = ByteChunker(chunk_size=chunk_size)
|
||||
|
||||
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
|
||||
for part in self.stream:
|
||||
self._num_bytes_downloaded += len(part)
|
||||
yield part
|
||||
for raw_stream_bytes in self.stream:
|
||||
self._num_bytes_downloaded += len(raw_stream_bytes)
|
||||
for chunk in chunker.decode(raw_stream_bytes):
|
||||
yield chunk
|
||||
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
@ -1234,31 +1259,47 @@ class Response:
|
||||
self._content = b"".join([part async for part in self.aiter_bytes()])
|
||||
return self._content
|
||||
|
||||
async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
|
||||
async def aiter_bytes(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
This allows us to handle gzip, deflate, and brotli encoded responses.
|
||||
"""
|
||||
if hasattr(self, "_content"):
|
||||
yield self._content
|
||||
chunk_size = len(self._content) if chunk_size is None else chunk_size
|
||||
for i in range(0, len(self._content), chunk_size):
|
||||
yield self._content[i : i + chunk_size]
|
||||
else:
|
||||
decoder = self._get_content_decoder()
|
||||
chunker = ByteChunker(chunk_size=chunk_size)
|
||||
with self._wrap_decoder_errors():
|
||||
async for chunk in self.aiter_raw():
|
||||
yield decoder.decode(chunk)
|
||||
yield decoder.flush()
|
||||
async for raw_bytes in self.aiter_raw():
|
||||
decoded = decoder.decode(raw_bytes)
|
||||
for chunk in chunker.decode(decoded):
|
||||
yield chunk
|
||||
decoded = decoder.flush()
|
||||
for chunk in chunker.decode(decoded):
|
||||
yield chunk
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
async def aiter_text(self) -> typing.AsyncIterator[str]:
|
||||
async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]:
|
||||
"""
|
||||
A str-iterator over the decoded response content
|
||||
that handles both gzip, deflate, etc but also detects the content's
|
||||
string encoding.
|
||||
"""
|
||||
decoder = TextDecoder(encoding=self.encoding)
|
||||
chunker = TextChunker(chunk_size=chunk_size)
|
||||
with self._wrap_decoder_errors():
|
||||
async for chunk in self.aiter_bytes():
|
||||
yield decoder.decode(chunk)
|
||||
yield decoder.flush()
|
||||
async for byte_content in self.aiter_bytes():
|
||||
text_content = decoder.decode(byte_content)
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk
|
||||
text_content = decoder.flush()
|
||||
for chunk in chunker.decode(text_content):
|
||||
yield chunk
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
async def aiter_lines(self) -> typing.AsyncIterator[str]:
|
||||
decoder = LineDecoder()
|
||||
@ -1269,7 +1310,7 @@ class Response:
|
||||
for line in decoder.flush():
|
||||
yield line
|
||||
|
||||
async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
|
||||
async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
"""
|
||||
@ -1282,10 +1323,17 @@ class Response:
|
||||
|
||||
self.is_stream_consumed = True
|
||||
self._num_bytes_downloaded = 0
|
||||
chunker = ByteChunker(chunk_size=chunk_size)
|
||||
|
||||
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
|
||||
async for part in self.stream:
|
||||
self._num_bytes_downloaded += len(part)
|
||||
yield part
|
||||
async for raw_stream_bytes in self.stream:
|
||||
self._num_bytes_downloaded += len(raw_stream_bytes)
|
||||
for chunk in chunker.decode(raw_stream_bytes):
|
||||
yield chunk
|
||||
|
||||
for chunk in chunker.flush():
|
||||
yield chunk
|
||||
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
|
||||
@ -343,6 +343,23 @@ def test_iter_raw():
|
||||
assert raw == b"Hello, world!"
|
||||
|
||||
|
||||
def test_iter_raw_with_chunksize():
|
||||
response = httpx.Response(200, content=streaming_body())
|
||||
|
||||
parts = [part for part in response.iter_raw(chunk_size=5)]
|
||||
assert parts == [b"Hello", b", wor", b"ld!"]
|
||||
|
||||
response = httpx.Response(200, content=streaming_body())
|
||||
|
||||
parts = [part for part in response.iter_raw(chunk_size=13)]
|
||||
assert parts == [b"Hello, world!"]
|
||||
|
||||
response = httpx.Response(200, content=streaming_body())
|
||||
|
||||
parts = [part for part in response.iter_raw(chunk_size=20)]
|
||||
assert parts == [b"Hello, world!"]
|
||||
|
||||
|
||||
def test_iter_raw_on_iterable():
|
||||
response = httpx.Response(
|
||||
200,
|
||||
@ -384,6 +401,24 @@ async def test_aiter_raw():
|
||||
assert raw == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_raw_with_chunksize():
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
|
||||
parts = [part async for part in response.aiter_raw(chunk_size=5)]
|
||||
assert parts == [b"Hello", b", wor", b"ld!"]
|
||||
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
|
||||
parts = [part async for part in response.aiter_raw(chunk_size=13)]
|
||||
assert parts == [b"Hello, world!"]
|
||||
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
|
||||
parts = [part async for part in response.aiter_raw(chunk_size=20)]
|
||||
assert parts == [b"Hello, world!"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_raw_on_sync():
|
||||
response = httpx.Response(
|
||||
@ -406,10 +441,7 @@ async def test_aiter_raw_increments_updates_counter():
|
||||
|
||||
|
||||
def test_iter_bytes():
|
||||
response = httpx.Response(
|
||||
200,
|
||||
content=b"Hello, world!",
|
||||
)
|
||||
response = httpx.Response(200, content=b"Hello, world!")
|
||||
|
||||
content = b""
|
||||
for part in response.iter_bytes():
|
||||
@ -417,6 +449,20 @@ def test_iter_bytes():
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
def test_iter_bytes_with_chunk_size():
|
||||
response = httpx.Response(200, content=streaming_body())
|
||||
parts = [part for part in response.iter_bytes(chunk_size=5)]
|
||||
assert parts == [b"Hello", b", wor", b"ld!"]
|
||||
|
||||
response = httpx.Response(200, content=streaming_body())
|
||||
parts = [part for part in response.iter_bytes(chunk_size=13)]
|
||||
assert parts == [b"Hello, world!"]
|
||||
|
||||
response = httpx.Response(200, content=streaming_body())
|
||||
parts = [part for part in response.iter_bytes(chunk_size=20)]
|
||||
assert parts == [b"Hello, world!"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_bytes():
|
||||
response = httpx.Response(
|
||||
@ -430,6 +476,21 @@ async def test_aiter_bytes():
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_bytes_with_chunk_size():
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
parts = [part async for part in response.aiter_bytes(chunk_size=5)]
|
||||
assert parts == [b"Hello", b", wor", b"ld!"]
|
||||
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
parts = [part async for part in response.aiter_bytes(chunk_size=13)]
|
||||
assert parts == [b"Hello, world!"]
|
||||
|
||||
response = httpx.Response(200, content=async_streaming_body())
|
||||
parts = [part async for part in response.aiter_bytes(chunk_size=20)]
|
||||
assert parts == [b"Hello, world!"]
|
||||
|
||||
|
||||
def test_iter_text():
|
||||
response = httpx.Response(
|
||||
200,
|
||||
@ -442,6 +503,20 @@ def test_iter_text():
|
||||
assert content == "Hello, world!"
|
||||
|
||||
|
||||
def test_iter_text_with_chunk_size():
|
||||
response = httpx.Response(200, content=b"Hello, world!")
|
||||
parts = [part for part in response.iter_text(chunk_size=5)]
|
||||
assert parts == ["Hello", ", wor", "ld!"]
|
||||
|
||||
response = httpx.Response(200, content=b"Hello, world!")
|
||||
parts = [part for part in response.iter_text(chunk_size=13)]
|
||||
assert parts == ["Hello, world!"]
|
||||
|
||||
response = httpx.Response(200, content=b"Hello, world!")
|
||||
parts = [part for part in response.iter_text(chunk_size=20)]
|
||||
assert parts == ["Hello, world!"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_text():
|
||||
response = httpx.Response(
|
||||
@ -455,6 +530,21 @@ async def test_aiter_text():
|
||||
assert content == "Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_text_with_chunk_size():
|
||||
response = httpx.Response(200, content=b"Hello, world!")
|
||||
parts = [part async for part in response.aiter_text(chunk_size=5)]
|
||||
assert parts == ["Hello", ", wor", "ld!"]
|
||||
|
||||
response = httpx.Response(200, content=b"Hello, world!")
|
||||
parts = [part async for part in response.aiter_text(chunk_size=13)]
|
||||
assert parts == ["Hello, world!"]
|
||||
|
||||
response = httpx.Response(200, content=b"Hello, world!")
|
||||
parts = [part async for part in response.aiter_text(chunk_size=20)]
|
||||
assert parts == ["Hello, world!"]
|
||||
|
||||
|
||||
def test_iter_lines():
|
||||
response = httpx.Response(
|
||||
200,
|
||||
|
||||
@ -6,10 +6,12 @@ import pytest
|
||||
import httpx
|
||||
from httpx._decoders import (
|
||||
BrotliDecoder,
|
||||
ByteChunker,
|
||||
DeflateDecoder,
|
||||
GZipDecoder,
|
||||
IdentityDecoder,
|
||||
LineDecoder,
|
||||
TextChunker,
|
||||
TextDecoder,
|
||||
)
|
||||
|
||||
@ -300,6 +302,50 @@ def test_line_decoder_crnl():
|
||||
assert decoder.flush() == []
|
||||
|
||||
|
||||
def test_byte_chunker():
|
||||
decoder = ByteChunker()
|
||||
assert decoder.decode(b"1234567") == [b"1234567"]
|
||||
assert decoder.decode(b"89") == [b"89"]
|
||||
assert decoder.flush() == []
|
||||
|
||||
decoder = ByteChunker(chunk_size=3)
|
||||
assert decoder.decode(b"1234567") == [b"123", b"456"]
|
||||
assert decoder.decode(b"89") == [b"789"]
|
||||
assert decoder.flush() == []
|
||||
|
||||
decoder = ByteChunker(chunk_size=3)
|
||||
assert decoder.decode(b"123456") == [b"123", b"456"]
|
||||
assert decoder.decode(b"789") == [b"789"]
|
||||
assert decoder.flush() == []
|
||||
|
||||
decoder = ByteChunker(chunk_size=3)
|
||||
assert decoder.decode(b"123456") == [b"123", b"456"]
|
||||
assert decoder.decode(b"78") == []
|
||||
assert decoder.flush() == [b"78"]
|
||||
|
||||
|
||||
def test_text_chunker():
|
||||
decoder = TextChunker()
|
||||
assert decoder.decode("1234567") == ["1234567"]
|
||||
assert decoder.decode("89") == ["89"]
|
||||
assert decoder.flush() == []
|
||||
|
||||
decoder = TextChunker(chunk_size=3)
|
||||
assert decoder.decode("1234567") == ["123", "456"]
|
||||
assert decoder.decode("89") == ["789"]
|
||||
assert decoder.flush() == []
|
||||
|
||||
decoder = TextChunker(chunk_size=3)
|
||||
assert decoder.decode("123456") == ["123", "456"]
|
||||
assert decoder.decode("789") == ["789"]
|
||||
assert decoder.flush() == []
|
||||
|
||||
decoder = TextChunker(chunk_size=3)
|
||||
assert decoder.decode("123456") == ["123", "456"]
|
||||
assert decoder.decode("78") == []
|
||||
assert decoder.flush() == ["78"]
|
||||
|
||||
|
||||
def test_invalid_content_encoding_header():
|
||||
headers = [(b"Content-Encoding", b"invalid-header")]
|
||||
body = b"test 123"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user