From 2e6e9a85070a766c2eb53351dabdc85956367f1f Mon Sep 17 00:00:00 2001 From: Julius Park Date: Tue, 7 Feb 2023 10:23:59 -0800 Subject: [PATCH] PYTHON-3592 createEncryptedCollection should raise a specialized exception to report the intermediate encryptedFields (#1148) --- pymongo/encryption.py | 17 +++++-------- pymongo/errors.py | 25 ++++++++++++++++++ test/test_encryption.py | 56 ++++++++++++++++++++++------------------- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/pymongo/encryption.py b/pymongo/encryption.py index cf76cbe14..6a6150d0c 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -47,6 +47,7 @@ from pymongo.database import Database from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.errors import ( ConfigurationError, + EncryptedCollectionError, EncryptionError, InvalidOperation, ServerSelectionTimeoutError, @@ -614,6 +615,9 @@ class ClientEncryption(Generic[_DocumentType]): as keyword arguments to this method. See the documentation for :meth:`~pymongo.database.Database.create_collection` for all valid options. + :Raises: + - :class:`~pymongo.errors.EncryptedCollectionError`: When either data-key creation or creating the collection fails. + .. versionadded:: 4.4 .. _create collection command: @@ -629,12 +633,7 @@ class ClientEncryption(Generic[_DocumentType]): master_key=master_key, ) except EncryptionError as exc: - raise EncryptionError( - Exception( - "Error occurred while creating data key for field %s with encryptedFields=%s" - % (field["path"], encrypted_fields) - ) - ) from exc + raise EncryptedCollectionError(exc, encrypted_fields) from exc kwargs["encryptedFields"] = encrypted_fields kwargs["check_exists"] = False try: @@ -643,11 +642,7 @@ class ClientEncryption(Generic[_DocumentType]): encrypted_fields, ) except Exception as exc: - raise EncryptionError( - Exception( - f"Error: {str(exc)} occurred while creating collection with encryptedFields={str(encrypted_fields)}" - ) - ) from exc + raise EncryptedCollectionError(exc, encrypted_fields) from exc def create_data_key( self, diff --git a/pymongo/errors.py b/pymongo/errors.py index efc7e2eca..192eec99d 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -359,6 +359,31 @@ class EncryptionError(PyMongoError): return False +class EncryptedCollectionError(EncryptionError): + """Raised when creating a collection with encrypted_fields fails. + + .. note:: EncryptedCollectionError and `create_encrypted_collection` are both part of the + Queryable Encryption beta. Backwards-breaking changes may be made before the final release. + + .. versionadded:: 4.4 + """ + + def __init__(self, cause: Exception, encrypted_fields: Mapping[str, Any]) -> None: + super(EncryptedCollectionError, self).__init__(cause) + self.__encrypted_fields = encrypted_fields + + @property + def encrypted_fields(self) -> Mapping[str, Any]: + """The encrypted_fields document that allows inferring which data keys are *known* to be created. + + Note that the returned document is not guaranteed to contain information about *all* of the data keys that + were created, for example in the case of an indefinite error like a timeout. Use the `cause` property to + determine whether a definite or indefinite error caused this error, and only rely on the accuracy of the + encrypted_fields if the error is definite. + """ + return self.__encrypted_fields + + class _OperationCancelled(AutoReconnect): """Internal error raised when a socket operation is cancelled.""" diff --git a/test/test_encryption.py b/test/test_encryption.py index eb9bf8e98..dcfb63916 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -74,6 +74,7 @@ from pymongo.errors import ( BulkWriteError, ConfigurationError, DuplicateKeyError, + EncryptedCollectionError, EncryptionError, InvalidOperation, OperationFailure, @@ -2729,7 +2730,7 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest): def test_03_invalid_keyid(self): with self.assertRaisesRegex( - EncryptionError, + EncryptedCollectionError, "create.encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData", ): self.client_encryption.create_encrypted_collection( @@ -2823,31 +2824,32 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest): def test_create_datakey_fails(self): key = self.client_encryption.create_data_key(kms_provider="local") - # Make sure the error message includes the previous keys in the error message even when generating keys fails. - with self.assertRaisesRegex( - EncryptionError, - f"data key for field dob with encryptedFields=.*{re.escape(repr(key))}.*keyId.*None", - ): + encrypted_fields = { + "fields": [ + {"path": "address", "bsonType": "string", "keyId": key}, + {"path": "dob", "bsonType": "string", "keyId": None}, + ] + } + # Make sure the exception's encrypted_fields object includes the previous keys in the error message even when + # generating keys fails. + with self.assertRaises( + EncryptedCollectionError, + ) as exc: self.client_encryption.create_encrypted_collection( database=self.db, name="testing1", - encrypted_fields={ - "fields": [ - {"path": "address", "bsonType": "string", "keyId": key}, - {"path": "dob", "bsonType": "string", "keyId": None}, - ] - }, + encrypted_fields=encrypted_fields, kms_provider="does not exist", ) + self.assertEqual(exc.exception.encrypted_fields, encrypted_fields) def test_create_failure(self): key = self.client_encryption.create_data_key(kms_provider="local") - # Make sure the error message includes the previous keys in the error message even when it is the creation - # of the collection that fails. - with self.assertRaisesRegex( - EncryptionError, - f"while creating collection with encryptedFields=.*{re.escape(repr(key))}.*keyId.*Binary", - ): + # Make sure the exception's encrypted_fields object includes the previous keys in the error message even when + # it is the creation of the collection that fails. + with self.assertRaises( + EncryptedCollectionError, + ) as exc: self.client_encryption.create_encrypted_collection( database=self.db, name=1, # type:ignore[arg-type] @@ -2859,6 +2861,8 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest): }, kms_provider="local", ) + for field in exc.exception.encrypted_fields["fields"]: + self.assertIsInstance(field["keyId"], Binary) def test_collection_name_collision(self): encrypted_fields = { @@ -2867,16 +2871,16 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest): ] } self.db.create_collection("testing1") - with self.assertRaisesRegex( - EncryptionError, - "while creating collection with encryptedFields=.*keyId.*Binary", - ): + with self.assertRaises( + EncryptedCollectionError, + ) as exc: self.client_encryption.create_encrypted_collection( database=self.db, name="testing1", encrypted_fields=encrypted_fields, kms_provider="local", ) + self.assertIsInstance(exc.exception.encrypted_fields["fields"][0]["keyId"], Binary) self.db.drop_collection("testing1", encrypted_fields=encrypted_fields) self.client_encryption.create_encrypted_collection( database=self.db, @@ -2884,16 +2888,16 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest): encrypted_fields=encrypted_fields, kms_provider="local", ) - with self.assertRaisesRegex( - EncryptionError, - "while creating collection with encryptedFields=.*keyId.*Binary", - ): + with self.assertRaises( + EncryptedCollectionError, + ) as exc: self.client_encryption.create_encrypted_collection( database=self.db, name="testing1", encrypted_fields=encrypted_fields, kms_provider="local", ) + self.assertIsInstance(exc.exception.encrypted_fields["fields"][0]["keyId"], Binary) if __name__ == "__main__":