PYTHON-1884 Implement auto encryption spec tests

Skip test for symbol type which pymongo converts to string.
Fix {} comparison with RawBSONDocument in command events.
Add support for $$type assertions.
Nicer message in check_events.
Support errorContains with empty string.
Move custom data files to custom/.
This commit is contained in:
Shane Harvey 2019-08-01 10:38:51 -07:00
parent 743042d843
commit e6eecb06d1
9 changed files with 236 additions and 41 deletions

1
.gitignore vendored
View File

@ -13,3 +13,4 @@ pymongo.egg-info/
*.so
*.egg
.tox
mongocryptd.pid

View File

@ -30,7 +30,8 @@ from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS,
_inflate_bson)
from bson.son import SON
from pymongo.errors import ServerSelectionTimeoutError
from pymongo.errors import (EncryptionError,
ServerSelectionTimeoutError)
from pymongo.mongo_client import MongoClient
from pymongo.pool import _configured_socket, PoolOptions
from pymongo.ssl_support import get_ssl_context
@ -210,8 +211,11 @@ class _Encrypter(object):
"""
# Workaround for $clusterTime which is incompatible with check_keys.
cluster_time = check_keys and cmd.pop('$clusterTime', None)
encrypted_cmd = self._auto_encrypter.encrypt(
database, _dict_to_bson(cmd, check_keys, codec_options))
encoded_cmd = _dict_to_bson(cmd, check_keys, codec_options)
try:
encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd)
except MongoCryptError as exc:
raise EncryptionError(exc)
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
if cluster_time:
@ -227,7 +231,10 @@ class _Encrypter(object):
:Returns:
The decrypted command response.
"""
return self._auto_encrypter.decrypt(response)
try:
return self._auto_encrypter.decrypt(response)
except MongoCryptError as exc:
raise EncryptionError(exc)
def close(self):
"""Cleanup resources."""

View File

@ -247,3 +247,10 @@ class DocumentTooLarge(InvalidDocument):
"""Raised when an encoded document is too large for the connected server.
"""
pass
class EncryptionError(OperationFailure):
"""Raised when encryption or decryption fails.
.. versionadded:: 3.9
"""

View File

@ -33,7 +33,8 @@ from pymongo.encryption_options import AutoEncryptionOpts, _HAVE_PYMONGOCRYPT
from pymongo.write_concern import WriteConcern
from test import unittest, IntegrationTest, PyMongoTestCase, client_context
from test.utils import wait_until
from test.utils import TestCreator, camel_to_snake_args, wait_until
from test.utils_spec_runner import SpecRunner
if _HAVE_PYMONGOCRYPT:
@ -129,8 +130,10 @@ class EncryptionIntegrationTest(IntegrationTest):
# Location of JSON test files.
TEST_PATH = os.path.join(
BASE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'client-side-encryption')
CUSTOM_PATH = os.path.join(BASE, 'custom')
SPEC_PATH = os.path.join(BASE, 'spec')
OPTS = CodecOptions(uuid_representation=STANDARD)
@ -139,7 +142,7 @@ JSON_OPTS = JSONOptions(document_class=SON, uuid_representation=STANDARD)
def read(filename):
with open(os.path.join(TEST_PATH, filename)) as fp:
with open(os.path.join(CUSTOM_PATH, filename)) as fp:
return fp.read()
@ -231,5 +234,104 @@ class TestClientSimple(EncryptionIntegrationTest):
self._test_auto_encrypt(opts)
# Spec tests
AWS_CREDS = {
'accessKeyId': os.environ.get('FLE_AWS_KEY', ''),
'secretAccessKey': os.environ.get('FLE_AWS_SECRET', '')
}
class TestSpec(SpecRunner):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed')
def setUpClass(cls):
super(TestSpec, cls).setUpClass()
def parse_auto_encrypt_opts(self, opts):
"""Parse clientOptions.autoEncryptOpts."""
opts = camel_to_snake_args(opts)
kms_providers = opts['kms_providers']
if 'aws' in kms_providers:
kms_providers['aws'] = AWS_CREDS
if not any(AWS_CREDS.values()):
self.skipTest('AWS environment credentials are not set')
if 'key_vault_namespace' not in opts:
opts['key_vault_namespace'] = 'admin.datakeys'
opts = dict(opts)
return AutoEncryptionOpts(**opts)
def parse_client_options(self, opts):
"""Override clientOptions parsing to support autoEncryptOpts."""
encrypt_opts = opts.pop('autoEncryptOpts')
if encrypt_opts:
opts['auto_encryption_opts'] = self.parse_auto_encrypt_opts(
encrypt_opts)
return super(TestSpec, self).parse_client_options(opts)
def get_object_name(self, op):
"""Default object is collection."""
return op.get('object', 'collection')
def maybe_skip_scenario(self, test):
super(TestSpec, self).maybe_skip_scenario(test)
if 'type=symbol' in test['description'].lower():
raise unittest.SkipTest(
'PyMongo does not support the symbol type')
def setup_scenario(self, scenario_def):
"""Override a test's setup."""
key_vault_data = scenario_def['key_vault_data']
if key_vault_data:
coll = client_context.client.get_database(
'admin',
write_concern=WriteConcern(w='majority'),
codec_options=OPTS)['datakeys']
coll.drop()
coll.insert_many(key_vault_data)
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
db = client_context.client.get_database(
db_name, write_concern=WriteConcern(w='majority'),
codec_options=OPTS)
coll = db[coll_name]
coll.drop()
json_schema = scenario_def['json_schema']
if json_schema:
db.create_collection(
coll_name,
validator={'$jsonSchema': json_schema}, codec_options=OPTS)
else:
db.create_collection(coll_name)
if scenario_def['data']:
# Load data.
coll.insert_many(scenario_def['data'])
def allowable_errors(self, op):
"""Override expected error classes."""
errors = super(TestSpec, self).allowable_errors(op)
# An updateOne test expects encryption to error when no $ operator
# appears but pymongo raises a client side ValueError in this case.
if op['name'] == 'updateOne':
errors += (ValueError,)
return errors
def create_test(scenario_def, test, name):
@client_context.require_test_commands
def run_scenario(self):
self.run_scenario(scenario_def, test)
return run_scenario
test_creator = TestCreator(create_test, TestSpec, SPEC_PATH)
test_creator.create_tests()
if __name__ == "__main__":
unittest.main()

View File

@ -20,6 +20,7 @@ import sys
sys.path[0:0] = [""]
from pymongo.mongo_client import MongoClient
from pymongo.write_concern import WriteConcern
from test import unittest, client_context, PyMongoTestCase
from test.utils import TestCreator
@ -67,6 +68,28 @@ class TestSpec(SpecRunner):
raise unittest.SkipTest(
'PyMongo does not support %s' % (name,))
def get_scenario_coll_name(self, scenario_def):
"""Override a test's collection name to support GridFS tests."""
if 'bucket_name' in scenario_def:
return scenario_def['bucket_name']
return super(TestSpec, self).get_scenario_coll_name(scenario_def)
def setup_scenario(self, scenario_def):
"""Override a test's setup to support GridFS tests."""
if 'bucket_name' in scenario_def:
db_name = self.get_scenario_db_name(scenario_def)
db = client_context.client.get_database(
db_name, write_concern=WriteConcern(w='majority'))
# Create a bucket for the retryable reads GridFS tests.
client_context.client.drop_database(db_name)
if scenario_def['data']:
data = scenario_def['data']
# Load data.
db['fs.chunks'].insert_many(data['fs.chunks'])
db['fs.files'].insert_many(data['fs.files'])
else:
super(TestSpec, self).setup_scenario(scenario_def)
def create_test(scenario_def, test, name):
@client_context.require_test_commands

View File

@ -321,8 +321,12 @@ class TestCreator(object):
for filename in filenames:
with open(os.path.join(dirpath, filename)) as scenario_stream:
# Use tz_aware=False to match how CodecOptions decodes
# dates.
opts = json_util.JSONOptions(tz_aware=False)
scenario_def = ScenarioDict(
json_util.loads(scenario_stream.read()))
json_util.loads(scenario_stream.read(),
json_options=opts))
test_type = os.path.splitext(filename)[0]

View File

@ -15,10 +15,14 @@
"""Utilities for testing driver specs."""
import copy
import sys
from bson.binary import Binary
from bson.py3compat import iteritems
from bson import decode, encode
from bson.binary import Binary, STANDARD
from bson.codec_options import CodecOptions
from bson.int64 import Int64
from bson.py3compat import iteritems, abc, text_type
from bson.son import SON
from gridfs import GridFSBucket
@ -65,6 +69,7 @@ class SpecRunner(IntegrationTest):
def setUp(self):
super(SpecRunner, self).setUp()
self.listener = None
self.maxDiff = None
def _set_fail_point(self, client, command_args):
cmd = SON([('configureFailPoint', 'failCommand')])
@ -309,12 +314,16 @@ class SpecRunner(IntegrationTest):
return result
def allowable_errors(self, op):
"""Allow encryption spec to override expected error classes."""
return (PyMongoError,)
def run_operations(self, sessions, collection, ops,
in_with_transaction=False):
for op in ops:
expected_result = op.get('result')
if expect_error(op):
with self.assertRaises(PyMongoError,
with self.assertRaises(self.allowable_errors(op),
msg=op['name']) as context:
self.run_operation(sessions, collection, op.copy())
@ -351,9 +360,10 @@ class SpecRunner(IntegrationTest):
if not len(test['expectations']):
return
cmd_names = [event.command_name for event in res['started']]
# Give a nicer message when there are missing or extra events
cmds = decode_raw([event.command for event in res['started']])
self.assertEqual(
len(res['started']), len(test['expectations']), cmd_names)
len(res['started']), len(test['expectations']), cmds)
for i, expectation in enumerate(test['expectations']):
event_type = next(iter(expectation))
event = res['started'][i]
@ -361,9 +371,9 @@ class SpecRunner(IntegrationTest):
# The tests substitute 42 for any number other than 0.
if (event.command_name == 'getMore'
and event.command['getMore']):
event.command['getMore'] = 42
event.command['getMore'] = Int64(42)
elif event.command_name == 'killCursors':
event.command['cursors'] = [42]
event.command['cursors'] = [Int64(42)]
elif event.command_name == 'update':
# TODO: remove this once PYTHON-1744 is done.
# Add upsert and multi fields back into expectations.
@ -396,6 +406,7 @@ class SpecRunner(IntegrationTest):
for attr, expected in expectation[event_type].items():
actual = getattr(event, attr)
expected = wrap_types(expected)
if isinstance(expected, dict):
for key, val in expected.items():
if val is None:
@ -406,7 +417,7 @@ class SpecRunner(IntegrationTest):
self.fail("Expected key [%s] in %r" % (
key, actual))
else:
self.assertEqual(val, actual[key],
self.assertEqual(val, decode_raw(actual[key]),
"Key [%s] in %s" % (key, actual))
else:
self.assertEqual(actual, expected)
@ -432,13 +443,30 @@ class SpecRunner(IntegrationTest):
operation."""
self.run_operations(sessions, collection, test['operations'])
def parse_client_options(self, opts):
"""Allow encryption spec to override a clientOptions parsing."""
# Convert test['clientOptions'] to dict to avoid a Jython bug using
# "**" with ScenarioDict.
return dict(opts)
def setup_scenario(self, scenario_def):
"""Allow specs to override a test's setup."""
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
db = client_context.client.get_database(
db_name, write_concern=WriteConcern(w='majority'))
coll = db[coll_name]
coll.drop()
db.create_collection(coll_name)
if scenario_def['data']:
# Load data.
coll.insert_many(scenario_def['data'])
def run_scenario(self, scenario_def, test):
self.maybe_skip_scenario(test)
listener = OvertCommandListener()
# Create a new client, to avoid interference from pooled sessions.
# Convert test['clientOptions'] to dict to avoid a Jython bug using
# "**" with ScenarioDict.
client_options = dict(test['clientOptions'])
client_options = self.parse_client_options(test['clientOptions'])
use_multi_mongos = test['useMultipleMongoses']
if client_context.is_mongos and use_multi_mongos:
client = rs_client(client_context.mongos_seeds(),
@ -456,25 +484,8 @@ class SpecRunner(IntegrationTest):
self.addCleanup(self.kill_all_sessions)
database_name = self.get_scenario_db_name(scenario_def)
write_concern_db = client_context.client.get_database(
database_name, write_concern=WriteConcern(w='majority'))
if 'bucket_name' in scenario_def:
# Create a bucket for the retryable reads GridFS tests.
collection_name = scenario_def['bucket_name']
client_context.client.drop_database(database_name)
if scenario_def['data']:
data = scenario_def['data']
# Load data.
write_concern_db['fs.chunks'].insert_many(data['fs.chunks'])
write_concern_db['fs.files'].insert_many(data['fs.files'])
else:
collection_name = self.get_scenario_coll_name(scenario_def)
write_concern_coll = write_concern_db[collection_name]
write_concern_coll.drop()
write_concern_db.create_collection(collection_name)
if scenario_def['data']:
# Load data.
write_concern_coll.insert_many(scenario_def['data'])
collection_name = self.get_scenario_coll_name(scenario_def)
self.setup_scenario(scenario_def)
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
@ -530,11 +541,16 @@ class SpecRunner(IntegrationTest):
# Read from the primary with local read concern to ensure causal
# consistency.
outcome_coll = collection.database.get_collection(
outcome_coll = client_context.client[
collection.database.name].get_collection(
outcome_coll_name,
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern('local'))
self.assertEqual(list(outcome_coll.find()), expected_c['data'])
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.
self.assertEqual(
wrap_types(expected_c['data']), list(outcome_coll.find()))
def expect_any_error(op):
@ -546,7 +562,7 @@ def expect_any_error(op):
def expect_error_message(expected_result):
if isinstance(expected_result, dict):
return expected_result['errorContains']
return isinstance(expected_result['errorContains'], text_type)
return False
@ -585,3 +601,38 @@ def end_sessions(sessions):
for s in sessions.values():
# Aborts the transaction if it's open.
s.end_session()
if sys.version_info[:2] >= (3, 6):
DOC_CLASS = dict
else:
DOC_CLASS = SON
OPTS = CodecOptions(document_class=DOC_CLASS, uuid_representation=STANDARD)
def decode_raw(val):
"""Decode RawBSONDocuments in the given container."""
if isinstance(val, (list, abc.Mapping)):
return decode(encode({'v': val}, codec_options=OPTS), OPTS)['v']
return val
TYPES = {
'binData': Binary,
'long': Int64,
}
def wrap_types(val):
"""Support $$type assertion in command results."""
if isinstance(val, list):
return [wrap_types(v) for v in val]
if isinstance(val, abc.Mapping):
typ = val.get('$$type')
if typ:
return CompareType(TYPES[typ])
d = {}
for key in val:
d[key] = wrap_types(val[key])
return d
return val