From 0d44783eddec96a9ae1a03b148cb4f049dca3e49 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Thu, 10 Aug 2023 16:46:41 -0700 Subject: [PATCH] PYTHON-3821 use overload pattern for _DocumentType (#1352) --- bson/__init__.py | 52 ++++++++++++++++++++++++++++++++++++++----- pymongo/collection.py | 18 --------------- test/test_typing.py | 25 +++++++++++++++++++++ 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index f61a7ddf7..4379b9479 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -1106,9 +1106,21 @@ if _USE_C: _decode_all = _cbson._decode_all # noqa: F811 +@overload +def decode_all(data: "_ReadableBuffer", codec_options: None = None) -> "List[Dict[str, Any]]": + ... + + +@overload +def decode_all( + data: "_ReadableBuffer", codec_options: "CodecOptions[_DocumentType]" +) -> "List[_DocumentType]": + ... + + def decode_all( data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None -) -> "List[_DocumentType]": +) -> "Union[List[Dict[str, Any]], List[_DocumentType]]": """Decode BSON data to multiple documents. `data` must be a bytes-like object implementing the buffer protocol that @@ -1131,11 +1143,13 @@ def decode_all( Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with `codec_options`. """ - opts = codec_options or DEFAULT_CODEC_OPTIONS - if not isinstance(opts, CodecOptions): + if codec_options is None: + return _decode_all(data, DEFAULT_CODEC_OPTIONS) + + if not isinstance(codec_options, CodecOptions): raise _CODEC_OPTIONS_TYPE_ERROR - return _decode_all(data, opts) # type:ignore[arg-type] + return _decode_all(data, codec_options) def _decode_selective(rawdoc: Any, fields: Any, codec_options: Any) -> Mapping[Any, Any]: @@ -1242,9 +1256,21 @@ def _decode_all_selective(data: Any, codec_options: CodecOptions, fields: Any) - ] +@overload +def decode_iter(data: bytes, codec_options: None = None) -> "Iterator[Dict[str, Any]]": + ... + + +@overload +def decode_iter( + data: bytes, codec_options: "CodecOptions[_DocumentType]" +) -> "Iterator[_DocumentType]": + ... + + def decode_iter( data: bytes, codec_options: "Optional[CodecOptions[_DocumentType]]" = None -) -> "Iterator[_DocumentType]": +) -> "Union[Iterator[Dict[str, Any]], Iterator[_DocumentType]]": """Decode BSON data to multiple documents as a generator. Works similarly to the decode_all function, but yields one document at a @@ -1278,9 +1304,23 @@ def decode_iter( yield _bson_to_dict(elements, opts) +@overload +def decode_file_iter( + file_obj: Union[BinaryIO, IO], codec_options: None = None +) -> "Iterator[Dict[str, Any]]": + ... + + +@overload +def decode_file_iter( + file_obj: Union[BinaryIO, IO], codec_options: "CodecOptions[_DocumentType]" +) -> "Iterator[_DocumentType]": + ... + + def decode_file_iter( file_obj: Union[BinaryIO, IO], codec_options: "Optional[CodecOptions[_DocumentType]]" = None -) -> "Iterator[_DocumentType]": +) -> "Union[Iterator[Dict[str, Any]], Iterator[_DocumentType]]": """Decode bson data from a file to multiple documents as a generator. Works similarly to the decode_all function, but reads from the file object diff --git a/pymongo/collection.py b/pymongo/collection.py index 6b2b16db7..772e43e95 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -427,24 +427,6 @@ class Collection(common.BaseObject, Generic[_DocumentType]): """ return self.__database - # @overload - # def with_options( - # self, - # codec_options: None = None, - # read_preference: Optional[_ServerMode] = None, - # write_concern: Optional[WriteConcern] = None, - # read_concern: Optional[ReadConcern] = None, - # ) -> Collection[Dict[str, Any]]: ... - - # @overload - # def with_options( - # self, - # codec_options: bson.CodecOptions[_DocumentType], - # read_preference: Optional[_ServerMode] = None, - # write_concern: Optional[WriteConcern] = None, - # read_concern: Optional[ReadConcern] = None, - # ) -> Collection[_DocumentType]: ... - def with_options( self, codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None, diff --git a/test/test_typing.py b/test/test_typing.py index 27597bb2c..b2db4b93b 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -242,6 +242,11 @@ class TestDecode(unittest.TestCase): rt_document3 = decode(bsonbytes2, codec_options=codec_options2) assert rt_document3.raw + def test_bson_decode_no_codec_option(self) -> None: + doc = decode_all(encode({"a": 1})) + assert doc + doc[0]["a"] = 2 + def test_bson_decode_all(self) -> None: doc = {"_id": 1} bsonbytes = encode(doc) @@ -266,6 +271,15 @@ class TestDecode(unittest.TestCase): rt_documents3 = decode_all(bsonbytes3, codec_options3) assert rt_documents3[0].raw + def test_bson_decode_all_no_codec_option(self) -> None: + docs = decode_all(b"") + docs.append({"new": 1}) + + docs = decode_all(encode({"a": 1})) + assert docs + docs[0]["a"] = 2 + docs.append({"new": 1}) + def test_bson_decode_iter(self) -> None: doc = {"_id": 1} bsonbytes = encode(doc) @@ -290,6 +304,11 @@ class TestDecode(unittest.TestCase): rt_documents3 = decode_iter(bsonbytes3, codec_options3) assert next(rt_documents3).raw + def test_bson_decode_iter_no_codec_option(self) -> None: + doc = next(decode_iter(encode({"a": 1}))) + assert doc + doc["a"] = 2 + def make_tempfile(self, content: bytes) -> Any: fileobj = tempfile.TemporaryFile() fileobj.write(content) @@ -324,6 +343,12 @@ class TestDecode(unittest.TestCase): rt_documents3 = decode_file_iter(fileobj3, codec_options3) assert next(rt_documents3).raw + def test_bson_decode_file_iter_none_codec_option(self) -> None: + fileobj = self.make_tempfile(encode({"new": 1})) + doc = next(decode_file_iter(fileobj)) + assert doc + doc["a"] = 2 + class TestDocumentType(unittest.TestCase): @only_type_check