Zstandard: fix crash on frame boundaries

Co-authored-by: Michiel W. Beijen <mb@x14.nl>
This commit is contained in:
Rogdham 2025-12-13 17:05:25 +01:00
parent 4d05db33ce
commit 7eee90d1b9
No known key found for this signature in database
GPG Key ID: 3586E588AEF1642D
2 changed files with 44 additions and 12 deletions

View File

@ -179,28 +179,27 @@ class ZStandardDecoder(ContentDecoder):
) from None
self.decompressor = zstd.ZstdDecompressor()
self.seen_data = False
self.at_valid_eof = True
def decode(self, data: bytes) -> bytes:
assert zstd is not None
self.seen_data = True
output = io.BytesIO()
try:
output.write(self.decompressor.decompress(data))
while self.decompressor.eof and self.decompressor.unused_data:
unused_data = self.decompressor.unused_data
self.decompressor = zstd.ZstdDecompressor()
output.write(self.decompressor.decompress(unused_data))
self.at_valid_eof = False
while data:
output.write(self.decompressor.decompress(data))
data = self.decompressor.unused_data
if self.decompressor.eof:
self.decompressor = zstd.ZstdDecompressor()
self.at_valid_eof = not data
except zstd.ZstdError as exc:
raise DecodingError(str(exc)) from exc
return output.getvalue()
def flush(self) -> bytes:
if not self.seen_data:
if self.at_valid_eof:
return b""
if not self.decompressor.eof:
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
return b""
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
class MultiDecoder(ContentDecoder):

View File

@ -120,7 +120,7 @@ def test_zstd_truncated():
httpx.Response(
200,
headers=headers,
content=compressed_body[1:3],
content=compressed_body[:-1],
)
@ -146,6 +146,39 @@ def test_zstd_multiframe():
assert response.content == b"foobar"
def test_zstd_truncated_multiframe():
body = b"test 123"
compressed_body = zstd.compress(body)
headers = [(b"Content-Encoding", b"zstd")]
with pytest.raises(httpx.DecodingError):
httpx.Response(
200,
headers=headers,
content=compressed_body + compressed_body[:-1],
)
def test_zstd_streaming_multiple_frames():
body1 = b"test 123 "
body2 = b"another frame"
# Create two separate complete frames
frame1 = zstd.compress(body1)
frame2 = zstd.compress(body2)
# Create an iterator that yields frames separately
def content_iterator() -> typing.Iterator[bytes]:
yield frame1
yield frame2
headers = [(b"Content-Encoding", b"zstd")]
response = httpx.Response(200, headers=headers, content=content_iterator())
response.read()
assert response.content == body1 + body2
def test_multi():
body = b"test 123"