From a0af3a61d7811ab98024b705a18a98a1cf350eaa Mon Sep 17 00:00:00 2001 From: Jeffrey 'Alex' Clark Date: Wed, 6 May 2026 18:42:27 -0400 Subject: [PATCH] Copilot feedback --- test/test_compression_support.py | 98 +++++++++++++++++--------------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/test/test_compression_support.py b/test/test_compression_support.py index 8ad65cd50..aa767de24 100644 --- a/test/test_compression_support.py +++ b/test/test_compression_support.py @@ -17,7 +17,6 @@ from __future__ import annotations import sys -import zlib from unittest.mock import patch sys.path[0:0] = [""] @@ -38,52 +37,22 @@ from pymongo.compression_support import ( ) -class TestHaveSnappy(unittest.TestCase): - def test_returns_true_when_available(self): - try: - import snappy - except ImportError: - self.skipTest("python-snappy not installed") - self.assertTrue(_have_snappy()) - - def test_returns_false_on_import_error(self): - with patch.dict(sys.modules, {"snappy": None}): - self.assertFalse(_have_snappy()) - - -class TestHaveZlib(unittest.TestCase): - def test_returns_false_on_import_error(self): - with patch.dict(sys.modules, {"zlib": None}): - self.assertFalse(_have_zlib()) - - -class TestHaveZstd(unittest.TestCase): - def test_returns_false_when_unavailable_pre_314(self): - if sys.version_info >= (3, 14): - self.skipTest("Python 3.14+ uses compression.zstd") - with patch.dict(sys.modules, {"backports": None, "backports.zstd": None}): - self.assertFalse(_have_zstd()) - - def test_returns_false_when_unavailable_314_plus(self): - if sys.version_info < (3, 14): - self.skipTest("Only applies to Python 3.14+") - with patch.dict(sys.modules, {"compression": None, "compression.zstd": None}): - self.assertFalse(_have_zstd()) - - class TestValidateCompressors(unittest.TestCase): def test_string_input_single(self): - result = validate_compressors(None, "zlib") + with patch("pymongo.compression_support._have_zlib", return_value=True): + result = validate_compressors(None, "zlib") self.assertEqual(result, ["zlib"]) def test_string_input_comma_separated(self): - with patch("pymongo.compression_support._have_snappy", return_value=True): + with patch("pymongo.compression_support._have_zlib", return_value=True), patch( + "pymongo.compression_support._have_snappy", return_value=True + ): result = validate_compressors(None, "zlib,snappy") - self.assertIn("zlib", result) - self.assertIn("snappy", result) + self.assertEqual(result, ["zlib", "snappy"]) def test_iterable_input(self): - result = validate_compressors(None, ["zlib"]) + with patch("pymongo.compression_support._have_zlib", return_value=True): + result = validate_compressors(None, ["zlib"]) self.assertEqual(result, ["zlib"]) def test_unsupported_compressor_warns_and_removes(self): @@ -125,7 +94,9 @@ class TestValidateCompressors(unittest.TestCase): self.assertIn("compression.zstd", str(ctx.warning)) def test_multiple_valid_compressors_preserves_order(self): - with patch("pymongo.compression_support._have_snappy", return_value=True): + with patch("pymongo.compression_support._have_zlib", return_value=True), patch( + "pymongo.compression_support._have_snappy", return_value=True + ): result = validate_compressors(None, ["zlib", "snappy"]) self.assertEqual(result, ["zlib", "snappy"]) @@ -178,10 +149,9 @@ class TestCompressionSettings(unittest.TestCase): self.assertIsInstance(ctx, SnappyContext) def test_get_context_zlib(self): - settings = self._make(level=6) + settings = self._make() ctx = settings.get_compression_context(["zlib"]) self.assertIsInstance(ctx, ZlibContext) - self.assertEqual(ctx.level, 6) def test_get_context_zstd(self): settings = self._make() @@ -189,7 +159,7 @@ class TestCompressionSettings(unittest.TestCase): self.assertIsInstance(ctx, ZstdContext) def test_get_context_uses_first_compressor(self): - settings = self._make(level=1) + settings = self._make() ctx = settings.get_compression_context(["zlib", "snappy"]) self.assertIsInstance(ctx, ZlibContext) @@ -200,16 +170,18 @@ class TestCompressionSettings(unittest.TestCase): class TestZlibContext(unittest.TestCase): + def setUp(self): + if not _have_zlib(): + self.skipTest("zlib not available") + def test_compress_and_decompress_roundtrip(self): + import zlib + ctx = ZlibContext(level=-1) data = b"hello world" * 100 compressed = ctx.compress(data) self.assertEqual(zlib.decompress(compressed), data) - def test_compress_level_stored(self): - ctx = ZlibContext(level=6) - self.assertEqual(ctx.level, 6) - class TestDecompress(unittest.TestCase): def test_unknown_compressor_id_raises(self): @@ -218,17 +190,49 @@ class TestDecompress(unittest.TestCase): self.assertIn("Unknown compressorId 99", str(ctx.exception)) def test_zlib_roundtrip(self): + if not _have_zlib(): + self.skipTest("zlib not available") + import zlib + data = b"hello world" compressed = zlib.compress(data) result = decompress(compressed, ZlibContext.compressor_id) self.assertEqual(result, data) def test_zlib_with_memoryview(self): + if not _have_zlib(): + self.skipTest("zlib not available") + import zlib + data = b"test data" compressed = zlib.compress(data) result = decompress(memoryview(compressed), ZlibContext.compressor_id) self.assertEqual(result, data) + def test_snappy_roundtrip(self): + if not _have_snappy(): + self.skipTest("python-snappy not installed") + data = b"hello world" * 50 + compressed = SnappyContext.compress(data) + result = decompress(compressed, SnappyContext.compressor_id) + self.assertEqual(result, data) + + def test_snappy_with_memoryview(self): + if not _have_snappy(): + self.skipTest("python-snappy not installed") + data = b"hello world" * 50 + compressed = SnappyContext.compress(data) + result = decompress(memoryview(compressed), SnappyContext.compressor_id) + self.assertEqual(result, data) + + def test_zstd_roundtrip(self): + if not _have_zstd(): + self.skipTest("zstd not available") + data = b"hello world" * 50 + compressed = ZstdContext.compress(data) + result = decompress(compressed, ZstdContext.compressor_id) + self.assertEqual(result, data) + if __name__ == "__main__": unittest.main()