From 91711ee366b268d59dc517ca94f67bcb7e948bb8 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Wed, 28 Jun 2023 09:28:46 -0700 Subject: [PATCH] PYTHON-3783 add types to compression_support.py (#1272) --- pymongo/compression_support.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index 40bad403f..030376fbd 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import warnings +from typing import Any, Iterable, List, Union try: import snappy @@ -45,10 +47,10 @@ _NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} _NO_COMPRESSION.update(_SENSITIVE_COMMANDS) -def validate_compressors(dummy, value): +def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> List[str]: try: # `value` is string. - compressors = value.split(",") + compressors = value.split(",") # type: ignore[union-attr] except AttributeError: # `value` is an iterable. compressors = list(value) @@ -78,7 +80,7 @@ def validate_compressors(dummy, value): return compressors -def validate_zlib_compression_level(option, value): +def validate_zlib_compression_level(option: str, value: Any) -> int: try: level = int(value) except Exception: @@ -89,11 +91,13 @@ def validate_zlib_compression_level(option, value): class CompressionSettings: - def __init__(self, compressors, zlib_compression_level): + def __init__(self, compressors: List[str], zlib_compression_level: int): self.compressors = compressors self.zlib_compression_level = zlib_compression_level - def get_compression_context(self, compressors): + def get_compression_context( + self, compressors: List[str] + ) -> Union[SnappyContext, ZlibContext, ZstdContext, None]: if compressors: chosen = compressors[0] if chosen == "snappy": @@ -110,7 +114,7 @@ class SnappyContext: compressor_id = 1 @staticmethod - def compress(data): + def compress(data: bytes) -> bytes: return snappy.compress(data) @@ -128,13 +132,13 @@ class ZstdContext: compressor_id = 3 @staticmethod - def compress(data): + def compress(data: bytes) -> bytes: # ZstdCompressor is not thread safe. # TODO: Use a pool? return ZstdCompressor().compress(data) -def decompress(data, compressor_id): +def decompress(data: bytes, compressor_id: int) -> bytes: if compressor_id == SnappyContext.compressor_id: # python-snappy doesn't support the buffer interface. # https://github.com/andrix/python-snappy/issues/65