PYTHON-3803 add types to encryption.py (#1296)

This commit is contained in:
Iris 2023-07-11 08:24:15 -07:00 committed by GitHub
parent fd760c2b66
commit f813f56362
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 39 deletions

View File

@ -53,7 +53,6 @@ bytes [#bytes]_ binary both
.. [#bytes] The bytes type is encoded as BSON binary with
subtype 0. It will be decoded back to bytes.
"""
import datetime
import itertools
import os
@ -84,6 +83,7 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from bson.binary import (
@ -1025,9 +1025,21 @@ def encode(
return _dict_to_bson(document, check_keys, codec_options)
@overload
def decode(data: "_ReadableBuffer", codec_options: None = None) -> Dict[str, Any]:
...
@overload
def decode(
data: "_ReadableBuffer", codec_options: "CodecOptions[_DocumentType]"
) -> "_DocumentType":
...
def decode(
data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None
) -> "_DocumentType":
) -> Union[Dict[str, Any], "_DocumentType"]:
"""Decode BSON to a document.
By default, returns a BSON document represented as a Python

View File

@ -51,7 +51,7 @@ blobs to disk, using raw BSON documents provides better speed and avoids the
overhead of decoding or encoding BSON.
"""
from typing import Any, ItemsView, Iterator, Mapping, Optional
from typing import Any, Dict, ItemsView, Iterator, Mapping, Optional
from bson import _get_object_size, _raw_to_dict
from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER
@ -62,7 +62,7 @@ from bson.son import SON
def _inflate_bson(
bson_bytes: bytes, codec_options: CodecOptions, raw_array: bool = False
) -> Mapping[Any, Any]:
) -> Dict[Any, Any]:
"""Inflates the top level fields of a BSON document.
:Parameters:

View File

@ -250,7 +250,7 @@ class _OIDCAuthenticator:
except OperationFailure as exc:
self.clear()
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined]
if "jwt" in bson.decode(cmd["payload"]):
if self.idp_info_gen_id > self.reauth_gen_id:
raise
return self.authenticate(sock_info, reauthenticate=True)

View File

@ -13,13 +13,24 @@
# limitations under the License.
"""Support for explicit client-side field level encryption."""
from __future__ import annotations
import contextlib
import enum
import socket
import weakref
from copy import deepcopy
from typing import Any, Generic, Mapping, Optional, Sequence, Tuple
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterator,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
)
try:
from pymongocrypt.auto_encrypter import AutoEncrypter
@ -65,6 +76,11 @@ from pymongo.typings import _DocumentType
from pymongo.uri_parser import parse_host
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext
from pymongo.response import Response
_HTTPS_PORT = 443
_KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to CONNECT_TIMEOUT
_MONGOCRYPTD_TIMEOUT_MS = 10000
@ -77,7 +93,7 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
@contextlib.contextmanager
def _wrap_encryption_errors():
def _wrap_encryption_errors() -> Iterator[None]:
"""Context manager to wrap encryption related errors."""
try:
yield
@ -89,8 +105,14 @@ def _wrap_encryption_errors():
raise EncryptionError(exc)
class _EncryptionIO(MongoCryptCallback): # type: ignore
def __init__(self, client, key_vault_coll, mongocryptd_client, opts):
class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
def __init__(
self,
client: Optional[MongoClient],
key_vault_coll: Collection,
mongocryptd_client: Optional[MongoClient],
opts: AutoEncryptionOpts,
):
"""Internal class to perform I/O on behalf of pymongocrypt."""
self.client_ref: Any
# Use a weak ref to break reference cycle.
@ -98,7 +120,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
self.client_ref = weakref.ref(client)
else:
self.client_ref = None
self.key_vault_coll = key_vault_coll.with_options(
self.key_vault_coll: Optional[Collection] = key_vault_coll.with_options(
codec_options=_KEY_VAULT_OPTS,
read_concern=ReadConcern(level="majority"),
write_concern=WriteConcern(w="majority"),
@ -107,7 +129,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
self.opts = opts
self._spawned = False
def kms_request(self, kms_context):
def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
"""Complete a KMS request.
:Parameters:
@ -161,7 +183,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
def collection_info(self, database, filter):
def collection_info(self, database: Database, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.
The returned collection info is passed to libmongocrypt which reads
@ -179,7 +201,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
return None
def spawn(self):
def spawn(self) -> None:
"""Spawn mongocryptd.
Note this method is thread safe; at most one mongocryptd will start
@ -190,7 +212,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
args.extend(self.opts._mongocryptd_spawn_args)
_spawn_daemon(args)
def mark_command(self, database, cmd):
def mark_command(self, database: str, cmd: bytes) -> bytes:
"""Mark a command for encryption.
:Parameters:
@ -205,6 +227,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
# Database.command only supports mutable mappings so we need to decode
# the raw BSON command first.
inflated_cmd = _inflate_bson(cmd, DEFAULT_RAW_BSON_OPTIONS)
assert self.mongocryptd_client is not None
try:
res = self.mongocryptd_client[database].command(
inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS
@ -218,7 +241,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
)
return res.raw
def fetch_keys(self, filter):
def fetch_keys(self, filter: bytes) -> Iterator[bytes]:
"""Yields one or more keys from the key vault.
:Parameters:
@ -227,11 +250,12 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
:Returns:
A generator which yields the requested keys from the key vault.
"""
assert self.key_vault_coll is not None
with self.key_vault_coll.find(RawBSONDocument(filter)) as cursor:
for key in cursor:
yield key.raw
def insert_data_key(self, data_key):
def insert_data_key(self, data_key: bytes) -> Binary:
"""Insert a data key into the key vault.
:Parameters:
@ -245,10 +269,11 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE:
raise TypeError("data_key _id must be Binary with a UUID subtype")
assert self.key_vault_coll is not None
self.key_vault_coll.insert_one(raw_doc)
return data_key_id
def bson_encode(self, doc):
def bson_encode(self, doc: MutableMapping[str, Any]) -> bytes:
"""Encode a document to BSON.
A document can be any mapping type (like :class:`dict`).
@ -261,7 +286,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
"""
return encode(doc)
def close(self):
def close(self) -> None:
"""Release resources.
Note it is not safe to call this method from __del__ or any GC hooks.
@ -300,7 +325,7 @@ class _Encrypter:
MongoDB commands.
"""
def __init__(self, client, opts):
def __init__(self, client: MongoClient, opts: AutoEncryptionOpts):
"""Create a _Encrypter for a client.
:Parameters:
@ -319,7 +344,7 @@ class _Encrypter:
self._bypass_auto_encryption = opts._bypass_auto_encryption
self._internal_client = None
def _get_internal_client(encrypter, mongo_client):
def _get_internal_client(encrypter: _Encrypter, mongo_client: MongoClient) -> MongoClient:
if mongo_client.options.pool_options.max_pool_size is None:
# Unlimited pool size, use the same client.
return mongo_client
@ -362,7 +387,9 @@ class _Encrypter:
)
self._closed = False
def encrypt(self, database, cmd, codec_options):
def encrypt(
self, database: Database, cmd: Mapping[str, Any], codec_options: CodecOptions
) -> Mapping[Any, Any]:
"""Encrypt a MongoDB command.
:Parameters:
@ -381,7 +408,7 @@ class _Encrypter:
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
return encrypt_cmd
def decrypt(self, response):
def decrypt(self, response: Response) -> Optional[bytes]:
"""Decrypt a MongoDB command response.
:Parameters:
@ -394,11 +421,11 @@ class _Encrypter:
with _wrap_encryption_errors():
return self._auto_encrypter.decrypt(response)
def _check_closed(self):
def _check_closed(self) -> None:
if self._closed:
raise InvalidOperation("Cannot use MongoClient after close")
def close(self):
def close(self) -> None:
"""Cleanup resources."""
self._closed = True
self._auto_encrypter.close()
@ -733,15 +760,15 @@ class ClientEncryption(Generic[_DocumentType]):
def _encrypt_helper(
self,
value,
algorithm,
key_id=None,
key_alt_name=None,
query_type=None,
contention_factor=None,
range_opts=None,
is_expression=False,
):
value: Any,
algorithm: str,
key_id: Optional[Binary] = None,
key_alt_name: Optional[str] = None,
query_type: Optional[str] = None,
contention_factor: Optional[int] = None,
range_opts: Optional[RangeOpts] = None,
is_expression: bool = False,
) -> Any:
self._check_closed()
if key_id is not None and not (
isinstance(key_id, Binary) and key_id.subtype == UUID_SUBTYPE
@ -752,8 +779,9 @@ class ClientEncryption(Generic[_DocumentType]):
{"v": value},
codec_options=self._codec_options,
)
range_opts_bytes = None
if range_opts:
range_opts = encode(
range_opts_bytes = encode(
range_opts.document,
codec_options=self._codec_options,
)
@ -765,10 +793,10 @@ class ClientEncryption(Generic[_DocumentType]):
key_alt_name=key_alt_name,
query_type=query_type,
contention_factor=contention_factor,
range_opts=range_opts,
range_opts=range_opts_bytes,
is_expression=is_expression,
)
return decode(encrypted_doc)["v"] # type: ignore[index]
return decode(encrypted_doc)["v"]
def encrypt(
self,
@ -897,6 +925,7 @@ class ClientEncryption(Generic[_DocumentType]):
.. versionadded:: 4.2
"""
self._check_closed()
assert self._key_vault_coll is not None
return self._key_vault_coll.find_one({"_id": id})
def get_keys(self) -> Cursor[RawBSONDocument]:
@ -909,6 +938,7 @@ class ClientEncryption(Generic[_DocumentType]):
.. versionadded:: 4.2
"""
self._check_closed()
assert self._key_vault_coll is not None
return self._key_vault_coll.find({})
def delete_key(self, id: Binary) -> DeleteResult:
@ -925,6 +955,7 @@ class ClientEncryption(Generic[_DocumentType]):
.. versionadded:: 4.2
"""
self._check_closed()
assert self._key_vault_coll is not None
return self._key_vault_coll.delete_one({"_id": id})
def add_key_alt_name(self, id: Binary, key_alt_name: str) -> Any:
@ -943,6 +974,7 @@ class ClientEncryption(Generic[_DocumentType]):
"""
self._check_closed()
update = {"$addToSet": {"keyAltNames": key_alt_name}}
assert self._key_vault_coll is not None
return self._key_vault_coll.find_one_and_update({"_id": id}, update)
def get_key_by_alt_name(self, key_alt_name: str) -> Optional[RawBSONDocument]:
@ -957,6 +989,7 @@ class ClientEncryption(Generic[_DocumentType]):
.. versionadded:: 4.2
"""
self._check_closed()
assert self._key_vault_coll is not None
return self._key_vault_coll.find_one({"keyAltNames": key_alt_name})
def remove_key_alt_name(self, id: Binary, key_alt_name: str) -> Optional[RawBSONDocument]:
@ -994,6 +1027,7 @@ class ClientEncryption(Generic[_DocumentType]):
}
}
]
assert self._key_vault_coll is not None
return self._key_vault_coll.find_one_and_update({"_id": id}, pipeline)
def rewrap_many_data_key(
@ -1052,6 +1086,7 @@ class ClientEncryption(Generic[_DocumentType]):
replacements.append(op)
if not replacements:
return RewrapManyDataKeyResult()
assert self._key_vault_coll is not None
result = self._key_vault_coll.bulk_write(replacements)
return RewrapManyDataKeyResult(result)
@ -1061,7 +1096,7 @@ class ClientEncryption(Generic[_DocumentType]):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def _check_closed(self):
def _check_closed(self) -> None:
if self._encryption is None:
raise InvalidOperation("Cannot use closed ClientEncryption")

View File

@ -15,7 +15,7 @@
"""Support for automatic client-side field level encryption."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
try:
import pymongocrypt # noqa: F401
@ -245,7 +245,7 @@ class RangeOpts:
self.precision = precision
@property
def document(self) -> Mapping[str, Any]:
def document(self) -> Dict[str, Any]:
doc = {}
for k, v in [
("sparsity", int64.Int64(self.sparsity)),