PYTHON-3803 add types to encryption.py (#1296)
This commit is contained in:
parent
fd760c2b66
commit
f813f56362
@ -53,7 +53,6 @@ bytes [#bytes]_ binary both
|
||||
.. [#bytes] The bytes type is encoded as BSON binary with
|
||||
subtype 0. It will be decoded back to bytes.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import itertools
|
||||
import os
|
||||
@ -84,6 +83,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from bson.binary import (
|
||||
@ -1025,9 +1025,21 @@ def encode(
|
||||
return _dict_to_bson(document, check_keys, codec_options)
|
||||
|
||||
|
||||
@overload
|
||||
def decode(data: "_ReadableBuffer", codec_options: None = None) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def decode(
|
||||
data: "_ReadableBuffer", codec_options: "CodecOptions[_DocumentType]"
|
||||
) -> "_DocumentType":
|
||||
...
|
||||
|
||||
|
||||
def decode(
|
||||
data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None
|
||||
) -> "_DocumentType":
|
||||
) -> Union[Dict[str, Any], "_DocumentType"]:
|
||||
"""Decode BSON to a document.
|
||||
|
||||
By default, returns a BSON document represented as a Python
|
||||
|
||||
@ -51,7 +51,7 @@ blobs to disk, using raw BSON documents provides better speed and avoids the
|
||||
overhead of decoding or encoding BSON.
|
||||
"""
|
||||
|
||||
from typing import Any, ItemsView, Iterator, Mapping, Optional
|
||||
from typing import Any, Dict, ItemsView, Iterator, Mapping, Optional
|
||||
|
||||
from bson import _get_object_size, _raw_to_dict
|
||||
from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER
|
||||
@ -62,7 +62,7 @@ from bson.son import SON
|
||||
|
||||
def _inflate_bson(
|
||||
bson_bytes: bytes, codec_options: CodecOptions, raw_array: bool = False
|
||||
) -> Mapping[Any, Any]:
|
||||
) -> Dict[Any, Any]:
|
||||
"""Inflates the top level fields of a BSON document.
|
||||
|
||||
:Parameters:
|
||||
|
||||
@ -250,7 +250,7 @@ class _OIDCAuthenticator:
|
||||
except OperationFailure as exc:
|
||||
self.clear()
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined]
|
||||
if "jwt" in bson.decode(cmd["payload"]):
|
||||
if self.idp_info_gen_id > self.reauth_gen_id:
|
||||
raise
|
||||
return self.authenticate(sock_info, reauthenticate=True)
|
||||
|
||||
@ -13,13 +13,24 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Support for explicit client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
import socket
|
||||
import weakref
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generic, Mapping, Optional, Sequence, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Iterator,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
try:
|
||||
from pymongocrypt.auto_encrypter import AutoEncrypter
|
||||
@ -65,6 +76,11 @@ from pymongo.typings import _DocumentType
|
||||
from pymongo.uri_parser import parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongocrypt.mongocrypt import MongoCryptKmsContext
|
||||
|
||||
from pymongo.response import Response
|
||||
|
||||
_HTTPS_PORT = 443
|
||||
_KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to CONNECT_TIMEOUT
|
||||
_MONGOCRYPTD_TIMEOUT_MS = 10000
|
||||
@ -77,7 +93,7 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _wrap_encryption_errors():
|
||||
def _wrap_encryption_errors() -> Iterator[None]:
|
||||
"""Context manager to wrap encryption related errors."""
|
||||
try:
|
||||
yield
|
||||
@ -89,8 +105,14 @@ def _wrap_encryption_errors():
|
||||
raise EncryptionError(exc)
|
||||
|
||||
|
||||
class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
def __init__(self, client, key_vault_coll, mongocryptd_client, opts):
|
||||
class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
client: Optional[MongoClient],
|
||||
key_vault_coll: Collection,
|
||||
mongocryptd_client: Optional[MongoClient],
|
||||
opts: AutoEncryptionOpts,
|
||||
):
|
||||
"""Internal class to perform I/O on behalf of pymongocrypt."""
|
||||
self.client_ref: Any
|
||||
# Use a weak ref to break reference cycle.
|
||||
@ -98,7 +120,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
self.client_ref = weakref.ref(client)
|
||||
else:
|
||||
self.client_ref = None
|
||||
self.key_vault_coll = key_vault_coll.with_options(
|
||||
self.key_vault_coll: Optional[Collection] = key_vault_coll.with_options(
|
||||
codec_options=_KEY_VAULT_OPTS,
|
||||
read_concern=ReadConcern(level="majority"),
|
||||
write_concern=WriteConcern(w="majority"),
|
||||
@ -107,7 +129,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
self.opts = opts
|
||||
self._spawned = False
|
||||
|
||||
def kms_request(self, kms_context):
|
||||
def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
|
||||
"""Complete a KMS request.
|
||||
|
||||
:Parameters:
|
||||
@ -161,7 +183,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
# Wrap I/O errors in PyMongo exceptions.
|
||||
_raise_connection_failure((host, port), error)
|
||||
|
||||
def collection_info(self, database, filter):
|
||||
def collection_info(self, database: Database, filter: bytes) -> Optional[bytes]:
|
||||
"""Get the collection info for a namespace.
|
||||
|
||||
The returned collection info is passed to libmongocrypt which reads
|
||||
@ -179,7 +201,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
|
||||
return None
|
||||
|
||||
def spawn(self):
|
||||
def spawn(self) -> None:
|
||||
"""Spawn mongocryptd.
|
||||
|
||||
Note this method is thread safe; at most one mongocryptd will start
|
||||
@ -190,7 +212,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
args.extend(self.opts._mongocryptd_spawn_args)
|
||||
_spawn_daemon(args)
|
||||
|
||||
def mark_command(self, database, cmd):
|
||||
def mark_command(self, database: str, cmd: bytes) -> bytes:
|
||||
"""Mark a command for encryption.
|
||||
|
||||
:Parameters:
|
||||
@ -205,6 +227,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
# Database.command only supports mutable mappings so we need to decode
|
||||
# the raw BSON command first.
|
||||
inflated_cmd = _inflate_bson(cmd, DEFAULT_RAW_BSON_OPTIONS)
|
||||
assert self.mongocryptd_client is not None
|
||||
try:
|
||||
res = self.mongocryptd_client[database].command(
|
||||
inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS
|
||||
@ -218,7 +241,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
)
|
||||
return res.raw
|
||||
|
||||
def fetch_keys(self, filter):
|
||||
def fetch_keys(self, filter: bytes) -> Iterator[bytes]:
|
||||
"""Yields one or more keys from the key vault.
|
||||
|
||||
:Parameters:
|
||||
@ -227,11 +250,12 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
:Returns:
|
||||
A generator which yields the requested keys from the key vault.
|
||||
"""
|
||||
assert self.key_vault_coll is not None
|
||||
with self.key_vault_coll.find(RawBSONDocument(filter)) as cursor:
|
||||
for key in cursor:
|
||||
yield key.raw
|
||||
|
||||
def insert_data_key(self, data_key):
|
||||
def insert_data_key(self, data_key: bytes) -> Binary:
|
||||
"""Insert a data key into the key vault.
|
||||
|
||||
:Parameters:
|
||||
@ -245,10 +269,11 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE:
|
||||
raise TypeError("data_key _id must be Binary with a UUID subtype")
|
||||
|
||||
assert self.key_vault_coll is not None
|
||||
self.key_vault_coll.insert_one(raw_doc)
|
||||
return data_key_id
|
||||
|
||||
def bson_encode(self, doc):
|
||||
def bson_encode(self, doc: MutableMapping[str, Any]) -> bytes:
|
||||
"""Encode a document to BSON.
|
||||
|
||||
A document can be any mapping type (like :class:`dict`).
|
||||
@ -261,7 +286,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore
|
||||
"""
|
||||
return encode(doc)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Release resources.
|
||||
|
||||
Note it is not safe to call this method from __del__ or any GC hooks.
|
||||
@ -300,7 +325,7 @@ class _Encrypter:
|
||||
MongoDB commands.
|
||||
"""
|
||||
|
||||
def __init__(self, client, opts):
|
||||
def __init__(self, client: MongoClient, opts: AutoEncryptionOpts):
|
||||
"""Create a _Encrypter for a client.
|
||||
|
||||
:Parameters:
|
||||
@ -319,7 +344,7 @@ class _Encrypter:
|
||||
self._bypass_auto_encryption = opts._bypass_auto_encryption
|
||||
self._internal_client = None
|
||||
|
||||
def _get_internal_client(encrypter, mongo_client):
|
||||
def _get_internal_client(encrypter: _Encrypter, mongo_client: MongoClient) -> MongoClient:
|
||||
if mongo_client.options.pool_options.max_pool_size is None:
|
||||
# Unlimited pool size, use the same client.
|
||||
return mongo_client
|
||||
@ -362,7 +387,9 @@ class _Encrypter:
|
||||
)
|
||||
self._closed = False
|
||||
|
||||
def encrypt(self, database, cmd, codec_options):
|
||||
def encrypt(
|
||||
self, database: Database, cmd: Mapping[str, Any], codec_options: CodecOptions
|
||||
) -> Mapping[Any, Any]:
|
||||
"""Encrypt a MongoDB command.
|
||||
|
||||
:Parameters:
|
||||
@ -381,7 +408,7 @@ class _Encrypter:
|
||||
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
|
||||
return encrypt_cmd
|
||||
|
||||
def decrypt(self, response):
|
||||
def decrypt(self, response: Response) -> Optional[bytes]:
|
||||
"""Decrypt a MongoDB command response.
|
||||
|
||||
:Parameters:
|
||||
@ -394,11 +421,11 @@ class _Encrypter:
|
||||
with _wrap_encryption_errors():
|
||||
return self._auto_encrypter.decrypt(response)
|
||||
|
||||
def _check_closed(self):
|
||||
def _check_closed(self) -> None:
|
||||
if self._closed:
|
||||
raise InvalidOperation("Cannot use MongoClient after close")
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Cleanup resources."""
|
||||
self._closed = True
|
||||
self._auto_encrypter.close()
|
||||
@ -733,15 +760,15 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
|
||||
def _encrypt_helper(
|
||||
self,
|
||||
value,
|
||||
algorithm,
|
||||
key_id=None,
|
||||
key_alt_name=None,
|
||||
query_type=None,
|
||||
contention_factor=None,
|
||||
range_opts=None,
|
||||
is_expression=False,
|
||||
):
|
||||
value: Any,
|
||||
algorithm: str,
|
||||
key_id: Optional[Binary] = None,
|
||||
key_alt_name: Optional[str] = None,
|
||||
query_type: Optional[str] = None,
|
||||
contention_factor: Optional[int] = None,
|
||||
range_opts: Optional[RangeOpts] = None,
|
||||
is_expression: bool = False,
|
||||
) -> Any:
|
||||
self._check_closed()
|
||||
if key_id is not None and not (
|
||||
isinstance(key_id, Binary) and key_id.subtype == UUID_SUBTYPE
|
||||
@ -752,8 +779,9 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
{"v": value},
|
||||
codec_options=self._codec_options,
|
||||
)
|
||||
range_opts_bytes = None
|
||||
if range_opts:
|
||||
range_opts = encode(
|
||||
range_opts_bytes = encode(
|
||||
range_opts.document,
|
||||
codec_options=self._codec_options,
|
||||
)
|
||||
@ -765,10 +793,10 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
key_alt_name=key_alt_name,
|
||||
query_type=query_type,
|
||||
contention_factor=contention_factor,
|
||||
range_opts=range_opts,
|
||||
range_opts=range_opts_bytes,
|
||||
is_expression=is_expression,
|
||||
)
|
||||
return decode(encrypted_doc)["v"] # type: ignore[index]
|
||||
return decode(encrypted_doc)["v"]
|
||||
|
||||
def encrypt(
|
||||
self,
|
||||
@ -897,6 +925,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
self._check_closed()
|
||||
assert self._key_vault_coll is not None
|
||||
return self._key_vault_coll.find_one({"_id": id})
|
||||
|
||||
def get_keys(self) -> Cursor[RawBSONDocument]:
|
||||
@ -909,6 +938,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
self._check_closed()
|
||||
assert self._key_vault_coll is not None
|
||||
return self._key_vault_coll.find({})
|
||||
|
||||
def delete_key(self, id: Binary) -> DeleteResult:
|
||||
@ -925,6 +955,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
self._check_closed()
|
||||
assert self._key_vault_coll is not None
|
||||
return self._key_vault_coll.delete_one({"_id": id})
|
||||
|
||||
def add_key_alt_name(self, id: Binary, key_alt_name: str) -> Any:
|
||||
@ -943,6 +974,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
"""
|
||||
self._check_closed()
|
||||
update = {"$addToSet": {"keyAltNames": key_alt_name}}
|
||||
assert self._key_vault_coll is not None
|
||||
return self._key_vault_coll.find_one_and_update({"_id": id}, update)
|
||||
|
||||
def get_key_by_alt_name(self, key_alt_name: str) -> Optional[RawBSONDocument]:
|
||||
@ -957,6 +989,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
self._check_closed()
|
||||
assert self._key_vault_coll is not None
|
||||
return self._key_vault_coll.find_one({"keyAltNames": key_alt_name})
|
||||
|
||||
def remove_key_alt_name(self, id: Binary, key_alt_name: str) -> Optional[RawBSONDocument]:
|
||||
@ -994,6 +1027,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
}
|
||||
}
|
||||
]
|
||||
assert self._key_vault_coll is not None
|
||||
return self._key_vault_coll.find_one_and_update({"_id": id}, pipeline)
|
||||
|
||||
def rewrap_many_data_key(
|
||||
@ -1052,6 +1086,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
replacements.append(op)
|
||||
if not replacements:
|
||||
return RewrapManyDataKeyResult()
|
||||
assert self._key_vault_coll is not None
|
||||
result = self._key_vault_coll.bulk_write(replacements)
|
||||
return RewrapManyDataKeyResult(result)
|
||||
|
||||
@ -1061,7 +1096,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def _check_closed(self):
|
||||
def _check_closed(self) -> None:
|
||||
if self._encryption is None:
|
||||
raise InvalidOperation("Cannot use closed ClientEncryption")
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
"""Support for automatic client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
|
||||
|
||||
try:
|
||||
import pymongocrypt # noqa: F401
|
||||
@ -245,7 +245,7 @@ class RangeOpts:
|
||||
self.precision = precision
|
||||
|
||||
@property
|
||||
def document(self) -> Mapping[str, Any]:
|
||||
def document(self) -> Dict[str, Any]:
|
||||
doc = {}
|
||||
for k, v in [
|
||||
("sparsity", int64.Int64(self.sparsity)),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user