Merge fba12be30a into b5addb64f0
This commit is contained in:
commit
71d054331b
@ -176,6 +176,7 @@ class ZStandardDecoder(ContentDecoder):
|
||||
|
||||
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
|
||||
self.seen_data = False
|
||||
self.seen_eof = False
|
||||
|
||||
def decode(self, data: bytes) -> bytes:
|
||||
assert zstandard is not None
|
||||
@ -187,6 +188,11 @@ class ZStandardDecoder(ContentDecoder):
|
||||
unused_data = self.decompressor.unused_data
|
||||
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
|
||||
output.write(self.decompressor.decompress(unused_data))
|
||||
# If the decompressor reached EOF, create a new one for the next call
|
||||
# since zstd decompressors cannot be reused after EOF
|
||||
if self.decompressor.eof:
|
||||
self.seen_eof = True
|
||||
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
|
||||
except zstandard.ZstdError as exc:
|
||||
raise DecodingError(str(exc)) from exc
|
||||
return output.getvalue()
|
||||
@ -195,7 +201,7 @@ class ZStandardDecoder(ContentDecoder):
|
||||
if not self.seen_data:
|
||||
return b""
|
||||
ret = self.decompressor.flush() # note: this is a no-op
|
||||
if not self.decompressor.eof:
|
||||
if not self.decompressor.eof and not self.seen_eof:
|
||||
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
|
||||
return bytes(ret)
|
||||
|
||||
|
||||
@ -141,6 +141,26 @@ def test_zstd_multiframe():
|
||||
assert response.content == b"foobar"
|
||||
|
||||
|
||||
def test_zstd_streaming_multiple_frames():
|
||||
body1 = b"test 123 "
|
||||
body2 = b"another frame"
|
||||
|
||||
# Create two separate complete frames
|
||||
frame1 = zstd.compress(body1)
|
||||
frame2 = zstd.compress(body2)
|
||||
|
||||
# Create an iterator that yields frames separately
|
||||
def content_iterator() -> typing.Iterator[bytes]:
|
||||
yield frame1
|
||||
yield frame2
|
||||
|
||||
headers = [(b"Content-Encoding", b"zstd")]
|
||||
response = httpx.Response(200, headers=headers, content=content_iterator())
|
||||
response.read()
|
||||
|
||||
assert response.content == body1 + body2
|
||||
|
||||
|
||||
def test_multi():
|
||||
body = b"test 123"
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user