From 7eee90d1b9e0c1590dcd46de9d8ea99b2dffbe99 Mon Sep 17 00:00:00 2001 From: Rogdham Date: Sat, 13 Dec 2025 17:05:25 +0100 Subject: [PATCH] Zstandard: fix crash on frame boundaries Co-authored-by: Michiel W. Beijen --- httpx/_decoders.py | 21 ++++++++++----------- tests/test_decoders.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/httpx/_decoders.py b/httpx/_decoders.py index f35de81f..9a898577 100644 --- a/httpx/_decoders.py +++ b/httpx/_decoders.py @@ -179,28 +179,27 @@ class ZStandardDecoder(ContentDecoder): ) from None self.decompressor = zstd.ZstdDecompressor() - self.seen_data = False + self.at_valid_eof = True def decode(self, data: bytes) -> bytes: assert zstd is not None - self.seen_data = True output = io.BytesIO() try: - output.write(self.decompressor.decompress(data)) - while self.decompressor.eof and self.decompressor.unused_data: - unused_data = self.decompressor.unused_data - self.decompressor = zstd.ZstdDecompressor() - output.write(self.decompressor.decompress(unused_data)) + self.at_valid_eof = False + while data: + output.write(self.decompressor.decompress(data)) + data = self.decompressor.unused_data + if self.decompressor.eof: + self.decompressor = zstd.ZstdDecompressor() + self.at_valid_eof = not data except zstd.ZstdError as exc: raise DecodingError(str(exc)) from exc return output.getvalue() def flush(self) -> bytes: - if not self.seen_data: + if self.at_valid_eof: return b"" - if not self.decompressor.eof: - raise DecodingError("Zstandard data is incomplete") # pragma: no cover - return b"" + raise DecodingError("Zstandard data is incomplete") # pragma: no cover class MultiDecoder(ContentDecoder): diff --git a/tests/test_decoders.py b/tests/test_decoders.py index e1192ac0..4c14bb52 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -120,7 +120,7 @@ def test_zstd_truncated(): httpx.Response( 200, headers=headers, - content=compressed_body[1:3], + content=compressed_body[:-1], ) @@ -146,6 +146,39 @@ def test_zstd_multiframe(): assert response.content == b"foobar" +def test_zstd_truncated_multiframe(): + body = b"test 123" + compressed_body = zstd.compress(body) + + headers = [(b"Content-Encoding", b"zstd")] + with pytest.raises(httpx.DecodingError): + httpx.Response( + 200, + headers=headers, + content=compressed_body + compressed_body[:-1], + ) + + +def test_zstd_streaming_multiple_frames(): + body1 = b"test 123 " + body2 = b"another frame" + + # Create two separate complete frames + frame1 = zstd.compress(body1) + frame2 = zstd.compress(body2) + + # Create an iterator that yields frames separately + def content_iterator() -> typing.Iterator[bytes]: + yield frame1 + yield frame2 + + headers = [(b"Content-Encoding", b"zstd")] + response = httpx.Response(200, headers=headers, content=content_iterator()) + response.read() + + assert response.content == body1 + body2 + + def test_multi(): body = b"test 123"