diff --git a/bson/__init__.py b/bson/__init__.py index fd11c9952..065e7f7ca 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -53,7 +53,6 @@ bytes [#bytes]_ binary both .. [#bytes] The bytes type is encoded as BSON binary with subtype 0. It will be decoded back to bytes. """ - import datetime import itertools import os @@ -84,6 +83,7 @@ from typing import ( TypeVar, Union, cast, + overload, ) from bson.binary import ( @@ -1025,9 +1025,21 @@ def encode( return _dict_to_bson(document, check_keys, codec_options) +@overload +def decode(data: "_ReadableBuffer", codec_options: None = None) -> Dict[str, Any]: + ... + + +@overload +def decode( + data: "_ReadableBuffer", codec_options: "CodecOptions[_DocumentType]" +) -> "_DocumentType": + ... + + def decode( data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None -) -> "_DocumentType": +) -> Union[Dict[str, Any], "_DocumentType"]: """Decode BSON to a document. By default, returns a BSON document represented as a Python diff --git a/bson/raw_bson.py b/bson/raw_bson.py index bb1dbd22a..f48016909 100644 --- a/bson/raw_bson.py +++ b/bson/raw_bson.py @@ -51,7 +51,7 @@ blobs to disk, using raw BSON documents provides better speed and avoids the overhead of decoding or encoding BSON. """ -from typing import Any, ItemsView, Iterator, Mapping, Optional +from typing import Any, Dict, ItemsView, Iterator, Mapping, Optional from bson import _get_object_size, _raw_to_dict from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER @@ -62,7 +62,7 @@ from bson.son import SON def _inflate_bson( bson_bytes: bytes, codec_options: CodecOptions, raw_array: bool = False -) -> Mapping[Any, Any]: +) -> Dict[Any, Any]: """Inflates the top level fields of a BSON document. :Parameters: diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index a3afbdb3f..62648b2c0 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -250,7 +250,7 @@ class _OIDCAuthenticator: except OperationFailure as exc: self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined] + if "jwt" in bson.decode(cmd["payload"]): if self.idp_info_gen_id > self.reauth_gen_id: raise return self.authenticate(sock_info, reauthenticate=True) diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 1d407fae8..431a6cd84 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -13,13 +13,24 @@ # limitations under the License. """Support for explicit client-side field level encryption.""" +from __future__ import annotations import contextlib import enum import socket import weakref from copy import deepcopy -from typing import Any, Generic, Mapping, Optional, Sequence, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterator, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, +) try: from pymongocrypt.auto_encrypter import AutoEncrypter @@ -65,6 +76,11 @@ from pymongo.typings import _DocumentType from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern +if TYPE_CHECKING: + from pymongocrypt.mongocrypt import MongoCryptKmsContext + + from pymongo.response import Response + _HTTPS_PORT = 443 _KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to CONNECT_TIMEOUT _MONGOCRYPTD_TIMEOUT_MS = 10000 @@ -77,7 +93,7 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument) @contextlib.contextmanager -def _wrap_encryption_errors(): +def _wrap_encryption_errors() -> Iterator[None]: """Context manager to wrap encryption related errors.""" try: yield @@ -89,8 +105,14 @@ def _wrap_encryption_errors(): raise EncryptionError(exc) -class _EncryptionIO(MongoCryptCallback): # type: ignore - def __init__(self, client, key_vault_coll, mongocryptd_client, opts): +class _EncryptionIO(MongoCryptCallback): # type: ignore[misc] + def __init__( + self, + client: Optional[MongoClient], + key_vault_coll: Collection, + mongocryptd_client: Optional[MongoClient], + opts: AutoEncryptionOpts, + ): """Internal class to perform I/O on behalf of pymongocrypt.""" self.client_ref: Any # Use a weak ref to break reference cycle. @@ -98,7 +120,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore self.client_ref = weakref.ref(client) else: self.client_ref = None - self.key_vault_coll = key_vault_coll.with_options( + self.key_vault_coll: Optional[Collection] = key_vault_coll.with_options( codec_options=_KEY_VAULT_OPTS, read_concern=ReadConcern(level="majority"), write_concern=WriteConcern(w="majority"), @@ -107,7 +129,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore self.opts = opts self._spawned = False - def kms_request(self, kms_context): + def kms_request(self, kms_context: MongoCryptKmsContext) -> None: """Complete a KMS request. :Parameters: @@ -161,7 +183,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) - def collection_info(self, database, filter): + def collection_info(self, database: Database, filter: bytes) -> Optional[bytes]: """Get the collection info for a namespace. The returned collection info is passed to libmongocrypt which reads @@ -179,7 +201,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore return _dict_to_bson(doc, False, _DATA_KEY_OPTS) return None - def spawn(self): + def spawn(self) -> None: """Spawn mongocryptd. Note this method is thread safe; at most one mongocryptd will start @@ -190,7 +212,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore args.extend(self.opts._mongocryptd_spawn_args) _spawn_daemon(args) - def mark_command(self, database, cmd): + def mark_command(self, database: str, cmd: bytes) -> bytes: """Mark a command for encryption. :Parameters: @@ -205,6 +227,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore # Database.command only supports mutable mappings so we need to decode # the raw BSON command first. inflated_cmd = _inflate_bson(cmd, DEFAULT_RAW_BSON_OPTIONS) + assert self.mongocryptd_client is not None try: res = self.mongocryptd_client[database].command( inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS @@ -218,7 +241,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore ) return res.raw - def fetch_keys(self, filter): + def fetch_keys(self, filter: bytes) -> Iterator[bytes]: """Yields one or more keys from the key vault. :Parameters: @@ -227,11 +250,12 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore :Returns: A generator which yields the requested keys from the key vault. """ + assert self.key_vault_coll is not None with self.key_vault_coll.find(RawBSONDocument(filter)) as cursor: for key in cursor: yield key.raw - def insert_data_key(self, data_key): + def insert_data_key(self, data_key: bytes) -> Binary: """Insert a data key into the key vault. :Parameters: @@ -245,10 +269,11 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE: raise TypeError("data_key _id must be Binary with a UUID subtype") + assert self.key_vault_coll is not None self.key_vault_coll.insert_one(raw_doc) return data_key_id - def bson_encode(self, doc): + def bson_encode(self, doc: MutableMapping[str, Any]) -> bytes: """Encode a document to BSON. A document can be any mapping type (like :class:`dict`). @@ -261,7 +286,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore """ return encode(doc) - def close(self): + def close(self) -> None: """Release resources. Note it is not safe to call this method from __del__ or any GC hooks. @@ -300,7 +325,7 @@ class _Encrypter: MongoDB commands. """ - def __init__(self, client, opts): + def __init__(self, client: MongoClient, opts: AutoEncryptionOpts): """Create a _Encrypter for a client. :Parameters: @@ -319,7 +344,7 @@ class _Encrypter: self._bypass_auto_encryption = opts._bypass_auto_encryption self._internal_client = None - def _get_internal_client(encrypter, mongo_client): + def _get_internal_client(encrypter: _Encrypter, mongo_client: MongoClient) -> MongoClient: if mongo_client.options.pool_options.max_pool_size is None: # Unlimited pool size, use the same client. return mongo_client @@ -362,7 +387,9 @@ class _Encrypter: ) self._closed = False - def encrypt(self, database, cmd, codec_options): + def encrypt( + self, database: Database, cmd: Mapping[str, Any], codec_options: CodecOptions + ) -> Mapping[Any, Any]: """Encrypt a MongoDB command. :Parameters: @@ -381,7 +408,7 @@ class _Encrypter: encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) return encrypt_cmd - def decrypt(self, response): + def decrypt(self, response: Response) -> Optional[bytes]: """Decrypt a MongoDB command response. :Parameters: @@ -394,11 +421,11 @@ class _Encrypter: with _wrap_encryption_errors(): return self._auto_encrypter.decrypt(response) - def _check_closed(self): + def _check_closed(self) -> None: if self._closed: raise InvalidOperation("Cannot use MongoClient after close") - def close(self): + def close(self) -> None: """Cleanup resources.""" self._closed = True self._auto_encrypter.close() @@ -733,15 +760,15 @@ class ClientEncryption(Generic[_DocumentType]): def _encrypt_helper( self, - value, - algorithm, - key_id=None, - key_alt_name=None, - query_type=None, - contention_factor=None, - range_opts=None, - is_expression=False, - ): + value: Any, + algorithm: str, + key_id: Optional[Binary] = None, + key_alt_name: Optional[str] = None, + query_type: Optional[str] = None, + contention_factor: Optional[int] = None, + range_opts: Optional[RangeOpts] = None, + is_expression: bool = False, + ) -> Any: self._check_closed() if key_id is not None and not ( isinstance(key_id, Binary) and key_id.subtype == UUID_SUBTYPE @@ -752,8 +779,9 @@ class ClientEncryption(Generic[_DocumentType]): {"v": value}, codec_options=self._codec_options, ) + range_opts_bytes = None if range_opts: - range_opts = encode( + range_opts_bytes = encode( range_opts.document, codec_options=self._codec_options, ) @@ -765,10 +793,10 @@ class ClientEncryption(Generic[_DocumentType]): key_alt_name=key_alt_name, query_type=query_type, contention_factor=contention_factor, - range_opts=range_opts, + range_opts=range_opts_bytes, is_expression=is_expression, ) - return decode(encrypted_doc)["v"] # type: ignore[index] + return decode(encrypted_doc)["v"] def encrypt( self, @@ -897,6 +925,7 @@ class ClientEncryption(Generic[_DocumentType]): .. versionadded:: 4.2 """ self._check_closed() + assert self._key_vault_coll is not None return self._key_vault_coll.find_one({"_id": id}) def get_keys(self) -> Cursor[RawBSONDocument]: @@ -909,6 +938,7 @@ class ClientEncryption(Generic[_DocumentType]): .. versionadded:: 4.2 """ self._check_closed() + assert self._key_vault_coll is not None return self._key_vault_coll.find({}) def delete_key(self, id: Binary) -> DeleteResult: @@ -925,6 +955,7 @@ class ClientEncryption(Generic[_DocumentType]): .. versionadded:: 4.2 """ self._check_closed() + assert self._key_vault_coll is not None return self._key_vault_coll.delete_one({"_id": id}) def add_key_alt_name(self, id: Binary, key_alt_name: str) -> Any: @@ -943,6 +974,7 @@ class ClientEncryption(Generic[_DocumentType]): """ self._check_closed() update = {"$addToSet": {"keyAltNames": key_alt_name}} + assert self._key_vault_coll is not None return self._key_vault_coll.find_one_and_update({"_id": id}, update) def get_key_by_alt_name(self, key_alt_name: str) -> Optional[RawBSONDocument]: @@ -957,6 +989,7 @@ class ClientEncryption(Generic[_DocumentType]): .. versionadded:: 4.2 """ self._check_closed() + assert self._key_vault_coll is not None return self._key_vault_coll.find_one({"keyAltNames": key_alt_name}) def remove_key_alt_name(self, id: Binary, key_alt_name: str) -> Optional[RawBSONDocument]: @@ -994,6 +1027,7 @@ class ClientEncryption(Generic[_DocumentType]): } } ] + assert self._key_vault_coll is not None return self._key_vault_coll.find_one_and_update({"_id": id}, pipeline) def rewrap_many_data_key( @@ -1052,6 +1086,7 @@ class ClientEncryption(Generic[_DocumentType]): replacements.append(op) if not replacements: return RewrapManyDataKeyResult() + assert self._key_vault_coll is not None result = self._key_vault_coll.bulk_write(replacements) return RewrapManyDataKeyResult(result) @@ -1061,7 +1096,7 @@ class ClientEncryption(Generic[_DocumentType]): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def _check_closed(self): + def _check_closed(self) -> None: if self._encryption is None: raise InvalidOperation("Cannot use closed ClientEncryption") diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index d6f3ca683..b4ffd92a8 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -15,7 +15,7 @@ """Support for automatic client-side field level encryption.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Mapping, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional try: import pymongocrypt # noqa: F401 @@ -245,7 +245,7 @@ class RangeOpts: self.precision = precision @property - def document(self) -> Mapping[str, Any]: + def document(self) -> Dict[str, Any]: doc = {} for k, v in [ ("sparsity", int64.Int64(self.sparsity)),