Zstandard: fix crash on frame boundaries
Co-authored-by: Michiel W. Beijen <mb@x14.nl>
This commit is contained in:
parent
4d05db33ce
commit
7eee90d1b9
@ -179,28 +179,27 @@ class ZStandardDecoder(ContentDecoder):
|
||||
) from None
|
||||
|
||||
self.decompressor = zstd.ZstdDecompressor()
|
||||
self.seen_data = False
|
||||
self.at_valid_eof = True
|
||||
|
||||
def decode(self, data: bytes) -> bytes:
|
||||
assert zstd is not None
|
||||
self.seen_data = True
|
||||
output = io.BytesIO()
|
||||
try:
|
||||
output.write(self.decompressor.decompress(data))
|
||||
while self.decompressor.eof and self.decompressor.unused_data:
|
||||
unused_data = self.decompressor.unused_data
|
||||
self.decompressor = zstd.ZstdDecompressor()
|
||||
output.write(self.decompressor.decompress(unused_data))
|
||||
self.at_valid_eof = False
|
||||
while data:
|
||||
output.write(self.decompressor.decompress(data))
|
||||
data = self.decompressor.unused_data
|
||||
if self.decompressor.eof:
|
||||
self.decompressor = zstd.ZstdDecompressor()
|
||||
self.at_valid_eof = not data
|
||||
except zstd.ZstdError as exc:
|
||||
raise DecodingError(str(exc)) from exc
|
||||
return output.getvalue()
|
||||
|
||||
def flush(self) -> bytes:
|
||||
if not self.seen_data:
|
||||
if self.at_valid_eof:
|
||||
return b""
|
||||
if not self.decompressor.eof:
|
||||
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
|
||||
return b""
|
||||
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
|
||||
|
||||
|
||||
class MultiDecoder(ContentDecoder):
|
||||
|
||||
@ -120,7 +120,7 @@ def test_zstd_truncated():
|
||||
httpx.Response(
|
||||
200,
|
||||
headers=headers,
|
||||
content=compressed_body[1:3],
|
||||
content=compressed_body[:-1],
|
||||
)
|
||||
|
||||
|
||||
@ -146,6 +146,39 @@ def test_zstd_multiframe():
|
||||
assert response.content == b"foobar"
|
||||
|
||||
|
||||
def test_zstd_truncated_multiframe():
|
||||
body = b"test 123"
|
||||
compressed_body = zstd.compress(body)
|
||||
|
||||
headers = [(b"Content-Encoding", b"zstd")]
|
||||
with pytest.raises(httpx.DecodingError):
|
||||
httpx.Response(
|
||||
200,
|
||||
headers=headers,
|
||||
content=compressed_body + compressed_body[:-1],
|
||||
)
|
||||
|
||||
|
||||
def test_zstd_streaming_multiple_frames():
|
||||
body1 = b"test 123 "
|
||||
body2 = b"another frame"
|
||||
|
||||
# Create two separate complete frames
|
||||
frame1 = zstd.compress(body1)
|
||||
frame2 = zstd.compress(body2)
|
||||
|
||||
# Create an iterator that yields frames separately
|
||||
def content_iterator() -> typing.Iterator[bytes]:
|
||||
yield frame1
|
||||
yield frame2
|
||||
|
||||
headers = [(b"Content-Encoding", b"zstd")]
|
||||
response = httpx.Response(200, headers=headers, content=content_iterator())
|
||||
response.read()
|
||||
|
||||
assert response.content == body1 + body2
|
||||
|
||||
|
||||
def test_multi():
|
||||
body = b"test 123"
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user