From b492263826123a78513d94eec186e184eb97f421 Mon Sep 17 00:00:00 2001 From: Julius Park Date: Tue, 31 Jan 2023 14:58:37 -0800 Subject: [PATCH] PYTHON-3357 Automatically create Queryable Encryption keys (#1145) --- pymongo/database.py | 56 ++++++----- pymongo/encryption.py | 106 +++++++++++++++++++- test/test_encryption.py | 214 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 348 insertions(+), 28 deletions(-) diff --git a/pymongo/database.py b/pymongo/database.py index 86754b2c0..b3c6c6085 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -13,6 +13,7 @@ # limitations under the License. """Database level operations.""" +from copy import deepcopy from typing import ( TYPE_CHECKING, Any, @@ -292,6 +293,28 @@ class Database(common.BaseObject, Generic[_DocumentType]): read_concern, ) + def _get_encrypted_fields(self, kwargs, coll_name, ask_db): + encrypted_fields = kwargs.get("encryptedFields") + if encrypted_fields: + return deepcopy(encrypted_fields) + if ( + self.client.options.auto_encryption_opts + and self.client.options.auto_encryption_opts._encrypted_fields_map + and self.client.options.auto_encryption_opts._encrypted_fields_map.get( + f"{self.name}.{coll_name}" + ) + ): + return deepcopy( + self.client.options.auto_encryption_opts._encrypted_fields_map[ + f"{self.name}.{coll_name}" + ] + ) + if ask_db and self.client.options.auto_encryption_opts: + options = self[coll_name].options() + if options.get("encryptedFields"): + return deepcopy(options["encryptedFields"]) + return None + @_csot.apply def create_collection( self, @@ -419,19 +442,10 @@ class Database(common.BaseObject, Generic[_DocumentType]): .. _create collection command: https://mongodb.com/docs/manual/reference/command/create """ - encrypted_fields = kwargs.get("encryptedFields") - if ( - not encrypted_fields - and self.client.options.auto_encryption_opts - and self.client.options.auto_encryption_opts._encrypted_fields_map - ): - encrypted_fields = self.client.options.auto_encryption_opts._encrypted_fields_map.get( - "%s.%s" % (self.name, name) - ) - kwargs["encryptedFields"] = encrypted_fields - + encrypted_fields = self._get_encrypted_fields(kwargs, name, False) if encrypted_fields: common.validate_is_mapping("encryptedFields", encrypted_fields) + kwargs["encryptedFields"] = encrypted_fields clustered_index = kwargs.get("clusteredIndex") if clustered_index: @@ -1038,21 +1052,11 @@ class Database(common.BaseObject, Generic[_DocumentType]): if not isinstance(name, str): raise TypeError("name_or_collection must be an instance of str") - full_name = "%s.%s" % (self.name, name) - if ( - not encrypted_fields - and self.client.options.auto_encryption_opts - and self.client.options.auto_encryption_opts._encrypted_fields_map - ): - encrypted_fields = self.client.options.auto_encryption_opts._encrypted_fields_map.get( - full_name - ) - if not encrypted_fields and self.client.options.auto_encryption_opts: - colls = list( - self.list_collections(filter={"name": name}, session=session, comment=comment) - ) - if colls and colls[0]["options"].get("encryptedFields"): - encrypted_fields = colls[0]["options"]["encryptedFields"] + encrypted_fields = self._get_encrypted_fields( + {"encryptedFields": encrypted_fields}, + name, + True, + ) if encrypted_fields: common.validate_is_mapping("encrypted_fields", encrypted_fields) self._drop_helper( diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 8b51863f9..0e281f7b3 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -18,7 +18,8 @@ import contextlib import enum import socket import weakref -from typing import Any, Generic, Mapping, Optional, Sequence +from copy import deepcopy +from typing import Any, Generic, Mapping, Optional, Sequence, Tuple try: from pymongocrypt.auto_encrypter import AutoEncrypter @@ -39,8 +40,10 @@ from bson.errors import BSONError from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson from bson.son import SON from pymongo import _csot +from pymongo.collection import Collection from pymongo.cursor import Cursor from pymongo.daemon import _spawn_daemon +from pymongo.database import Database from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.errors import ( ConfigurationError, @@ -552,6 +555,107 @@ class ClientEncryption(Generic[_DocumentType]): # Use the same key vault collection as the callback. self._key_vault_coll = self._io_callbacks.key_vault_coll + def create_encrypted_collection( + self, + database: Database, + name: str, + encrypted_fields: Mapping[str, Any], + kms_provider: Optional[str] = None, + master_key: Optional[Mapping[str, Any]] = None, + key_alt_names: Optional[Sequence[str]] = None, + key_material: Optional[bytes] = None, + **kwargs: Any, + ) -> Tuple[Collection[_DocumentType], Mapping[str, Any]]: + """Create a collection with encryptedFields. + + .. warning:: + This function does not update the encryptedFieldsMap in the client's + AutoEncryptionOpts, thus the user must create a new client after calling this function with + the encryptedFields returned. + + Normally collection creation is automatic. This method should + only be used to specify options on + creation. :class:`~pymongo.errors.EncryptionError` will be + raised if the collection already exists. + + :Parameters: + - `name`: the name of the collection to create + - `encrypted_fields` (dict): **(BETA)** Document that describes the encrypted fields for + Queryable Encryption. For example:: + + { + "escCollection": "enxcol_.encryptedCollection.esc", + "eccCollection": "enxcol_.encryptedCollection.ecc", + "ecocCollection": "enxcol_.encryptedCollection.ecoc", + "fields": [ + { + "path": "firstName", + "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), + "bsonType": "string", + "queries": {"queryType": "equality"} + }, + { + "path": "ssn", + "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), + "bsonType": "string" + } + ] + } + + The "keyId" may be set to ``None`` to auto-generate the data keys. + - `kms_provider` (optional): the KMS provider to be used + - `master_key` (optional): Identifies a KMS-specific key used to encrypt the + new data key. If the kmsProvider is "local" the `master_key` is + not applicable and may be omitted. + - `key_alt_names` (optional): An optional list of string alternate + names used to reference a key. If a key is created with alternate + names, then encryption may refer to the key by the unique alternate + name instead of by ``key_id``. + - `key_material` (optional): Sets the custom key material to be used + by the data key for encryption and decryption. + - `**kwargs` (optional): additional keyword arguments are the same as "create_collection". + + All optional `create collection command`_ parameters should be passed + as keyword arguments to this method. + See the documentation for :meth:`~pymongo.database.Database.create_collection` for all valid options. + + .. versionadded:: 4.4 + + .. _create collection command: + https://mongodb.com/docs/manual/reference/command/create + + """ + encrypted_fields = deepcopy(encrypted_fields) + for i, field in enumerate(encrypted_fields["fields"]): + if isinstance(field, dict) and field.get("keyId") is None: + try: + encrypted_fields["fields"][i]["keyId"] = self.create_data_key( + kms_provider=kms_provider, # type:ignore[arg-type] + master_key=master_key, + key_alt_names=key_alt_names, + key_material=key_material, + ) + except EncryptionError as exc: + raise EncryptionError( + Exception( + "Error occurred while creating data key for field %s with encryptedFields=%s" + % (field["path"], encrypted_fields) + ) + ) from exc + kwargs["encryptedFields"] = encrypted_fields + kwargs["check_exists"] = False + try: + return ( + database.create_collection(name=name, **kwargs), + encrypted_fields, + ) + except Exception as exc: + raise EncryptionError( + Exception( + f"Error: {str(exc)} occurred while creating collection with encryptedFields={str(encrypted_fields)}" + ) + ) from exc + def create_data_key( self, kms_provider: str, diff --git a/test/test_encryption.py b/test/test_encryption.py index fc6d62c72..0df875d95 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -65,7 +65,7 @@ from bson.codec_options import CodecOptions from bson.errors import BSONError from bson.json_util import JSONOptions from bson.son import SON -from pymongo import encryption +from pymongo import ReadPreference, encryption from pymongo.cursor import CursorType from pymongo.encryption import Algorithm, ClientEncryption, QueryType from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts, RangeOpts @@ -2687,5 +2687,217 @@ class TestRangeQueryProse(EncryptionIntegrationTest): self.run_test_cases("Int", RangeOpts(min=0, max=200, sparsity=1), int) +# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#automatic-data-encryption-keys +class TestAutomaticDecryptionKeys(EncryptionIntegrationTest): + @client_context.require_no_standalone + @client_context.require_version_min(6, 0, -1) + def setUp(self): + super().setUp() + self.key1_document = json_data("etc", "data", "keys", "key1-document.json") + self.key1_id = self.key1_document["_id"] + self.client.drop_database(self.db) + self.key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) + self.addCleanup(self.key_vault.drop) + self.client_encryption = ClientEncryption( + {"local": {"key": LOCAL_MASTER_KEY}}, + self.key_vault.full_name, + self.client, + OPTS, + ) + self.addCleanup(self.client_encryption.close) + + def test_01_simple_create(self): + coll, _ = self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + encrypted_fields={"fields": [{"path": "ssn", "bsonType": "string", "keyId": None}]}, + kms_provider="local", + ) + with self.assertRaises(WriteError) as exc: + coll.insert_one({"ssn": "123-45-6789"}) + self.assertEqual(exc.exception.code, 121) + + def test_02_no_fields(self): + with self.assertRaisesRegex( + TypeError, + "create_encrypted_collection.* missing 1 required positional argument: 'encrypted_fields'", + ): + self.client_encryption.create_encrypted_collection( # type:ignore[call-arg] + database=self.db, + name="testing1", + ) + + def test_03_invalid_keyid(self): + with self.assertRaisesRegex( + EncryptionError, + "create.encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData", + ): + self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + encrypted_fields={ + "fields": [{"path": "ssn", "bsonType": "string", "keyId": False}] + }, + kms_provider="local", + ) + + def test_04_insert_encrypted(self): + coll, ef = self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + encrypted_fields={"fields": [{"path": "ssn", "bsonType": "string", "keyId": None}]}, + kms_provider="local", + ) + key1_id = ef["fields"][0]["keyId"] + encrypted_value = self.client_encryption.encrypt( + "123-45-6789", + key_id=key1_id, + algorithm=Algorithm.UNINDEXED, + ) + coll.insert_one({"ssn": encrypted_value}) + + def test_copy_encrypted_fields(self): + encrypted_fields = { + "fields": [ + { + "path": "ssn", + "bsonType": "string", + "keyId": None, + } + ] + } + _, ef = self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + kms_provider="local", + encrypted_fields=encrypted_fields, + ) + self.assertIsNotNone(ef["fields"][0]["keyId"]) + self.assertIsNone(encrypted_fields["fields"][0]["keyId"]) + + def test_options_forward(self): + coll, ef = self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + kms_provider="local", + encrypted_fields={"fields": [{"path": "ssn", "bsonType": "string", "keyId": None}]}, + read_preference=ReadPreference.NEAREST, + ) + self.assertEqual(coll.read_preference, ReadPreference.NEAREST) + self.assertEqual(coll.name, "testing1") + + def test_mixed_null_keyids(self): + key = self.client_encryption.create_data_key(kms_provider="local") + coll, ef = self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + encrypted_fields={ + "fields": [ + {"path": "ssn", "bsonType": "string", "keyId": None}, + {"path": "dob", "bsonType": "string", "keyId": key}, + {"path": "secrets", "bsonType": "string"}, + {"path": "address", "bsonType": "string", "keyId": None}, + ] + }, + kms_provider="local", + ) + encrypted_values = [ + self.client_encryption.encrypt( + val, + key_id=key, + algorithm=Algorithm.UNINDEXED, + ) + for val, key in zip( + ["123-45-6789", "11/22/1963", "My secret", "New Mexico, 87104"], + [field["keyId"] for field in ef["fields"]], + ) + ] + coll.insert_one( + { + "ssn": encrypted_values[0], + "dob": encrypted_values[1], + "secrets": encrypted_values[2], + "address": encrypted_values[3], + } + ) + + def test_create_datakey_fails(self): + key = self.client_encryption.create_data_key(kms_provider="local") + # Make sure the error message includes the previous keys in the error message even when generating keys fails. + with self.assertRaisesRegex( + EncryptionError, + f"data key for field ssn with encryptedFields=.*{re.escape(repr(key))}.*keyId.*Binary.*keyId.*None", + ): + self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + encrypted_fields={ + "fields": [ + {"path": "address", "bsonType": "string", "keyId": key}, + {"path": "dob", "bsonType": "string", "keyId": None}, + # Because this is the second one to use the altName "1", it will fail when creating the data_key. + {"path": "ssn", "bsonType": "string", "keyId": None}, + ] + }, + kms_provider="local", + key_alt_names=["1"], + ) + + def test_create_failure(self): + key = self.client_encryption.create_data_key(kms_provider="local") + # Make sure the error message includes the previous keys in the error message even when it is the creation + # of the collection that fails. + with self.assertRaisesRegex( + EncryptionError, + f"while creating collection with encryptedFields=.*{re.escape(repr(key))}.*keyId.*Binary", + ): + self.client_encryption.create_encrypted_collection( + database=self.db, + name=1, # type:ignore[arg-type] + encrypted_fields={ + "fields": [ + {"path": "address", "bsonType": "string", "keyId": key}, + {"path": "dob", "bsonType": "string", "keyId": None}, + ] + }, + kms_provider="local", + ) + + def test_collection_name_collision(self): + encrypted_fields = { + "fields": [ + {"path": "address", "bsonType": "string", "keyId": None}, + ] + } + self.db.create_collection("testing1") + with self.assertRaisesRegex( + EncryptionError, + "while creating collection with encryptedFields=.*keyId.*Binary", + ): + self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + encrypted_fields=encrypted_fields, + kms_provider="local", + ) + self.db.drop_collection("testing1", encrypted_fields=encrypted_fields) + self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + encrypted_fields=encrypted_fields, + kms_provider="local", + ) + with self.assertRaisesRegex( + EncryptionError, + "while creating collection with encryptedFields=.*keyId.*Binary", + ): + self.client_encryption.create_encrypted_collection( + database=self.db, + name="testing1", + encrypted_fields=encrypted_fields, + kms_provider="local", + ) + + if __name__ == "__main__": unittest.main()