Support for chunk_size (#1277)

* Support iter_raw(chunk_size=...) and aiter_raw(chunk_size=...)

* Unit tests for ByteChunker

* Support iter_bytes(chunk_size=...)

* Add TextChunker

* Support iter_text(chunk_size=...)

* Fix merge with master

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
This commit is contained in:
Tom Christie 2020-11-25 15:28:06 +00:00 committed by GitHub
parent c4d2e6fa28
commit 27df5e49c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 293 additions and 30 deletions

View File

@ -4,6 +4,7 @@ Handlers for Content-Encoding.
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
"""
import codecs
import io
import typing
import zlib
@ -155,6 +156,84 @@ class MultiDecoder(ContentDecoder):
return data
class ByteChunker:
"""
Handles returning byte content in fixed-size chunks.
"""
def __init__(self, chunk_size: int = None) -> None:
self._buffer = io.BytesIO()
self._chunk_size = chunk_size
def decode(self, content: bytes) -> typing.List[bytes]:
if self._chunk_size is None:
return [content]
self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
value = self._buffer.getvalue()
chunks = [
value[i : i + self._chunk_size]
for i in range(0, len(value), self._chunk_size)
]
if len(chunks[-1]) == self._chunk_size:
self._buffer.seek(0)
self._buffer.truncate()
return chunks
else:
self._buffer.seek(0)
self._buffer.write(chunks[-1])
self._buffer.truncate()
return chunks[:-1]
else:
return []
def flush(self) -> typing.List[bytes]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return [value] if value else []
class TextChunker:
"""
Handles returning text content in fixed-size chunks.
"""
def __init__(self, chunk_size: int = None) -> None:
self._buffer = io.StringIO()
self._chunk_size = chunk_size
def decode(self, content: str) -> typing.List[str]:
if self._chunk_size is None:
return [content]
self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
value = self._buffer.getvalue()
chunks = [
value[i : i + self._chunk_size]
for i in range(0, len(value), self._chunk_size)
]
if len(chunks[-1]) == self._chunk_size:
self._buffer.seek(0)
self._buffer.truncate()
return chunks
else:
self._buffer.seek(0)
self._buffer.write(chunks[-1])
self._buffer.truncate()
return chunks[:-1]
else:
return []
def flush(self) -> typing.List[str]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return [value] if value else []
class TextDecoder:
"""
Handles incrementally decoding bytes into text

View File

@ -15,10 +15,12 @@ import rfc3986.exceptions
from ._content import PlainByteStream, encode_request, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ByteChunker,
ContentDecoder,
IdentityDecoder,
LineDecoder,
MultiDecoder,
TextChunker,
TextDecoder,
)
from ._exceptions import (
@ -1162,31 +1164,47 @@ class Response:
self._content = b"".join(self.iter_bytes())
return self._content
def iter_bytes(self) -> typing.Iterator[bytes]:
def iter_bytes(self, chunk_size: int = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
chunk_size = len(self._content) if chunk_size is None else chunk_size
for i in range(0, len(self._content), chunk_size):
yield self._content[i : i + chunk_size]
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
for chunk in self.iter_raw():
yield decoder.decode(chunk)
yield decoder.flush()
for raw_bytes in self.iter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk
for chunk in chunker.flush():
yield chunk
def iter_text(self) -> typing.Iterator[str]:
def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(encoding=self.encoding)
chunker = TextChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
for chunk in self.iter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
for byte_content in self.iter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
for chunk in chunker.flush():
yield chunk
def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
@ -1197,7 +1215,7 @@ class Response:
for line in decoder.flush():
yield line
def iter_raw(self) -> typing.Iterator[bytes]:
def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
@ -1210,10 +1228,17 @@ class Response:
self.is_stream_consumed = True
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
for part in self.stream:
self._num_bytes_downloaded += len(part)
yield part
for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk
for chunk in chunker.flush():
yield chunk
self.close()
def close(self) -> None:
@ -1234,31 +1259,47 @@ class Response:
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content
async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
async def aiter_bytes(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
chunk_size = len(self._content) if chunk_size is None else chunk_size
for i in range(0, len(self._content), chunk_size):
yield self._content[i : i + chunk_size]
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
async for chunk in self.aiter_raw():
yield decoder.decode(chunk)
yield decoder.flush()
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk
for chunk in chunker.flush():
yield chunk
async def aiter_text(self) -> typing.AsyncIterator[str]:
async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(encoding=self.encoding)
chunker = TextChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
async for chunk in self.aiter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
for chunk in chunker.flush():
yield chunk
async def aiter_lines(self) -> typing.AsyncIterator[str]:
decoder = LineDecoder()
@ -1269,7 +1310,7 @@ class Response:
for line in decoder.flush():
yield line
async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
@ -1282,10 +1323,17 @@ class Response:
self.is_stream_consumed = True
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
async for part in self.stream:
self._num_bytes_downloaded += len(part)
yield part
async for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk
for chunk in chunker.flush():
yield chunk
await self.aclose()
async def aclose(self) -> None:

View File

@ -343,6 +343,23 @@ def test_iter_raw():
assert raw == b"Hello, world!"
def test_iter_raw_with_chunksize():
response = httpx.Response(200, content=streaming_body())
parts = [part for part in response.iter_raw(chunk_size=5)]
assert parts == [b"Hello", b", wor", b"ld!"]
response = httpx.Response(200, content=streaming_body())
parts = [part for part in response.iter_raw(chunk_size=13)]
assert parts == [b"Hello, world!"]
response = httpx.Response(200, content=streaming_body())
parts = [part for part in response.iter_raw(chunk_size=20)]
assert parts == [b"Hello, world!"]
def test_iter_raw_on_iterable():
response = httpx.Response(
200,
@ -384,6 +401,24 @@ async def test_aiter_raw():
assert raw == b"Hello, world!"
@pytest.mark.asyncio
async def test_aiter_raw_with_chunksize():
response = httpx.Response(200, content=async_streaming_body())
parts = [part async for part in response.aiter_raw(chunk_size=5)]
assert parts == [b"Hello", b", wor", b"ld!"]
response = httpx.Response(200, content=async_streaming_body())
parts = [part async for part in response.aiter_raw(chunk_size=13)]
assert parts == [b"Hello, world!"]
response = httpx.Response(200, content=async_streaming_body())
parts = [part async for part in response.aiter_raw(chunk_size=20)]
assert parts == [b"Hello, world!"]
@pytest.mark.asyncio
async def test_aiter_raw_on_sync():
response = httpx.Response(
@ -406,10 +441,7 @@ async def test_aiter_raw_increments_updates_counter():
def test_iter_bytes():
response = httpx.Response(
200,
content=b"Hello, world!",
)
response = httpx.Response(200, content=b"Hello, world!")
content = b""
for part in response.iter_bytes():
@ -417,6 +449,20 @@ def test_iter_bytes():
assert content == b"Hello, world!"
def test_iter_bytes_with_chunk_size():
response = httpx.Response(200, content=streaming_body())
parts = [part for part in response.iter_bytes(chunk_size=5)]
assert parts == [b"Hello", b", wor", b"ld!"]
response = httpx.Response(200, content=streaming_body())
parts = [part for part in response.iter_bytes(chunk_size=13)]
assert parts == [b"Hello, world!"]
response = httpx.Response(200, content=streaming_body())
parts = [part for part in response.iter_bytes(chunk_size=20)]
assert parts == [b"Hello, world!"]
@pytest.mark.asyncio
async def test_aiter_bytes():
response = httpx.Response(
@ -430,6 +476,21 @@ async def test_aiter_bytes():
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_aiter_bytes_with_chunk_size():
response = httpx.Response(200, content=async_streaming_body())
parts = [part async for part in response.aiter_bytes(chunk_size=5)]
assert parts == [b"Hello", b", wor", b"ld!"]
response = httpx.Response(200, content=async_streaming_body())
parts = [part async for part in response.aiter_bytes(chunk_size=13)]
assert parts == [b"Hello, world!"]
response = httpx.Response(200, content=async_streaming_body())
parts = [part async for part in response.aiter_bytes(chunk_size=20)]
assert parts == [b"Hello, world!"]
def test_iter_text():
response = httpx.Response(
200,
@ -442,6 +503,20 @@ def test_iter_text():
assert content == "Hello, world!"
def test_iter_text_with_chunk_size():
response = httpx.Response(200, content=b"Hello, world!")
parts = [part for part in response.iter_text(chunk_size=5)]
assert parts == ["Hello", ", wor", "ld!"]
response = httpx.Response(200, content=b"Hello, world!")
parts = [part for part in response.iter_text(chunk_size=13)]
assert parts == ["Hello, world!"]
response = httpx.Response(200, content=b"Hello, world!")
parts = [part for part in response.iter_text(chunk_size=20)]
assert parts == ["Hello, world!"]
@pytest.mark.asyncio
async def test_aiter_text():
response = httpx.Response(
@ -455,6 +530,21 @@ async def test_aiter_text():
assert content == "Hello, world!"
@pytest.mark.asyncio
async def test_aiter_text_with_chunk_size():
response = httpx.Response(200, content=b"Hello, world!")
parts = [part async for part in response.aiter_text(chunk_size=5)]
assert parts == ["Hello", ", wor", "ld!"]
response = httpx.Response(200, content=b"Hello, world!")
parts = [part async for part in response.aiter_text(chunk_size=13)]
assert parts == ["Hello, world!"]
response = httpx.Response(200, content=b"Hello, world!")
parts = [part async for part in response.aiter_text(chunk_size=20)]
assert parts == ["Hello, world!"]
def test_iter_lines():
response = httpx.Response(
200,

View File

@ -6,10 +6,12 @@ import pytest
import httpx
from httpx._decoders import (
BrotliDecoder,
ByteChunker,
DeflateDecoder,
GZipDecoder,
IdentityDecoder,
LineDecoder,
TextChunker,
TextDecoder,
)
@ -300,6 +302,50 @@ def test_line_decoder_crnl():
assert decoder.flush() == []
def test_byte_chunker():
decoder = ByteChunker()
assert decoder.decode(b"1234567") == [b"1234567"]
assert decoder.decode(b"89") == [b"89"]
assert decoder.flush() == []
decoder = ByteChunker(chunk_size=3)
assert decoder.decode(b"1234567") == [b"123", b"456"]
assert decoder.decode(b"89") == [b"789"]
assert decoder.flush() == []
decoder = ByteChunker(chunk_size=3)
assert decoder.decode(b"123456") == [b"123", b"456"]
assert decoder.decode(b"789") == [b"789"]
assert decoder.flush() == []
decoder = ByteChunker(chunk_size=3)
assert decoder.decode(b"123456") == [b"123", b"456"]
assert decoder.decode(b"78") == []
assert decoder.flush() == [b"78"]
def test_text_chunker():
decoder = TextChunker()
assert decoder.decode("1234567") == ["1234567"]
assert decoder.decode("89") == ["89"]
assert decoder.flush() == []
decoder = TextChunker(chunk_size=3)
assert decoder.decode("1234567") == ["123", "456"]
assert decoder.decode("89") == ["789"]
assert decoder.flush() == []
decoder = TextChunker(chunk_size=3)
assert decoder.decode("123456") == ["123", "456"]
assert decoder.decode("789") == ["789"]
assert decoder.flush() == []
decoder = TextChunker(chunk_size=3)
assert decoder.decode("123456") == ["123", "456"]
assert decoder.decode("78") == []
assert decoder.flush() == ["78"]
def test_invalid_content_encoding_header():
headers = [(b"Content-Encoding", b"invalid-header")]
body = b"test 123"