Implement Response.stream_text() (#183)
This commit is contained in:
parent
9536b76f5c
commit
af907e85e1
@ -3,9 +3,12 @@ Handlers for Content-Encoding.
|
||||
|
||||
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
|
||||
"""
|
||||
import codecs
|
||||
import typing
|
||||
import zlib
|
||||
|
||||
import chardet
|
||||
|
||||
from .exceptions import DecodingError
|
||||
|
||||
try:
|
||||
@ -138,6 +141,70 @@ class MultiDecoder(Decoder):
|
||||
return data
|
||||
|
||||
|
||||
class TextDecoder:
|
||||
"""
|
||||
Handles incrementally decoding bytes into text
|
||||
"""
|
||||
|
||||
def __init__(self, encoding: typing.Optional[str] = None):
|
||||
self.decoder: typing.Optional[codecs.IncrementalDecoder] = (
|
||||
None if encoding is None else codecs.getincrementaldecoder(encoding)()
|
||||
)
|
||||
self.detector = chardet.universaldetector.UniversalDetector()
|
||||
|
||||
# This buffer is only needed if 'decoder' is 'None'
|
||||
# we want to trigger errors if data is getting added to
|
||||
# our internal buffer for some silly reason while
|
||||
# a decoder is discovered.
|
||||
self.buffer: typing.Optional[bytearray] = None if self.decoder else bytearray()
|
||||
|
||||
def decode(self, data: bytes) -> str:
|
||||
try:
|
||||
if self.decoder is not None:
|
||||
text = self.decoder.decode(data)
|
||||
else:
|
||||
assert self.buffer is not None
|
||||
text = ""
|
||||
self.detector.feed(data)
|
||||
self.buffer += data
|
||||
|
||||
# Should be more than enough data to process, we don't
|
||||
# want to buffer too long as chardet will wait until
|
||||
# detector.close() is used to give back common
|
||||
# encodings like 'utf-8'.
|
||||
if len(self.buffer) >= 4096:
|
||||
self.decoder = codecs.getincrementaldecoder(
|
||||
self._detector_result()
|
||||
)()
|
||||
text = self.decoder.decode(bytes(self.buffer), False)
|
||||
self.buffer = None
|
||||
|
||||
return text
|
||||
except UnicodeDecodeError: # pragma: nocover
|
||||
raise DecodingError() from None
|
||||
|
||||
def flush(self) -> str:
|
||||
try:
|
||||
if self.decoder is None:
|
||||
# Empty string case as chardet is guaranteed to not have a guess.
|
||||
assert self.buffer is not None
|
||||
if len(self.buffer) == 0:
|
||||
return ""
|
||||
return bytes(self.buffer).decode(self._detector_result())
|
||||
|
||||
return self.decoder.decode(b"", True)
|
||||
except UnicodeDecodeError: # pragma: nocover
|
||||
raise DecodingError() from None
|
||||
|
||||
def _detector_result(self) -> str:
|
||||
self.detector.close()
|
||||
result = self.detector.result["encoding"]
|
||||
if not result: # pragma: nocover
|
||||
raise DecodingError("Unable to determine encoding of content")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
SUPPORTED_DECODERS = {
|
||||
"identity": IdentityDecoder,
|
||||
"gzip": GZipDecoder,
|
||||
|
||||
@ -17,6 +17,7 @@ from .decoders import (
|
||||
Decoder,
|
||||
IdentityDecoder,
|
||||
MultiDecoder,
|
||||
TextDecoder,
|
||||
)
|
||||
from .exceptions import (
|
||||
CookieConflict,
|
||||
@ -890,6 +891,17 @@ class AsyncResponse(BaseResponse):
|
||||
yield self.decoder.decode(chunk)
|
||||
yield self.decoder.flush()
|
||||
|
||||
async def stream_text(self) -> typing.AsyncIterator[str]:
|
||||
"""
|
||||
A str-iterator over the decoded response content
|
||||
that handles both gzip, deflate, etc but also detects the content's
|
||||
string encoding.
|
||||
"""
|
||||
decoder = TextDecoder(encoding=self.charset_encoding)
|
||||
async for chunk in self.stream():
|
||||
yield decoder.decode(chunk)
|
||||
yield decoder.flush()
|
||||
|
||||
async def raw(self) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
@ -969,6 +981,17 @@ class Response(BaseResponse):
|
||||
yield self.decoder.decode(chunk)
|
||||
yield self.decoder.flush()
|
||||
|
||||
def stream_text(self) -> typing.Iterator[str]:
|
||||
"""
|
||||
A str-iterator over the decoded response content
|
||||
that handles both gzip, deflate, etc but also detects the content's
|
||||
string encoding.
|
||||
"""
|
||||
decoder = TextDecoder(encoding=self.charset_encoding)
|
||||
for chunk in self.stream():
|
||||
yield decoder.decode(chunk)
|
||||
yield decoder.flush()
|
||||
|
||||
def raw(self) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the raw response content.
|
||||
|
||||
@ -288,3 +288,18 @@ def test_json_without_specified_encoding_decode_error():
|
||||
response = httpx.Response(200, content=content, headers=headers)
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_text():
|
||||
async def iterator():
|
||||
yield b"Hello, world!"
|
||||
|
||||
response = httpx.AsyncResponse(200, content=iterator().__aiter__())
|
||||
|
||||
await response.read()
|
||||
|
||||
content = ""
|
||||
async for part in response.stream_text():
|
||||
content += part
|
||||
assert content == "Hello, world!"
|
||||
|
||||
@ -4,6 +4,7 @@ import brotli
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx.decoders import TextDecoder
|
||||
|
||||
|
||||
def test_deflate():
|
||||
@ -88,6 +89,57 @@ def test_decoding_errors(header_value):
|
||||
response.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["data", "encoding"],
|
||||
[
|
||||
((b"Hello,", b" world!"), "ascii"),
|
||||
((b"\xe3\x83", b"\x88\xe3\x83\xa9", b"\xe3", b"\x83\x99\xe3\x83\xab"), "utf-8"),
|
||||
((b"\x83g\x83\x89\x83x\x83\x8b",) * 64, "shift-jis"),
|
||||
((b"\x83g\x83\x89\x83x\x83\x8b",) * 600, "shift-jis"),
|
||||
(
|
||||
(b"\xcb\xee\xf0\xe5\xec \xe8\xef\xf1\xf3\xec \xe4\xee\xeb\xee\xf0",) * 64,
|
||||
"MacCyrillic",
|
||||
),
|
||||
(
|
||||
(b"\xa5\xa6\xa5\xa7\xa5\xd6\xa4\xce\xb9\xf1\xba\xdd\xb2\xbd",) * 512,
|
||||
"euc-jp",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_text_decoder(data, encoding):
|
||||
def iterator():
|
||||
nonlocal data
|
||||
for chunk in data:
|
||||
yield chunk
|
||||
|
||||
response = httpx.Response(200, content=iterator())
|
||||
assert "".join(response.stream_text()) == (b"".join(data)).decode(encoding)
|
||||
|
||||
|
||||
def test_text_decoder_known_encoding():
|
||||
def iterator():
|
||||
yield b"\x83g"
|
||||
yield b"\x83"
|
||||
yield b"\x89\x83x\x83\x8b"
|
||||
|
||||
response = httpx.Response(
|
||||
200,
|
||||
headers=[(b"Content-Type", b"text/html; charset=shift-jis")],
|
||||
content=iterator(),
|
||||
)
|
||||
|
||||
assert "".join(response.stream_text()) == "トラベル"
|
||||
|
||||
|
||||
def test_text_decoder_empty_cases():
|
||||
decoder = TextDecoder()
|
||||
assert decoder.flush() == ""
|
||||
|
||||
decoder = TextDecoder()
|
||||
assert decoder.decode(b"") == ""
|
||||
assert decoder.flush() == ""
|
||||
|
||||
|
||||
def test_invalid_content_encoding_header():
|
||||
headers = [(b"Content-Encoding", b"invalid-header")]
|
||||
body = b"test 123"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user