201 lines
7.9 KiB
Python
201 lines
7.9 KiB
Python
# Copyright 2021-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Test Explicit Encryption with AsyncIOMotorClient."""
|
|
|
|
import unittest
|
|
import uuid
|
|
from test import env
|
|
from test.asyncio_tests import AsyncIOTestCase, asyncio_test
|
|
|
|
from bson.binary import JAVA_LEGACY, STANDARD, UUID_SUBTYPE, Binary
|
|
from bson.codec_options import CodecOptions
|
|
from bson.errors import BSONError
|
|
from pymongo.encryption import Algorithm
|
|
from pymongo.errors import InvalidOperation
|
|
|
|
from motor.motor_asyncio import AsyncIOMotorClientEncryption
|
|
|
|
KMS_PROVIDERS = {"local": {"key": b"\x00" * 96}}
|
|
|
|
OPTS = CodecOptions(uuid_representation=STANDARD)
|
|
|
|
|
|
class TestExplicitSimple(AsyncIOTestCase):
|
|
@env.require_csfle
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
def assertEncrypted(self, val):
|
|
self.assertIsInstance(val, Binary)
|
|
self.assertEqual(val.subtype, 6)
|
|
|
|
def assertBinaryUUID(self, val):
|
|
self.assertIsInstance(val, Binary)
|
|
self.assertEqual(val.subtype, UUID_SUBTYPE)
|
|
|
|
@asyncio_test
|
|
async def test_encrypt_decrypt(self):
|
|
client = self.asyncio_client()
|
|
client_encryption = AsyncIOMotorClientEncryption(
|
|
KMS_PROVIDERS, "keyvault.datakeys", client, OPTS
|
|
)
|
|
# Use standard UUID representation.
|
|
key_vault = client.keyvault.get_collection("datakeys", codec_options=OPTS)
|
|
|
|
# Create the encrypted field's data key.
|
|
key_id = await client_encryption.create_data_key("local", key_alt_names=["name"])
|
|
self.assertBinaryUUID(key_id)
|
|
self.assertTrue(await key_vault.find_one({"_id": key_id}))
|
|
|
|
# Create an unused data key to make sure filtering works.
|
|
unused_key_id = await client_encryption.create_data_key("local", key_alt_names=["unused"])
|
|
self.assertBinaryUUID(unused_key_id)
|
|
self.assertTrue(await key_vault.find_one({"_id": unused_key_id}))
|
|
|
|
doc = {"_id": 0, "ssn": "000"}
|
|
encrypted_ssn = await client_encryption.encrypt(
|
|
doc["ssn"], Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id
|
|
)
|
|
|
|
# Ensure encryption via key_alt_name for the same key produces the
|
|
# same output.
|
|
encrypted_ssn2 = await client_encryption.encrypt(
|
|
doc["ssn"], Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="name"
|
|
)
|
|
self.assertEqual(encrypted_ssn, encrypted_ssn2)
|
|
|
|
# Test decryption.
|
|
decrypted_ssn = await client_encryption.decrypt(encrypted_ssn)
|
|
self.assertEqual(decrypted_ssn, doc["ssn"])
|
|
|
|
await key_vault.drop()
|
|
await client_encryption.close()
|
|
|
|
@asyncio_test
|
|
async def test_validation(self):
|
|
client = self.asyncio_client()
|
|
client_encryption = AsyncIOMotorClientEncryption(
|
|
KMS_PROVIDERS, "keyvault.datakeys", client, OPTS
|
|
)
|
|
|
|
msg = "value to decrypt must be a bson.binary.Binary with subtype 6"
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
await client_encryption.decrypt("str")
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
await client_encryption.decrypt(Binary(b"123"))
|
|
|
|
msg = "key_id must be a bson.binary.Binary with subtype 4"
|
|
algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
await client_encryption.encrypt("str", algo, key_id="str")
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
await client_encryption.encrypt("str", algo, key_id=Binary(b"123"))
|
|
|
|
await client_encryption.close()
|
|
|
|
@asyncio_test
|
|
async def test_bson_errors(self):
|
|
client = self.asyncio_client()
|
|
client_encryption = AsyncIOMotorClientEncryption(
|
|
KMS_PROVIDERS, "keyvault.datakeys", client, OPTS
|
|
)
|
|
|
|
# Attempt to encrypt an unencodable object.
|
|
unencodable_value = object()
|
|
with self.assertRaises(BSONError):
|
|
await client_encryption.encrypt(
|
|
unencodable_value,
|
|
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
|
|
key_id=Binary(uuid.uuid4().bytes, UUID_SUBTYPE),
|
|
)
|
|
|
|
await client_encryption.close()
|
|
|
|
@asyncio_test
|
|
async def test_codec_options(self):
|
|
client = self.asyncio_client()
|
|
with self.assertRaisesRegex(TypeError, "codec_options must be"):
|
|
AsyncIOMotorClientEncryption(KMS_PROVIDERS, "keyvault.datakeys", client, None)
|
|
|
|
opts = CodecOptions(uuid_representation=JAVA_LEGACY)
|
|
client_encryption_legacy = AsyncIOMotorClientEncryption(
|
|
KMS_PROVIDERS, "keyvault.datakeys", client, opts
|
|
)
|
|
|
|
# Create the encrypted field's data key.
|
|
key_id = await client_encryption_legacy.create_data_key("local")
|
|
|
|
# Encrypt a UUID with JAVA_LEGACY codec options.
|
|
value = uuid.uuid4()
|
|
encrypted_legacy = await client_encryption_legacy.encrypt(
|
|
value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id
|
|
)
|
|
decrypted_value_legacy = await client_encryption_legacy.decrypt(encrypted_legacy)
|
|
self.assertEqual(decrypted_value_legacy, value)
|
|
|
|
# Encrypt the same UUID with STANDARD codec options.
|
|
client_encryption = AsyncIOMotorClientEncryption(
|
|
KMS_PROVIDERS, "keyvault.datakeys", client, OPTS
|
|
)
|
|
encrypted_standard = await client_encryption.encrypt(
|
|
value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id
|
|
)
|
|
decrypted_standard = await client_encryption.decrypt(encrypted_standard)
|
|
self.assertEqual(decrypted_standard, value)
|
|
|
|
# Test that codec_options is applied during encryption.
|
|
self.assertNotEqual(encrypted_standard, encrypted_legacy)
|
|
# Test that codec_options is applied during decryption.
|
|
self.assertEqual(
|
|
await client_encryption_legacy.decrypt(encrypted_standard),
|
|
Binary.from_uuid(value, uuid_representation=STANDARD),
|
|
)
|
|
self.assertNotEqual(await client_encryption.decrypt(encrypted_legacy), value)
|
|
|
|
await client_encryption_legacy.close()
|
|
await client_encryption.close()
|
|
|
|
@asyncio_test
|
|
async def test_close(self):
|
|
client = self.asyncio_client()
|
|
client_encryption = AsyncIOMotorClientEncryption(
|
|
KMS_PROVIDERS, "keyvault.datakeys", client, OPTS
|
|
)
|
|
await client_encryption.close()
|
|
# Close can be called multiple times.
|
|
await client_encryption.close()
|
|
algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic
|
|
msg = "Cannot use closed ClientEncryption"
|
|
with self.assertRaisesRegex(InvalidOperation, msg):
|
|
await client_encryption.create_data_key("local")
|
|
with self.assertRaisesRegex(InvalidOperation, msg):
|
|
await client_encryption.encrypt("val", algo, key_alt_name="name")
|
|
with self.assertRaisesRegex(InvalidOperation, msg):
|
|
await client_encryption.decrypt(Binary(b"", 6))
|
|
|
|
@asyncio_test
|
|
async def test_with_statement(self):
|
|
client = self.asyncio_client()
|
|
async with AsyncIOMotorClientEncryption(
|
|
KMS_PROVIDERS, "keyvault.datakeys", client, OPTS
|
|
) as client_encryption:
|
|
pass
|
|
with self.assertRaisesRegex(InvalidOperation, "Cannot use closed ClientEncryption"):
|
|
await client_encryption.create_data_key("local")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|