PYTHON-4673 - Add Async Encryption Tests (#1818)
This commit is contained in:
parent
a2059dc9cb
commit
e6b95f6595
@ -50,7 +50,7 @@ try:
|
||||
_HAVE_PYMONGOCRYPT = True
|
||||
except ImportError:
|
||||
_HAVE_PYMONGOCRYPT = False
|
||||
MongoCryptCallback = object
|
||||
AsyncMongoCryptCallback = object
|
||||
|
||||
from bson import _dict_to_bson, decode, encode
|
||||
from bson.binary import STANDARD, UUID_SUBTYPE, Binary
|
||||
@ -207,10 +207,10 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
|
||||
:return: The first document from the listCollections command response as BSON.
|
||||
"""
|
||||
async with self.client_ref()[database].list_collections(
|
||||
async with await self.client_ref()[database].list_collections(
|
||||
filter=RawBSONDocument(filter)
|
||||
) as cursor:
|
||||
for doc in cursor:
|
||||
async for doc in cursor:
|
||||
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
|
||||
return None
|
||||
|
||||
|
||||
@ -297,7 +297,7 @@ async def command(
|
||||
)
|
||||
|
||||
if client and client._encrypter and reply:
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
decrypted = await client._encrypter.decrypt(reply.raw_command_response())
|
||||
response_doc = cast(
|
||||
"_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0]
|
||||
)
|
||||
|
||||
@ -309,7 +309,7 @@ class Server:
|
||||
client = operation.client # type: ignore[assignment]
|
||||
if client and client._encrypter:
|
||||
if use_cmd:
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
decrypted = await client._encrypter.decrypt(reply.raw_command_response())
|
||||
docs = _decode_all_selective(decrypted, operation.codec_options, user_fields)
|
||||
|
||||
response: Response
|
||||
|
||||
@ -1538,7 +1538,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
if not _IS_SYNC:
|
||||
# Add support for contextlib.closing.
|
||||
aclose = close
|
||||
close = close
|
||||
|
||||
def _get_topology(self) -> Topology:
|
||||
"""Get the internal :class:`~pymongo.topology.Topology` object.
|
||||
|
||||
3155
test/asynchronous/test_encryption.py
Normal file
3155
test/asynchronous/test_encryption.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import base64
|
||||
import copy
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import socket
|
||||
import socketserver
|
||||
@ -27,6 +28,7 @@ import textwrap
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
from test import IntegrationTest, PyMongoTestCase, client_context
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
@ -34,13 +36,11 @@ import pytest
|
||||
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.helpers import next
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
PyMongoTestCase,
|
||||
client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.helpers import (
|
||||
@ -93,6 +93,8 @@ from pymongo.synchronous.encryption import Algorithm, ClientEncryption, QueryTyp
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
pytestmark = pytest.mark.encryption
|
||||
|
||||
KMS_PROVIDERS = {"local": {"key": b"\x00" * 96}}
|
||||
@ -216,8 +218,8 @@ class EncryptionIntegrationTest(IntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
@client_context.require_version_min(4, 2, -1)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def assertEncrypted(self, val):
|
||||
self.assertIsInstance(val, Binary)
|
||||
@ -229,7 +231,11 @@ class EncryptionIntegrationTest(IntegrationTest):
|
||||
|
||||
|
||||
# Location of JSON test files.
|
||||
BASE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "client-side-encryption")
|
||||
if _IS_SYNC:
|
||||
BASE = os.path.join(pathlib.Path(__file__).resolve().parent, "client-side-encryption")
|
||||
else:
|
||||
BASE = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-side-encryption")
|
||||
|
||||
SPEC_PATH = os.path.join(BASE, "spec")
|
||||
|
||||
OPTS = CodecOptions()
|
||||
@ -278,9 +284,11 @@ class TestClientSimple(EncryptionIntegrationTest):
|
||||
unack = encrypted_coll.with_options(write_concern=WriteConcern(w=0))
|
||||
unack.insert_one(docs[3])
|
||||
unack.insert_many(docs[4:], ordered=False)
|
||||
wait_until(
|
||||
lambda: self.db.test.count_documents({}) == len(docs), "insert documents with w=0"
|
||||
)
|
||||
|
||||
def count_documents():
|
||||
return self.db.test.count_documents({}) == len(docs)
|
||||
|
||||
wait_until(count_documents, "insert documents with w=0")
|
||||
|
||||
# Database.command auto decrypts.
|
||||
res = client.pymongo_test.command("find", "test", filter={"ssn": "000"})
|
||||
@ -288,19 +296,19 @@ class TestClientSimple(EncryptionIntegrationTest):
|
||||
self.assertEqual(decrypted_docs, [{"_id": 0, "ssn": "000"}])
|
||||
|
||||
# Collection.find auto decrypts.
|
||||
decrypted_docs = list(encrypted_coll.find())
|
||||
decrypted_docs = encrypted_coll.find().to_list()
|
||||
self.assertEqual(decrypted_docs, docs)
|
||||
|
||||
# Collection.find auto decrypts getMores.
|
||||
decrypted_docs = list(encrypted_coll.find(batch_size=1))
|
||||
decrypted_docs = encrypted_coll.find(batch_size=1).to_list()
|
||||
self.assertEqual(decrypted_docs, docs)
|
||||
|
||||
# Collection.aggregate auto decrypts.
|
||||
decrypted_docs = list(encrypted_coll.aggregate([]))
|
||||
decrypted_docs = (encrypted_coll.aggregate([])).to_list()
|
||||
self.assertEqual(decrypted_docs, docs)
|
||||
|
||||
# Collection.aggregate auto decrypts getMores.
|
||||
decrypted_docs = list(encrypted_coll.aggregate([], batchSize=1))
|
||||
decrypted_docs = (encrypted_coll.aggregate([], batchSize=1)).to_list()
|
||||
self.assertEqual(decrypted_docs, docs)
|
||||
|
||||
# Collection.distinct auto decrypts.
|
||||
@ -402,8 +410,8 @@ class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest):
|
||||
class TestClientMaxWireVersion(IntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
@client_context.require_version_max(4, 0, 99)
|
||||
def test_raise_max_wire_version_error(self):
|
||||
@ -601,131 +609,130 @@ AWS_TEMP_NO_SESSION_CREDS = {
|
||||
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
|
||||
|
||||
|
||||
class TestSpec(SpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
if _IS_SYNC:
|
||||
# TODO: Add synchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
|
||||
class TestSpec(SpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
|
||||
return super().parse_client_options(opts)
|
||||
return super().parse_client_options(opts)
|
||||
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
|
||||
coll.delete_many({})
|
||||
if key_vault_data:
|
||||
coll.insert_many(key_vault_data)
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
|
||||
coll.delete_many({})
|
||||
if key_vault_data:
|
||||
coll.insert_many(key_vault_data)
|
||||
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
@client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
@client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
return run_scenario
|
||||
|
||||
return run_scenario
|
||||
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
|
||||
|
||||
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Prose Tests
|
||||
ALL_KMS_PROVIDERS = {
|
||||
@ -797,8 +804,8 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||
"No environment credentials are set",
|
||||
)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client.db.coll.drop()
|
||||
@ -830,7 +837,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def _tearDown_class(cls):
|
||||
cls.vault.drop()
|
||||
cls.client.close()
|
||||
cls.client_encrypted.close()
|
||||
@ -849,7 +856,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
cmd = self.listener.started_events[-1]
|
||||
self.assertEqual("insert", cmd.command_name)
|
||||
self.assertEqual({"w": "majority"}, cmd.command.get("writeConcern"))
|
||||
docs = list(self.vault.find({"_id": datakey_id}))
|
||||
docs = self.vault.find({"_id": datakey_id}).to_list()
|
||||
self.assertEqual(len(docs), 1)
|
||||
self.assertEqual(docs[0]["masterKey"]["provider"], provider_name)
|
||||
|
||||
@ -989,8 +996,8 @@ class TestViews(EncryptionIntegrationTest):
|
||||
class TestCorpus(EncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
@staticmethod
|
||||
def kms_providers():
|
||||
@ -1167,8 +1174,8 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
listener: OvertCommandListener
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
db = client_context.client.db
|
||||
cls.coll = db.coll
|
||||
cls.coll.drop()
|
||||
@ -1196,10 +1203,10 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
cls.coll_encrypted = cls.client_encrypted.db.coll
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def _tearDown_class(cls):
|
||||
cls.coll_encrypted.drop()
|
||||
cls.client_encrypted.close()
|
||||
super().tearDownClass()
|
||||
super()._tearDown_class()
|
||||
|
||||
def test_01_insert_succeeds_under_2MiB(self):
|
||||
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
|
||||
@ -1268,8 +1275,8 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
|
||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||
"No environment credentials are set",
|
||||
)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
kms_providers = {
|
||||
@ -1537,11 +1544,11 @@ class AzureGCPEncryptionTestMixin:
|
||||
class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
|
||||
def setUpClass(cls):
|
||||
def _setup_class(cls):
|
||||
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
super().setUpClass()
|
||||
super()._setup_class()
|
||||
|
||||
def test_explicit(self):
|
||||
return self._test_explicit(
|
||||
@ -1563,11 +1570,11 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest
|
||||
class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
|
||||
def setUpClass(cls):
|
||||
def _setup_class(cls):
|
||||
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
super().setUpClass()
|
||||
super()._setup_class()
|
||||
|
||||
def test_explicit(self):
|
||||
return self._test_explicit(
|
||||
@ -1944,7 +1951,8 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
|
||||
@unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed")
|
||||
def test_via_loading_shared_library(self):
|
||||
create_key_vault(
|
||||
client_context.client.keyvault.datakeys, json_data("external", "external-key.json")
|
||||
client_context.client.keyvault.datakeys,
|
||||
json_data("external", "external-key.json"),
|
||||
)
|
||||
schemas = {"db.coll": json_data("external", "external-schema.json")}
|
||||
opts = AutoEncryptionOpts(
|
||||
@ -1962,7 +1970,7 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
|
||||
self.addCleanup(client_encrypted.close)
|
||||
client_encrypted.db.coll.drop()
|
||||
client_encrypted.db.coll.insert_one({"encrypted": "test"})
|
||||
self.assertEncrypted(client_context.client.db.coll.find_one({})["encrypted"])
|
||||
self.assertEncrypted((client_context.client.db.coll.find_one({}))["encrypted"])
|
||||
no_mongocryptd_client = MongoClient(
|
||||
host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000"
|
||||
)
|
||||
@ -1989,7 +1997,8 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest):
|
||||
listener_t = Thread(target=listener)
|
||||
listener_t.start()
|
||||
create_key_vault(
|
||||
client_context.client.keyvault.datakeys, json_data("external", "external-key.json")
|
||||
client_context.client.keyvault.datakeys,
|
||||
json_data("external", "external-key.json"),
|
||||
)
|
||||
schemas = {"db.coll": json_data("external", "external-schema.json")}
|
||||
opts = AutoEncryptionOpts(
|
||||
@ -2326,11 +2335,12 @@ class TestExplicitQueryableEncryption(EncryptionIntegrationTest):
|
||||
find_payload = self.client_encryption.encrypt(
|
||||
val, Algorithm.INDEXED, self.key1_id, query_type=QueryType.EQUALITY, contention_factor=0
|
||||
)
|
||||
docs = list(
|
||||
self.encrypted_client[self.db.name].explicit_encryption.find(
|
||||
{"encryptedIndexed": find_payload}
|
||||
)
|
||||
docs = (
|
||||
self.encrypted_client[self.db.name]
|
||||
.explicit_encryption.find({"encryptedIndexed": find_payload})
|
||||
.to_list()
|
||||
)
|
||||
|
||||
self.assertEqual(len(docs), 1)
|
||||
self.assertEqual(docs[0]["encryptedIndexed"], val)
|
||||
|
||||
@ -2348,11 +2358,12 @@ class TestExplicitQueryableEncryption(EncryptionIntegrationTest):
|
||||
find_payload = self.client_encryption.encrypt(
|
||||
val, Algorithm.INDEXED, self.key1_id, query_type=QueryType.EQUALITY, contention_factor=0
|
||||
)
|
||||
docs = list(
|
||||
self.encrypted_client[self.db.name].explicit_encryption.find(
|
||||
{"encryptedIndexed": find_payload}
|
||||
)
|
||||
docs = (
|
||||
self.encrypted_client[self.db.name]
|
||||
.explicit_encryption.find({"encryptedIndexed": find_payload})
|
||||
.to_list()
|
||||
)
|
||||
|
||||
self.assertLessEqual(len(docs), 10)
|
||||
for doc in docs:
|
||||
self.assertEqual(doc["encryptedIndexed"], val)
|
||||
@ -2365,11 +2376,12 @@ class TestExplicitQueryableEncryption(EncryptionIntegrationTest):
|
||||
query_type=QueryType.EQUALITY,
|
||||
contention_factor=contention,
|
||||
)
|
||||
docs = list(
|
||||
self.encrypted_client[self.db.name].explicit_encryption.find(
|
||||
{"encryptedIndexed": find_payload}
|
||||
)
|
||||
docs = (
|
||||
self.encrypted_client[self.db.name]
|
||||
.explicit_encryption.find({"encryptedIndexed": find_payload})
|
||||
.to_list()
|
||||
)
|
||||
|
||||
self.assertEqual(len(docs), 10)
|
||||
for doc in docs:
|
||||
self.assertEqual(doc["encryptedIndexed"], val)
|
||||
@ -2381,7 +2393,7 @@ class TestExplicitQueryableEncryption(EncryptionIntegrationTest):
|
||||
{"_id": 1, "encryptedUnindexed": insert_payload}
|
||||
)
|
||||
|
||||
docs = list(self.encrypted_client[self.db.name].explicit_encryption.find({"_id": 1}))
|
||||
docs = self.encrypted_client[self.db.name].explicit_encryption.find({"_id": 1}).to_list()
|
||||
self.assertEqual(len(docs), 1)
|
||||
self.assertEqual(docs[0]["encryptedUnindexed"], val)
|
||||
|
||||
@ -2461,7 +2473,7 @@ class TestRewrapWithSeparateClientEncryption(EncryptionIntegrationTest):
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
codec_options=OPTS,
|
||||
)
|
||||
self.addCleanup(client_encryption1.close)
|
||||
self.addCleanup(client_encryption2.close)
|
||||
|
||||
# Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``.
|
||||
rewrap_many_data_key_result = client_encryption2.rewrap_many_data_key(
|
||||
@ -2647,7 +2659,8 @@ class TestRangeQueryProse(EncryptionIntegrationTest):
|
||||
if use_expr:
|
||||
find_payload = {"$expr": find_payload}
|
||||
sorted_find = sorted(
|
||||
self.encrypted_client.db.explicit_encryption.find(find_payload), key=lambda x: x["_id"]
|
||||
self.encrypted_client.db.explicit_encryption.find(find_payload).to_list(),
|
||||
key=lambda x: x["_id"],
|
||||
)
|
||||
for elem, expected in zip(sorted_find, expected_elems):
|
||||
self.assertEqual(elem[f"encrypted{name}"], expected)
|
||||
@ -3073,13 +3086,13 @@ class TestNoSessionsSupport(EncryptionIntegrationTest):
|
||||
|
||||
@classmethod
|
||||
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
start_mongocryptd(cls.MONGOCRYPTD_PORT)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
def _tearDown_class(cls):
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.listener = OvertCommandListener()
|
||||
|
||||
@ -96,6 +96,7 @@ replacements = {
|
||||
"async-transactions-ref": "transactions-ref",
|
||||
"async-snapshot-reads-ref": "snapshot-reads-ref",
|
||||
"default_async": "default",
|
||||
"aclose": "close",
|
||||
"PyMongo|async": "PyMongo",
|
||||
}
|
||||
|
||||
@ -158,6 +159,7 @@ converted_tests = [
|
||||
"test_collection.py",
|
||||
"test_cursor.py",
|
||||
"test_database.py",
|
||||
"test_encryption.py",
|
||||
"test_logger.py",
|
||||
"test_session.py",
|
||||
"test_transactions.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user