diff --git a/test/test_encryption.py b/test/test_encryption.py index 209308aba..f2a02780b 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -61,6 +61,7 @@ from pymongo.cursor import CursorType from pymongo.encryption import Algorithm, ClientEncryption, QueryType from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( + AutoReconnect, BulkWriteError, ConfigurationError, EncryptionError, @@ -1769,6 +1770,83 @@ class TestDeadlockProse(EncryptionIntegrationTest): self.assertEqual(len(self.topology_listener.results["opened"]), 1) +# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#14-decryption-events +class TestDecryptProse(EncryptionIntegrationTest): + def setUp(self): + self.client = client_context.client + self.client.db.drop_collection("decryption_events") + self.client.keyvault.drop_collection("datakeys") + self.client.keyvault.datakeys.create_index( + "keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}} + ) + kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} + + self.client_encryption = ClientEncryption( + kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() + ) + keyID = self.client_encryption.create_data_key("local") + self.cipher_text = self.client_encryption.encrypt( + "hello", key_id=keyID, algorithm=Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic + ) + if self.cipher_text[-1] == 0: + self.malformed_cipher_text = self.cipher_text[:-1] + b"1" + else: + self.malformed_cipher_text = self.cipher_text[:-1] + b"0" + self.malformed_cipher_text = Binary(self.malformed_cipher_text, 6) + opts = AutoEncryptionOpts( + key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map + ) + self.listener = AllowListEventListener("aggregate") + self.encrypted_client = MongoClient( + auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] + ) + self.addCleanup(self.encrypted_client.close) + + def test_01_command_error(self): + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"errorCode": 123, "failCommands": ["aggregate"]}, + } + ): + with self.assertRaises(OperationFailure): + self.encrypted_client.db.decryption_events.aggregate([]) + self.assertEqual(len(self.listener.results["failed"]), 1) + for event in self.listener.results["failed"]: + self.assertEqual(event.failure["code"], 123) + + def test_02_network_error(self): + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"errorCode": 123, "closeConnection": True, "failCommands": ["aggregate"]}, + } + ): + with self.assertRaises(AutoReconnect): + self.encrypted_client.db.decryption_events.aggregate([]) + self.assertEqual(len(self.listener.results["failed"]), 1) + self.assertEqual(self.listener.results["failed"][0].command_name, "aggregate") + + def test_03_decrypt_error(self): + self.encrypted_client.db.decryption_events.insert_one( + {"encrypted": self.malformed_cipher_text} + ) + with self.assertRaises(EncryptionError): + next(self.encrypted_client.db.decryption_events.aggregate([])) + event = self.listener.results["succeeded"][0] + self.assertEqual(len(self.listener.results["failed"]), 0) + self.assertEqual( + event.reply["cursor"]["firstBatch"][0]["encrypted"], self.malformed_cipher_text + ) + + def test_04_decrypt_success(self): + self.encrypted_client.db.decryption_events.insert_one({"encrypted": self.cipher_text}) + next(self.encrypted_client.db.decryption_events.aggregate([])) + event = self.listener.results["succeeded"][0] + self.assertEqual(len(self.listener.results["failed"]), 0) + self.assertEqual(event.reply["cursor"]["firstBatch"][0]["encrypted"], self.cipher_text) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#bypass-spawning-mongocryptd class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest): @unittest.skipIf(