From e6eecb06d15c1a24511f662428ad2a9d10edfd9e Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 1 Aug 2019 10:38:51 -0700 Subject: [PATCH] PYTHON-1884 Implement auto encryption spec tests Skip test for symbol type which pymongo converts to string. Fix {} comparison with RawBSONDocument in command events. Add support for $$type assertions. Nicer message in check_events. Support errorContains with empty string. Move custom data files to custom/. --- .gitignore | 1 + pymongo/encryption.py | 15 ++- pymongo/errors.py | 7 ++ .../{ => custom}/key-document-local.json | 0 .../{ => custom}/schema.json | 0 test/test_encryption.py | 108 +++++++++++++++- test/test_retryable_reads.py | 23 ++++ test/utils.py | 6 +- test/utils_spec_runner.py | 117 +++++++++++++----- 9 files changed, 236 insertions(+), 41 deletions(-) rename test/client-side-encryption/{ => custom}/key-document-local.json (100%) rename test/client-side-encryption/{ => custom}/schema.json (100%) diff --git a/.gitignore b/.gitignore index a3da2b8ad..385160b01 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ pymongo.egg-info/ *.so *.egg .tox +mongocryptd.pid diff --git a/pymongo/encryption.py b/pymongo/encryption.py index bf552b907..d69188081 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -30,7 +30,8 @@ from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, _inflate_bson) from bson.son import SON -from pymongo.errors import ServerSelectionTimeoutError +from pymongo.errors import (EncryptionError, + ServerSelectionTimeoutError) from pymongo.mongo_client import MongoClient from pymongo.pool import _configured_socket, PoolOptions from pymongo.ssl_support import get_ssl_context @@ -210,8 +211,11 @@ class _Encrypter(object): """ # Workaround for $clusterTime which is incompatible with check_keys. cluster_time = check_keys and cmd.pop('$clusterTime', None) - encrypted_cmd = self._auto_encrypter.encrypt( - database, _dict_to_bson(cmd, check_keys, codec_options)) + encoded_cmd = _dict_to_bson(cmd, check_keys, codec_options) + try: + encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd) + except MongoCryptError as exc: + raise EncryptionError(exc) # TODO: PYTHON-1922 avoid decoding the encrypted_cmd. encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) if cluster_time: @@ -227,7 +231,10 @@ class _Encrypter(object): :Returns: The decrypted command response. """ - return self._auto_encrypter.decrypt(response) + try: + return self._auto_encrypter.decrypt(response) + except MongoCryptError as exc: + raise EncryptionError(exc) def close(self): """Cleanup resources.""" diff --git a/pymongo/errors.py b/pymongo/errors.py index e5c2fc812..f6e6a49c3 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -247,3 +247,10 @@ class DocumentTooLarge(InvalidDocument): """Raised when an encoded document is too large for the connected server. """ pass + + +class EncryptionError(OperationFailure): + """Raised when encryption or decryption fails. + + .. versionadded:: 3.9 + """ diff --git a/test/client-side-encryption/key-document-local.json b/test/client-side-encryption/custom/key-document-local.json similarity index 100% rename from test/client-side-encryption/key-document-local.json rename to test/client-side-encryption/custom/key-document-local.json diff --git a/test/client-side-encryption/schema.json b/test/client-side-encryption/custom/schema.json similarity index 100% rename from test/client-side-encryption/schema.json rename to test/client-side-encryption/custom/schema.json diff --git a/test/test_encryption.py b/test/test_encryption.py index 6485f813c..1b2c680f9 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -33,7 +33,8 @@ from pymongo.encryption_options import AutoEncryptionOpts, _HAVE_PYMONGOCRYPT from pymongo.write_concern import WriteConcern from test import unittest, IntegrationTest, PyMongoTestCase, client_context -from test.utils import wait_until +from test.utils import TestCreator, camel_to_snake_args, wait_until +from test.utils_spec_runner import SpecRunner if _HAVE_PYMONGOCRYPT: @@ -129,8 +130,10 @@ class EncryptionIntegrationTest(IntegrationTest): # Location of JSON test files. -TEST_PATH = os.path.join( +BASE = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'client-side-encryption') +CUSTOM_PATH = os.path.join(BASE, 'custom') +SPEC_PATH = os.path.join(BASE, 'spec') OPTS = CodecOptions(uuid_representation=STANDARD) @@ -139,7 +142,7 @@ JSON_OPTS = JSONOptions(document_class=SON, uuid_representation=STANDARD) def read(filename): - with open(os.path.join(TEST_PATH, filename)) as fp: + with open(os.path.join(CUSTOM_PATH, filename)) as fp: return fp.read() @@ -231,5 +234,104 @@ class TestClientSimple(EncryptionIntegrationTest): self._test_auto_encrypt(opts) +# Spec tests + +AWS_CREDS = { + 'accessKeyId': os.environ.get('FLE_AWS_KEY', ''), + 'secretAccessKey': os.environ.get('FLE_AWS_SECRET', '') +} + + +class TestSpec(SpecRunner): + + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') + def setUpClass(cls): + super(TestSpec, cls).setUpClass() + + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts['kms_providers'] + if 'aws' in kms_providers: + kms_providers['aws'] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest('AWS environment credentials are not set') + if 'key_vault_namespace' not in opts: + opts['key_vault_namespace'] = 'admin.datakeys' + opts = dict(opts) + return AutoEncryptionOpts(**opts) + + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop('autoEncryptOpts') + if encrypt_opts: + opts['auto_encryption_opts'] = self.parse_auto_encrypt_opts( + encrypt_opts) + + return super(TestSpec, self).parse_client_options(opts) + + def get_object_name(self, op): + """Default object is collection.""" + return op.get('object', 'collection') + + def maybe_skip_scenario(self, test): + super(TestSpec, self).maybe_skip_scenario(test) + if 'type=symbol' in test['description'].lower(): + raise unittest.SkipTest( + 'PyMongo does not support the symbol type') + + def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def['key_vault_data'] + if key_vault_data: + coll = client_context.client.get_database( + 'admin', + write_concern=WriteConcern(w='majority'), + codec_options=OPTS)['datakeys'] + coll.drop() + coll.insert_many(key_vault_data) + + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = client_context.client.get_database( + db_name, write_concern=WriteConcern(w='majority'), + codec_options=OPTS) + coll = db[coll_name] + coll.drop() + json_schema = scenario_def['json_schema'] + if json_schema: + db.create_collection( + coll_name, + validator={'$jsonSchema': json_schema}, codec_options=OPTS) + else: + db.create_collection(coll_name) + + if scenario_def['data']: + # Load data. + coll.insert_many(scenario_def['data']) + + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super(TestSpec, self).allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op['name'] == 'updateOne': + errors += (ValueError,) + return errors + + +def create_test(scenario_def, test, name): + @client_context.require_test_commands + def run_scenario(self): + self.run_scenario(scenario_def, test) + + return run_scenario + + +test_creator = TestCreator(create_test, TestSpec, SPEC_PATH) +test_creator.create_tests() + + if __name__ == "__main__": unittest.main() diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 90346f763..94088a706 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -20,6 +20,7 @@ import sys sys.path[0:0] = [""] from pymongo.mongo_client import MongoClient +from pymongo.write_concern import WriteConcern from test import unittest, client_context, PyMongoTestCase from test.utils import TestCreator @@ -67,6 +68,28 @@ class TestSpec(SpecRunner): raise unittest.SkipTest( 'PyMongo does not support %s' % (name,)) + def get_scenario_coll_name(self, scenario_def): + """Override a test's collection name to support GridFS tests.""" + if 'bucket_name' in scenario_def: + return scenario_def['bucket_name'] + return super(TestSpec, self).get_scenario_coll_name(scenario_def) + + def setup_scenario(self, scenario_def): + """Override a test's setup to support GridFS tests.""" + if 'bucket_name' in scenario_def: + db_name = self.get_scenario_db_name(scenario_def) + db = client_context.client.get_database( + db_name, write_concern=WriteConcern(w='majority')) + # Create a bucket for the retryable reads GridFS tests. + client_context.client.drop_database(db_name) + if scenario_def['data']: + data = scenario_def['data'] + # Load data. + db['fs.chunks'].insert_many(data['fs.chunks']) + db['fs.files'].insert_many(data['fs.files']) + else: + super(TestSpec, self).setup_scenario(scenario_def) + def create_test(scenario_def, test, name): @client_context.require_test_commands diff --git a/test/utils.py b/test/utils.py index 608aa7f1e..eb849ebb0 100644 --- a/test/utils.py +++ b/test/utils.py @@ -321,8 +321,12 @@ class TestCreator(object): for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) scenario_def = ScenarioDict( - json_util.loads(scenario_stream.read())) + json_util.loads(scenario_stream.read(), + json_options=opts)) test_type = os.path.splitext(filename)[0] diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index ce5110f88..d258228ae 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -15,10 +15,14 @@ """Utilities for testing driver specs.""" import copy +import sys -from bson.binary import Binary -from bson.py3compat import iteritems +from bson import decode, encode +from bson.binary import Binary, STANDARD +from bson.codec_options import CodecOptions +from bson.int64 import Int64 +from bson.py3compat import iteritems, abc, text_type from bson.son import SON from gridfs import GridFSBucket @@ -65,6 +69,7 @@ class SpecRunner(IntegrationTest): def setUp(self): super(SpecRunner, self).setUp() self.listener = None + self.maxDiff = None def _set_fail_point(self, client, command_args): cmd = SON([('configureFailPoint', 'failCommand')]) @@ -309,12 +314,16 @@ class SpecRunner(IntegrationTest): return result + def allowable_errors(self, op): + """Allow encryption spec to override expected error classes.""" + return (PyMongoError,) + def run_operations(self, sessions, collection, ops, in_with_transaction=False): for op in ops: expected_result = op.get('result') if expect_error(op): - with self.assertRaises(PyMongoError, + with self.assertRaises(self.allowable_errors(op), msg=op['name']) as context: self.run_operation(sessions, collection, op.copy()) @@ -351,9 +360,10 @@ class SpecRunner(IntegrationTest): if not len(test['expectations']): return - cmd_names = [event.command_name for event in res['started']] + # Give a nicer message when there are missing or extra events + cmds = decode_raw([event.command for event in res['started']]) self.assertEqual( - len(res['started']), len(test['expectations']), cmd_names) + len(res['started']), len(test['expectations']), cmds) for i, expectation in enumerate(test['expectations']): event_type = next(iter(expectation)) event = res['started'][i] @@ -361,9 +371,9 @@ class SpecRunner(IntegrationTest): # The tests substitute 42 for any number other than 0. if (event.command_name == 'getMore' and event.command['getMore']): - event.command['getMore'] = 42 + event.command['getMore'] = Int64(42) elif event.command_name == 'killCursors': - event.command['cursors'] = [42] + event.command['cursors'] = [Int64(42)] elif event.command_name == 'update': # TODO: remove this once PYTHON-1744 is done. # Add upsert and multi fields back into expectations. @@ -396,6 +406,7 @@ class SpecRunner(IntegrationTest): for attr, expected in expectation[event_type].items(): actual = getattr(event, attr) + expected = wrap_types(expected) if isinstance(expected, dict): for key, val in expected.items(): if val is None: @@ -406,7 +417,7 @@ class SpecRunner(IntegrationTest): self.fail("Expected key [%s] in %r" % ( key, actual)) else: - self.assertEqual(val, actual[key], + self.assertEqual(val, decode_raw(actual[key]), "Key [%s] in %s" % (key, actual)) else: self.assertEqual(actual, expected) @@ -432,13 +443,30 @@ class SpecRunner(IntegrationTest): operation.""" self.run_operations(sessions, collection, test['operations']) + def parse_client_options(self, opts): + """Allow encryption spec to override a clientOptions parsing.""" + # Convert test['clientOptions'] to dict to avoid a Jython bug using + # "**" with ScenarioDict. + return dict(opts) + + def setup_scenario(self, scenario_def): + """Allow specs to override a test's setup.""" + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = client_context.client.get_database( + db_name, write_concern=WriteConcern(w='majority')) + coll = db[coll_name] + coll.drop() + db.create_collection(coll_name) + if scenario_def['data']: + # Load data. + coll.insert_many(scenario_def['data']) + def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) listener = OvertCommandListener() # Create a new client, to avoid interference from pooled sessions. - # Convert test['clientOptions'] to dict to avoid a Jython bug using - # "**" with ScenarioDict. - client_options = dict(test['clientOptions']) + client_options = self.parse_client_options(test['clientOptions']) use_multi_mongos = test['useMultipleMongoses'] if client_context.is_mongos and use_multi_mongos: client = rs_client(client_context.mongos_seeds(), @@ -456,25 +484,8 @@ class SpecRunner(IntegrationTest): self.addCleanup(self.kill_all_sessions) database_name = self.get_scenario_db_name(scenario_def) - write_concern_db = client_context.client.get_database( - database_name, write_concern=WriteConcern(w='majority')) - if 'bucket_name' in scenario_def: - # Create a bucket for the retryable reads GridFS tests. - collection_name = scenario_def['bucket_name'] - client_context.client.drop_database(database_name) - if scenario_def['data']: - data = scenario_def['data'] - # Load data. - write_concern_db['fs.chunks'].insert_many(data['fs.chunks']) - write_concern_db['fs.files'].insert_many(data['fs.files']) - else: - collection_name = self.get_scenario_coll_name(scenario_def) - write_concern_coll = write_concern_db[collection_name] - write_concern_coll.drop() - write_concern_db.create_collection(collection_name) - if scenario_def['data']: - # Load data. - write_concern_coll.insert_many(scenario_def['data']) + collection_name = self.get_scenario_coll_name(scenario_def) + self.setup_scenario(scenario_def) # SPEC-1245 workaround StaleDbVersion on distinct for c in self.mongos_clients: @@ -530,11 +541,16 @@ class SpecRunner(IntegrationTest): # Read from the primary with local read concern to ensure causal # consistency. - outcome_coll = collection.database.get_collection( + outcome_coll = client_context.client[ + collection.database.name].get_collection( outcome_coll_name, read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern('local')) - self.assertEqual(list(outcome_coll.find()), expected_c['data']) + + # The expected data needs to be the left hand side here otherwise + # CompareType(Binary) doesn't work. + self.assertEqual( + wrap_types(expected_c['data']), list(outcome_coll.find())) def expect_any_error(op): @@ -546,7 +562,7 @@ def expect_any_error(op): def expect_error_message(expected_result): if isinstance(expected_result, dict): - return expected_result['errorContains'] + return isinstance(expected_result['errorContains'], text_type) return False @@ -585,3 +601,38 @@ def end_sessions(sessions): for s in sessions.values(): # Aborts the transaction if it's open. s.end_session() + + +if sys.version_info[:2] >= (3, 6): + DOC_CLASS = dict +else: + DOC_CLASS = SON +OPTS = CodecOptions(document_class=DOC_CLASS, uuid_representation=STANDARD) + + +def decode_raw(val): + """Decode RawBSONDocuments in the given container.""" + if isinstance(val, (list, abc.Mapping)): + return decode(encode({'v': val}, codec_options=OPTS), OPTS)['v'] + return val + + +TYPES = { + 'binData': Binary, + 'long': Int64, +} + + +def wrap_types(val): + """Support $$type assertion in command results.""" + if isinstance(val, list): + return [wrap_types(v) for v in val] + if isinstance(val, abc.Mapping): + typ = val.get('$$type') + if typ: + return CompareType(TYPES[typ]) + d = {} + for key in val: + d[key] = wrap_types(val[key]) + return d + return val