Raise error if PNG transparency has incorrect type or length when saving (#9536)

This commit is contained in:
Hugo van Kemenade 2026-05-03 13:25:49 +03:00 committed by GitHub
commit 82614324ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 20 deletions

View File

@ -502,8 +502,9 @@ class TestFilePng:
im = roundtrip(im)
assert im.info["transparency"] == (248, 248, 248)
im = roundtrip(im, transparency=(0, 1, 2))
assert im.info["transparency"] == (0, 1, 2)
for transparency in ((0, 1, 2), [0, 1, 2]):
im = roundtrip(im, transparency=transparency)
assert im.info["transparency"] == (0, 1, 2)
def test_trns_p(self, tmp_path: Path) -> None:
# Check writing a transparency of 0, issue #528
@ -518,6 +519,36 @@ class TestFilePng:
assert_image_equal(im2.convert("RGBA"), im.convert("RGBA"))
def test_trns_invalid(self, tmp_path: Path) -> None:
out = tmp_path / "temp.png"
for mode in ("1", "L", "I;16"):
im = Image.new(mode, (1, 1))
with pytest.raises(
ValueError, match=f"transparency for {mode} must be an integer"
):
im.save(out, transparency="invalid")
im = Image.new("I", (1, 1))
with pytest.warns(DeprecationWarning, match="Saving I mode images as PNG"):
with pytest.raises(ValueError):
im.save(out, transparency="invalid")
im = Image.new("P", (1, 1))
with pytest.raises(
ValueError, match="transparency for P must be an integer or bytes"
):
im.save(out, transparency="invalid")
im = Image.new("RGB", (1, 1))
with pytest.raises(
ValueError, match="transparency for RGB must be list or tuple"
):
im.save(out, transparency="invalid")
with pytest.raises(ValueError, match="transparency for RGB must have length 3"):
im.save(out, transparency=(1, 2))
def test_trns_null(self) -> None:
# Check reading images with null tRNS value, issue #1239
test_file = "Tests/images/tRNS_null_1x1.png"

View File

@ -1443,35 +1443,47 @@ def _save(
palette_bytes += b"\0"
chunk(fp, b"PLTE", palette_bytes)
transparency = im.encoderinfo.get("transparency", im.info.get("transparency", None))
transparency = im.encoderinfo.get("transparency", im.info.get("transparency"))
if transparency or transparency == 0:
if transparency is not None:
if im.mode == "P":
# limit to actual palette size
alpha_bytes = colors
if isinstance(transparency, bytes):
chunk(fp, b"tRNS", transparency[:alpha_bytes])
else:
elif isinstance(transparency, int):
transparency = max(0, min(255, transparency))
alpha = b"\xff" * transparency + b"\0"
chunk(fp, b"tRNS", alpha[:alpha_bytes])
else:
msg = "transparency for P must be an integer or bytes"
raise ValueError(msg)
elif im.mode in ("1", "L", "I", "I;16"):
transparency = max(0, min(65535, transparency))
chunk(fp, b"tRNS", o16(transparency))
if isinstance(transparency, int):
transparency = max(0, min(65535, transparency))
chunk(fp, b"tRNS", o16(transparency))
else:
msg = f"transparency for {im.mode} must be an integer"
raise ValueError(msg)
elif im.mode == "RGB":
red, green, blue = transparency
chunk(fp, b"tRNS", o16(red) + o16(green) + o16(blue))
else:
if "transparency" in im.encoderinfo:
# don't bother with transparency if it's an RGBA
# and it's in the info dict. It's probably just stale.
msg = "cannot use transparency for this mode"
raise OSError(msg)
else:
if im.mode == "P" and im.im.getpalettemode() == "RGBA":
alpha = im.im.getpalette("RGBA", "A")
alpha_bytes = colors
chunk(fp, b"tRNS", alpha[:alpha_bytes])
if not isinstance(transparency, (list, tuple)):
msg = "transparency for RGB must be list or tuple"
raise ValueError(msg)
elif len(transparency) != 3:
msg = "transparency for RGB must have length 3"
raise ValueError(msg)
else:
red, green, blue = transparency
chunk(fp, b"tRNS", o16(red) + o16(green) + o16(blue))
elif im.encoderinfo.get("transparency") is not None:
# don't bother with transparency if it's an RGBA
# and it's in the info dict. It's probably just stale.
msg = "cannot use transparency for this mode"
raise OSError(msg)
elif im.mode == "P" and im.im.getpalettemode() == "RGBA":
alpha = im.im.getpalette("RGBA", "A")
alpha_bytes = colors
chunk(fp, b"tRNS", alpha[:alpha_bytes])
if dpi := im.encoderinfo.get("dpi"):
chunk(