PYTHON-3357 Automatically create Queryable Encryption keys (#1145)
This commit is contained in:
parent
b3099c62de
commit
b492263826
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Database level operations."""
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -292,6 +293,28 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
read_concern,
|
||||
)
|
||||
|
||||
def _get_encrypted_fields(self, kwargs, coll_name, ask_db):
|
||||
encrypted_fields = kwargs.get("encryptedFields")
|
||||
if encrypted_fields:
|
||||
return deepcopy(encrypted_fields)
|
||||
if (
|
||||
self.client.options.auto_encryption_opts
|
||||
and self.client.options.auto_encryption_opts._encrypted_fields_map
|
||||
and self.client.options.auto_encryption_opts._encrypted_fields_map.get(
|
||||
f"{self.name}.{coll_name}"
|
||||
)
|
||||
):
|
||||
return deepcopy(
|
||||
self.client.options.auto_encryption_opts._encrypted_fields_map[
|
||||
f"{self.name}.{coll_name}"
|
||||
]
|
||||
)
|
||||
if ask_db and self.client.options.auto_encryption_opts:
|
||||
options = self[coll_name].options()
|
||||
if options.get("encryptedFields"):
|
||||
return deepcopy(options["encryptedFields"])
|
||||
return None
|
||||
|
||||
@_csot.apply
|
||||
def create_collection(
|
||||
self,
|
||||
@ -419,19 +442,10 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
.. _create collection command:
|
||||
https://mongodb.com/docs/manual/reference/command/create
|
||||
"""
|
||||
encrypted_fields = kwargs.get("encryptedFields")
|
||||
if (
|
||||
not encrypted_fields
|
||||
and self.client.options.auto_encryption_opts
|
||||
and self.client.options.auto_encryption_opts._encrypted_fields_map
|
||||
):
|
||||
encrypted_fields = self.client.options.auto_encryption_opts._encrypted_fields_map.get(
|
||||
"%s.%s" % (self.name, name)
|
||||
)
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
|
||||
encrypted_fields = self._get_encrypted_fields(kwargs, name, False)
|
||||
if encrypted_fields:
|
||||
common.validate_is_mapping("encryptedFields", encrypted_fields)
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
|
||||
clustered_index = kwargs.get("clusteredIndex")
|
||||
if clustered_index:
|
||||
@ -1038,21 +1052,11 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_collection must be an instance of str")
|
||||
full_name = "%s.%s" % (self.name, name)
|
||||
if (
|
||||
not encrypted_fields
|
||||
and self.client.options.auto_encryption_opts
|
||||
and self.client.options.auto_encryption_opts._encrypted_fields_map
|
||||
):
|
||||
encrypted_fields = self.client.options.auto_encryption_opts._encrypted_fields_map.get(
|
||||
full_name
|
||||
)
|
||||
if not encrypted_fields and self.client.options.auto_encryption_opts:
|
||||
colls = list(
|
||||
self.list_collections(filter={"name": name}, session=session, comment=comment)
|
||||
)
|
||||
if colls and colls[0]["options"].get("encryptedFields"):
|
||||
encrypted_fields = colls[0]["options"]["encryptedFields"]
|
||||
encrypted_fields = self._get_encrypted_fields(
|
||||
{"encryptedFields": encrypted_fields},
|
||||
name,
|
||||
True,
|
||||
)
|
||||
if encrypted_fields:
|
||||
common.validate_is_mapping("encrypted_fields", encrypted_fields)
|
||||
self._drop_helper(
|
||||
|
||||
@ -18,7 +18,8 @@ import contextlib
|
||||
import enum
|
||||
import socket
|
||||
import weakref
|
||||
from typing import Any, Generic, Mapping, Optional, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generic, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
try:
|
||||
from pymongocrypt.auto_encrypter import AutoEncrypter
|
||||
@ -39,8 +40,10 @@ from bson.errors import BSONError
|
||||
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson
|
||||
from bson.son import SON
|
||||
from pymongo import _csot
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.cursor import Cursor
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
from pymongo.database import Database
|
||||
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
@ -552,6 +555,107 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
# Use the same key vault collection as the callback.
|
||||
self._key_vault_coll = self._io_callbacks.key_vault_coll
|
||||
|
||||
def create_encrypted_collection(
|
||||
self,
|
||||
database: Database,
|
||||
name: str,
|
||||
encrypted_fields: Mapping[str, Any],
|
||||
kms_provider: Optional[str] = None,
|
||||
master_key: Optional[Mapping[str, Any]] = None,
|
||||
key_alt_names: Optional[Sequence[str]] = None,
|
||||
key_material: Optional[bytes] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Collection[_DocumentType], Mapping[str, Any]]:
|
||||
"""Create a collection with encryptedFields.
|
||||
|
||||
.. warning::
|
||||
This function does not update the encryptedFieldsMap in the client's
|
||||
AutoEncryptionOpts, thus the user must create a new client after calling this function with
|
||||
the encryptedFields returned.
|
||||
|
||||
Normally collection creation is automatic. This method should
|
||||
only be used to specify options on
|
||||
creation. :class:`~pymongo.errors.EncryptionError` will be
|
||||
raised if the collection already exists.
|
||||
|
||||
:Parameters:
|
||||
- `name`: the name of the collection to create
|
||||
- `encrypted_fields` (dict): **(BETA)** Document that describes the encrypted fields for
|
||||
Queryable Encryption. For example::
|
||||
|
||||
{
|
||||
"escCollection": "enxcol_.encryptedCollection.esc",
|
||||
"eccCollection": "enxcol_.encryptedCollection.ecc",
|
||||
"ecocCollection": "enxcol_.encryptedCollection.ecoc",
|
||||
"fields": [
|
||||
{
|
||||
"path": "firstName",
|
||||
"keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')),
|
||||
"bsonType": "string",
|
||||
"queries": {"queryType": "equality"}
|
||||
},
|
||||
{
|
||||
"path": "ssn",
|
||||
"keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')),
|
||||
"bsonType": "string"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
The "keyId" may be set to ``None`` to auto-generate the data keys.
|
||||
- `kms_provider` (optional): the KMS provider to be used
|
||||
- `master_key` (optional): Identifies a KMS-specific key used to encrypt the
|
||||
new data key. If the kmsProvider is "local" the `master_key` is
|
||||
not applicable and may be omitted.
|
||||
- `key_alt_names` (optional): An optional list of string alternate
|
||||
names used to reference a key. If a key is created with alternate
|
||||
names, then encryption may refer to the key by the unique alternate
|
||||
name instead of by ``key_id``.
|
||||
- `key_material` (optional): Sets the custom key material to be used
|
||||
by the data key for encryption and decryption.
|
||||
- `**kwargs` (optional): additional keyword arguments are the same as "create_collection".
|
||||
|
||||
All optional `create collection command`_ parameters should be passed
|
||||
as keyword arguments to this method.
|
||||
See the documentation for :meth:`~pymongo.database.Database.create_collection` for all valid options.
|
||||
|
||||
.. versionadded:: 4.4
|
||||
|
||||
.. _create collection command:
|
||||
https://mongodb.com/docs/manual/reference/command/create
|
||||
|
||||
"""
|
||||
encrypted_fields = deepcopy(encrypted_fields)
|
||||
for i, field in enumerate(encrypted_fields["fields"]):
|
||||
if isinstance(field, dict) and field.get("keyId") is None:
|
||||
try:
|
||||
encrypted_fields["fields"][i]["keyId"] = self.create_data_key(
|
||||
kms_provider=kms_provider, # type:ignore[arg-type]
|
||||
master_key=master_key,
|
||||
key_alt_names=key_alt_names,
|
||||
key_material=key_material,
|
||||
)
|
||||
except EncryptionError as exc:
|
||||
raise EncryptionError(
|
||||
Exception(
|
||||
"Error occurred while creating data key for field %s with encryptedFields=%s"
|
||||
% (field["path"], encrypted_fields)
|
||||
)
|
||||
) from exc
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
kwargs["check_exists"] = False
|
||||
try:
|
||||
return (
|
||||
database.create_collection(name=name, **kwargs),
|
||||
encrypted_fields,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise EncryptionError(
|
||||
Exception(
|
||||
f"Error: {str(exc)} occurred while creating collection with encryptedFields={str(encrypted_fields)}"
|
||||
)
|
||||
) from exc
|
||||
|
||||
def create_data_key(
|
||||
self,
|
||||
kms_provider: str,
|
||||
|
||||
@ -65,7 +65,7 @@ from bson.codec_options import CodecOptions
|
||||
from bson.errors import BSONError
|
||||
from bson.json_util import JSONOptions
|
||||
from bson.son import SON
|
||||
from pymongo import encryption
|
||||
from pymongo import ReadPreference, encryption
|
||||
from pymongo.cursor import CursorType
|
||||
from pymongo.encryption import Algorithm, ClientEncryption, QueryType
|
||||
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts, RangeOpts
|
||||
@ -2687,5 +2687,217 @@ class TestRangeQueryProse(EncryptionIntegrationTest):
|
||||
self.run_test_cases("Int", RangeOpts(min=0, max=200, sparsity=1), int)
|
||||
|
||||
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#automatic-data-encryption-keys
|
||||
class TestAutomaticDecryptionKeys(EncryptionIntegrationTest):
|
||||
@client_context.require_no_standalone
|
||||
@client_context.require_version_min(6, 0, -1)
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.key1_document = json_data("etc", "data", "keys", "key1-document.json")
|
||||
self.key1_id = self.key1_document["_id"]
|
||||
self.client.drop_database(self.db)
|
||||
self.key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document)
|
||||
self.addCleanup(self.key_vault.drop)
|
||||
self.client_encryption = ClientEncryption(
|
||||
{"local": {"key": LOCAL_MASTER_KEY}},
|
||||
self.key_vault.full_name,
|
||||
self.client,
|
||||
OPTS,
|
||||
)
|
||||
self.addCleanup(self.client_encryption.close)
|
||||
|
||||
def test_01_simple_create(self):
|
||||
coll, _ = self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
encrypted_fields={"fields": [{"path": "ssn", "bsonType": "string", "keyId": None}]},
|
||||
kms_provider="local",
|
||||
)
|
||||
with self.assertRaises(WriteError) as exc:
|
||||
coll.insert_one({"ssn": "123-45-6789"})
|
||||
self.assertEqual(exc.exception.code, 121)
|
||||
|
||||
def test_02_no_fields(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"create_encrypted_collection.* missing 1 required positional argument: 'encrypted_fields'",
|
||||
):
|
||||
self.client_encryption.create_encrypted_collection( # type:ignore[call-arg]
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
)
|
||||
|
||||
def test_03_invalid_keyid(self):
|
||||
with self.assertRaisesRegex(
|
||||
EncryptionError,
|
||||
"create.encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData",
|
||||
):
|
||||
self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
encrypted_fields={
|
||||
"fields": [{"path": "ssn", "bsonType": "string", "keyId": False}]
|
||||
},
|
||||
kms_provider="local",
|
||||
)
|
||||
|
||||
def test_04_insert_encrypted(self):
|
||||
coll, ef = self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
encrypted_fields={"fields": [{"path": "ssn", "bsonType": "string", "keyId": None}]},
|
||||
kms_provider="local",
|
||||
)
|
||||
key1_id = ef["fields"][0]["keyId"]
|
||||
encrypted_value = self.client_encryption.encrypt(
|
||||
"123-45-6789",
|
||||
key_id=key1_id,
|
||||
algorithm=Algorithm.UNINDEXED,
|
||||
)
|
||||
coll.insert_one({"ssn": encrypted_value})
|
||||
|
||||
def test_copy_encrypted_fields(self):
|
||||
encrypted_fields = {
|
||||
"fields": [
|
||||
{
|
||||
"path": "ssn",
|
||||
"bsonType": "string",
|
||||
"keyId": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
_, ef = self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
kms_provider="local",
|
||||
encrypted_fields=encrypted_fields,
|
||||
)
|
||||
self.assertIsNotNone(ef["fields"][0]["keyId"])
|
||||
self.assertIsNone(encrypted_fields["fields"][0]["keyId"])
|
||||
|
||||
def test_options_forward(self):
|
||||
coll, ef = self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
kms_provider="local",
|
||||
encrypted_fields={"fields": [{"path": "ssn", "bsonType": "string", "keyId": None}]},
|
||||
read_preference=ReadPreference.NEAREST,
|
||||
)
|
||||
self.assertEqual(coll.read_preference, ReadPreference.NEAREST)
|
||||
self.assertEqual(coll.name, "testing1")
|
||||
|
||||
def test_mixed_null_keyids(self):
|
||||
key = self.client_encryption.create_data_key(kms_provider="local")
|
||||
coll, ef = self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
encrypted_fields={
|
||||
"fields": [
|
||||
{"path": "ssn", "bsonType": "string", "keyId": None},
|
||||
{"path": "dob", "bsonType": "string", "keyId": key},
|
||||
{"path": "secrets", "bsonType": "string"},
|
||||
{"path": "address", "bsonType": "string", "keyId": None},
|
||||
]
|
||||
},
|
||||
kms_provider="local",
|
||||
)
|
||||
encrypted_values = [
|
||||
self.client_encryption.encrypt(
|
||||
val,
|
||||
key_id=key,
|
||||
algorithm=Algorithm.UNINDEXED,
|
||||
)
|
||||
for val, key in zip(
|
||||
["123-45-6789", "11/22/1963", "My secret", "New Mexico, 87104"],
|
||||
[field["keyId"] for field in ef["fields"]],
|
||||
)
|
||||
]
|
||||
coll.insert_one(
|
||||
{
|
||||
"ssn": encrypted_values[0],
|
||||
"dob": encrypted_values[1],
|
||||
"secrets": encrypted_values[2],
|
||||
"address": encrypted_values[3],
|
||||
}
|
||||
)
|
||||
|
||||
def test_create_datakey_fails(self):
|
||||
key = self.client_encryption.create_data_key(kms_provider="local")
|
||||
# Make sure the error message includes the previous keys in the error message even when generating keys fails.
|
||||
with self.assertRaisesRegex(
|
||||
EncryptionError,
|
||||
f"data key for field ssn with encryptedFields=.*{re.escape(repr(key))}.*keyId.*Binary.*keyId.*None",
|
||||
):
|
||||
self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
encrypted_fields={
|
||||
"fields": [
|
||||
{"path": "address", "bsonType": "string", "keyId": key},
|
||||
{"path": "dob", "bsonType": "string", "keyId": None},
|
||||
# Because this is the second one to use the altName "1", it will fail when creating the data_key.
|
||||
{"path": "ssn", "bsonType": "string", "keyId": None},
|
||||
]
|
||||
},
|
||||
kms_provider="local",
|
||||
key_alt_names=["1"],
|
||||
)
|
||||
|
||||
def test_create_failure(self):
|
||||
key = self.client_encryption.create_data_key(kms_provider="local")
|
||||
# Make sure the error message includes the previous keys in the error message even when it is the creation
|
||||
# of the collection that fails.
|
||||
with self.assertRaisesRegex(
|
||||
EncryptionError,
|
||||
f"while creating collection with encryptedFields=.*{re.escape(repr(key))}.*keyId.*Binary",
|
||||
):
|
||||
self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name=1, # type:ignore[arg-type]
|
||||
encrypted_fields={
|
||||
"fields": [
|
||||
{"path": "address", "bsonType": "string", "keyId": key},
|
||||
{"path": "dob", "bsonType": "string", "keyId": None},
|
||||
]
|
||||
},
|
||||
kms_provider="local",
|
||||
)
|
||||
|
||||
def test_collection_name_collision(self):
|
||||
encrypted_fields = {
|
||||
"fields": [
|
||||
{"path": "address", "bsonType": "string", "keyId": None},
|
||||
]
|
||||
}
|
||||
self.db.create_collection("testing1")
|
||||
with self.assertRaisesRegex(
|
||||
EncryptionError,
|
||||
"while creating collection with encryptedFields=.*keyId.*Binary",
|
||||
):
|
||||
self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
encrypted_fields=encrypted_fields,
|
||||
kms_provider="local",
|
||||
)
|
||||
self.db.drop_collection("testing1", encrypted_fields=encrypted_fields)
|
||||
self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
encrypted_fields=encrypted_fields,
|
||||
kms_provider="local",
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
EncryptionError,
|
||||
"while creating collection with encryptedFields=.*keyId.*Binary",
|
||||
):
|
||||
self.client_encryption.create_encrypted_collection(
|
||||
database=self.db,
|
||||
name="testing1",
|
||||
encrypted_fields=encrypted_fields,
|
||||
kms_provider="local",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user