diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 23e94ca66..7552cf3f7 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -14,6 +14,7 @@ """Client side encryption.""" +import functools import subprocess import uuid import weakref @@ -30,8 +31,9 @@ except ImportError: MongoCryptCallback = object from bson import _bson_to_dict, _dict_to_bson, decode, encode -from bson.binary import STANDARD, Binary from bson.codec_options import CodecOptions +from bson.binary import STANDARD, Binary +from bson.errors import BSONError from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson) @@ -56,6 +58,22 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument, uuid_representation=STANDARD) +def _wrap_encryption_errors(encryption_func=None): + """Decorator to wrap encryption related errors with EncryptionError.""" + @functools.wraps(encryption_func) + def wrap_encryption_errors(*args, **kwargs): + try: + return encryption_func(*args, **kwargs) + except BSONError: + # BSON encoding/decoding errors are unrelated to encryption so + # we should propagate them unchanged. + raise + except Exception as exc: + raise EncryptionError(exc) + + return wrap_encryption_errors + + class _EncryptionIO(MongoCryptCallback): def __init__(self, client, key_vault_coll, mongocryptd_client, opts): """Internal class to perform I/O on behalf of pymongocrypt.""" @@ -85,14 +103,11 @@ class _EncryptionIO(MongoCryptCallback): opts = PoolOptions(connect_timeout=_KMS_CONNECT_TIMEOUT, socket_timeout=_KMS_CONNECT_TIMEOUT, ssl_context=ctx) - try: - with _configured_socket((endpoint, _HTTPS_PORT), opts) as conn: - conn.sendall(message) - while kms_context.bytes_needed > 0: - data = conn.recv(kms_context.bytes_needed) - kms_context.feed(data) - except Exception as exc: - raise MongoCryptError(str(exc)) + with _configured_socket((endpoint, _HTTPS_PORT), opts) as conn: + conn.sendall(message) + while kms_context.bytes_needed > 0: + data = conn.recv(kms_context.bytes_needed) + kms_context.feed(data) def collection_info(self, database, filter): """Get the collection info for a namespace. @@ -222,6 +237,7 @@ class _Encrypter(object): opts._kms_providers, schema_map)) self._bypass_auto_encryption = opts._bypass_auto_encryption + @_wrap_encryption_errors def encrypt(self, database, cmd, check_keys, codec_options): """Encrypt a MongoDB command. @@ -237,16 +253,14 @@ class _Encrypter(object): # Workaround for $clusterTime which is incompatible with check_keys. cluster_time = check_keys and cmd.pop('$clusterTime', None) encoded_cmd = _dict_to_bson(cmd, check_keys, codec_options) - try: - encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd) - except MongoCryptError as exc: - raise EncryptionError(exc) + encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd) # TODO: PYTHON-1922 avoid decoding the encrypted_cmd. encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) if cluster_time: encrypt_cmd['$clusterTime'] = cluster_time return encrypt_cmd + @_wrap_encryption_errors def decrypt(self, response): """Decrypt a MongoDB command response. @@ -256,10 +270,7 @@ class _Encrypter(object): :Returns: The decrypted command response. """ - try: - return self._auto_encrypter.decrypt(response) - except MongoCryptError as exc: - raise EncryptionError(exc) + return self._auto_encrypter.decrypt(response) def close(self): """Cleanup resources.""" @@ -349,6 +360,7 @@ class ClientEncryption(object): self._encryption = ExplicitEncrypter( self._io_callbacks, MongoCryptOptions(kms_providers, None)) + @_wrap_encryption_errors def create_data_key(self, kms_provider, master_key=None, key_alt_names=None): """Create and insert a new data key into the key vault collection. @@ -383,6 +395,7 @@ class ClientEncryption(object): return self._encryption.create_data_key( kms_provider, master_key=master_key, key_alt_names=key_alt_names) + @_wrap_encryption_errors def encrypt(self, value, algorithm, key_id=None, key_alt_name=None): """Encrypt a BSON value with a given key and algorithm. @@ -410,6 +423,14 @@ class ClientEncryption(object): doc, algorithm, key_id=raw_key_id, key_alt_name=key_alt_name) return decode(encrypted_doc)['v'] + @_wrap_encryption_errors + def _decrypt(self, value): + """Internal decrypt helper.""" + doc = encode({'v': value}) + decrypted_doc = self._encryption.decrypt(doc) + # TODO: Add a required codec_options argument for decoding? + return decode(decrypted_doc, codec_options=_DATA_KEY_OPTS)['v'] + def decrypt(self, value): """Decrypt an encrypted value. @@ -423,10 +444,8 @@ class ClientEncryption(object): if not (isinstance(value, Binary) and value.subtype == 6): raise TypeError( 'value to decrypt must be a bson.binary.Binary with subtype 6') - doc = encode({'v': value}) - decrypted_doc = self._encryption.decrypt(doc) - # TODO: Add a required codec_options argument for decoding? - return decode(decrypted_doc, codec_options=_DATA_KEY_OPTS)['v'] + + return self._decrypt(value) def close(self): """Release resources.""" diff --git a/pymongo/errors.py b/pymongo/errors.py index f6e6a49c3..1b5bcbdb0 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -249,8 +249,20 @@ class DocumentTooLarge(InvalidDocument): pass -class EncryptionError(OperationFailure): +class EncryptionError(PyMongoError): """Raised when encryption or decryption fails. + This error always wraps another exception which can be retrieved via the + :attr:`cause` property. + .. versionadded:: 3.9 """ + + def __init__(self, cause): + super(EncryptionError, self).__init__(str(cause)) + self.__cause = cause + + @property + def cause(self): + """The exception that caused this encryption or decryption error.""" + return self.__cause diff --git a/test/client-side-encryption/external/external-key.json b/test/client-side-encryption/external/external-key.json new file mode 100644 index 000000000..b3fe0723b --- /dev/null +++ b/test/client-side-encryption/external/external-key.json @@ -0,0 +1,31 @@ +{ + "status": { + "$numberInt": "1" + }, + "_id": { + "$binary": { + "base64": "LOCALAAAAAAAAAAAAAAAAA==", + "subType": "04" + } + }, + "masterKey": { + "provider": "local" + }, + "updateDate": { + "$date": { + "$numberLong": "1557827033449" + } + }, + "keyMaterial": { + "$binary": { + "base64": "Ce9HSz/HKKGkIt4uyy+jDuKGA+rLC2cycykMo6vc8jXxqa1UVDYHWq1r+vZKbnnSRBfB981akzRKZCFpC05CTyFqDhXv6OnMjpG97OZEREGIsHEYiJkBW0jJJvfLLgeLsEpBzsro9FztGGXASxyxFRZFhXvHxyiLOKrdWfs7X1O/iK3pEoHMx6uSNSfUOgbebLfIqW7TO++iQS5g1xovXA==", + "subType": "00" + } + }, + "creationDate": { + "$date": { + "$numberLong": "1557827033449" + } + }, + "keyAltNames": [ "local" ] +} \ No newline at end of file diff --git a/test/client-side-encryption/external/external-schema.json b/test/client-side-encryption/external/external-schema.json new file mode 100644 index 000000000..7d8cad8c3 --- /dev/null +++ b/test/client-side-encryption/external/external-schema.json @@ -0,0 +1,19 @@ +{ + "properties": { + "encrypted": { + "encrypt": { + "keyId": [ + { + "$binary": { + "base64": "LOCALAAAAAAAAAAAAAAAAA==", + "subType": "04" + } + } + ], + "bsonType": "string", + "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" + } + } + }, + "bsonType": "object" +} diff --git a/test/test_encryption.py b/test/test_encryption.py index 369c1453f..5271c28ab 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -27,11 +27,14 @@ sys.path[0:0] = [""] from bson import BSON, json_util from bson.binary import STANDARD, Binary, UUID_SUBTYPE from bson.codec_options import CodecOptions +from bson.errors import BSONError from bson.json_util import JSONOptions from bson.raw_bson import RawBSONDocument from bson.son import SON -from pymongo.errors import ConfigurationError +from pymongo.errors import (ConfigurationError, + EncryptionError, + OperationFailure) from pymongo.encryption_options import AutoEncryptionOpts, _HAVE_PYMONGOCRYPT from pymongo.mongo_client import MongoClient from pymongo.write_concern import WriteConcern @@ -231,6 +234,10 @@ class TestClientSimple(EncryptionIntegrationTest): self.assertIsInstance(encrypted_doc['_id'], int) self.assertEncrypted(encrypted_doc['ssn']) + # Attempt to encrypt an unencodable object. + with self.assertRaises(BSONError): + encrypted_coll.insert_one({'unencodeable': object()}) + def test_auto_encrypt(self): # Configure the encrypted field via jsonSchema. json_schema = json_data('custom', 'schema.json') @@ -298,6 +305,19 @@ class TestExplicitSimple(EncryptionIntegrationTest): with self.assertRaisesRegex(TypeError, msg): client_encryption.decrypt(Binary(b'123')) + def test_bson_errors(self): + client_encryption = ClientEncryption( + KMS_PROVIDERS, 'admin.datakeys', client_context.client) + self.addCleanup(client_encryption.close) + + # Attempt to encrypt an unencodable object. + unencodable_value = object() + with self.assertRaises(BSONError): + client_encryption.encrypt( + unencodable_value, Algorithm.Deterministic, + key_id=Binary(uuid.uuid4().bytes, UUID_SUBTYPE)) + + # Spec tests AWS_CREDS = { @@ -426,6 +446,69 @@ def create_key_vault(vault, *data_keys): return vault +class TestExternalKeyVault(EncryptionIntegrationTest): + + @staticmethod + def kms_providers(): + return {'local': {'key': LOCAL_MASTER_KEY}} + + def _test_external_key_vault(self, with_external_key_vault): + self.client.db.coll.drop() + vault = create_key_vault( + self.client.admin.datakeys, + json_data('corpus', 'corpus-key-local.json'), + json_data('corpus', 'corpus-key-aws.json')) + self.addCleanup(vault.drop) + + # Configure the encrypted field via the local schema_map option. + schemas = {'db.coll': json_data('external', 'external-schema.json')} + if with_external_key_vault: + key_vault_client = rs_or_single_client( + username='fake-user', password='fake-pwd') + self.addCleanup(key_vault_client.close) + else: + key_vault_client = client_context.client + opts = AutoEncryptionOpts( + self.kms_providers(), 'admin.datakeys', schema_map=schemas, + key_vault_client=key_vault_client) + + client_encrypted = rs_or_single_client( + auto_encryption_opts=opts, uuidRepresentation='standard') + self.addCleanup(client_encrypted.close) + + client_encryption = ClientEncryption( + self.kms_providers(), 'admin.datakeys', key_vault_client) + self.addCleanup(client_encryption.close) + + if with_external_key_vault: + # Authentication error. + with self.assertRaises(EncryptionError) as ctx: + client_encrypted.db.coll.insert_one({"encrypted": "test"}) + # AuthenticationFailed error. + self.assertIsInstance(ctx.exception.cause, OperationFailure) + self.assertEqual(ctx.exception.cause.code, 18) + else: + client_encrypted.db.coll.insert_one({"encrypted": "test"}) + + if with_external_key_vault: + # Authentication error. + with self.assertRaises(EncryptionError) as ctx: + client_encryption.encrypt( + "test", Algorithm.Deterministic, key_id=LOCAL_KEY_ID) + # AuthenticationFailed error. + self.assertIsInstance(ctx.exception.cause, OperationFailure) + self.assertEqual(ctx.exception.cause.code, 18) + else: + client_encryption.encrypt( + "test", Algorithm.Deterministic, key_id=LOCAL_KEY_ID) + + def test_external_key_vault_1(self): + self._test_external_key_vault(True) + + def test_external_key_vault_2(self): + self._test_external_key_vault(False) + + class TestCorpus(EncryptionIntegrationTest): @classmethod