PYTHON-1884 Implement auto encryption spec tests
Skip test for symbol type which pymongo converts to string.
Fix {} comparison with RawBSONDocument in command events.
Add support for $$type assertions.
Nicer message in check_events.
Support errorContains with empty string.
Move custom data files to custom/.
This commit is contained in:
parent
743042d843
commit
e6eecb06d1
1
.gitignore
vendored
1
.gitignore
vendored
@ -13,3 +13,4 @@ pymongo.egg-info/
|
||||
*.so
|
||||
*.egg
|
||||
.tox
|
||||
mongocryptd.pid
|
||||
|
||||
@ -30,7 +30,8 @@ from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS,
|
||||
_inflate_bson)
|
||||
from bson.son import SON
|
||||
|
||||
from pymongo.errors import ServerSelectionTimeoutError
|
||||
from pymongo.errors import (EncryptionError,
|
||||
ServerSelectionTimeoutError)
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.pool import _configured_socket, PoolOptions
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
@ -210,8 +211,11 @@ class _Encrypter(object):
|
||||
"""
|
||||
# Workaround for $clusterTime which is incompatible with check_keys.
|
||||
cluster_time = check_keys and cmd.pop('$clusterTime', None)
|
||||
encrypted_cmd = self._auto_encrypter.encrypt(
|
||||
database, _dict_to_bson(cmd, check_keys, codec_options))
|
||||
encoded_cmd = _dict_to_bson(cmd, check_keys, codec_options)
|
||||
try:
|
||||
encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd)
|
||||
except MongoCryptError as exc:
|
||||
raise EncryptionError(exc)
|
||||
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
|
||||
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
|
||||
if cluster_time:
|
||||
@ -227,7 +231,10 @@ class _Encrypter(object):
|
||||
:Returns:
|
||||
The decrypted command response.
|
||||
"""
|
||||
return self._auto_encrypter.decrypt(response)
|
||||
try:
|
||||
return self._auto_encrypter.decrypt(response)
|
||||
except MongoCryptError as exc:
|
||||
raise EncryptionError(exc)
|
||||
|
||||
def close(self):
|
||||
"""Cleanup resources."""
|
||||
|
||||
@ -247,3 +247,10 @@ class DocumentTooLarge(InvalidDocument):
|
||||
"""Raised when an encoded document is too large for the connected server.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EncryptionError(OperationFailure):
|
||||
"""Raised when encryption or decryption fails.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
|
||||
@ -33,7 +33,8 @@ from pymongo.encryption_options import AutoEncryptionOpts, _HAVE_PYMONGOCRYPT
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
from test import unittest, IntegrationTest, PyMongoTestCase, client_context
|
||||
from test.utils import wait_until
|
||||
from test.utils import TestCreator, camel_to_snake_args, wait_until
|
||||
from test.utils_spec_runner import SpecRunner
|
||||
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
@ -129,8 +130,10 @@ class EncryptionIntegrationTest(IntegrationTest):
|
||||
|
||||
|
||||
# Location of JSON test files.
|
||||
TEST_PATH = os.path.join(
|
||||
BASE = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), 'client-side-encryption')
|
||||
CUSTOM_PATH = os.path.join(BASE, 'custom')
|
||||
SPEC_PATH = os.path.join(BASE, 'spec')
|
||||
|
||||
OPTS = CodecOptions(uuid_representation=STANDARD)
|
||||
|
||||
@ -139,7 +142,7 @@ JSON_OPTS = JSONOptions(document_class=SON, uuid_representation=STANDARD)
|
||||
|
||||
|
||||
def read(filename):
|
||||
with open(os.path.join(TEST_PATH, filename)) as fp:
|
||||
with open(os.path.join(CUSTOM_PATH, filename)) as fp:
|
||||
return fp.read()
|
||||
|
||||
|
||||
@ -231,5 +234,104 @@ class TestClientSimple(EncryptionIntegrationTest):
|
||||
self._test_auto_encrypt(opts)
|
||||
|
||||
|
||||
# Spec tests
|
||||
|
||||
AWS_CREDS = {
|
||||
'accessKeyId': os.environ.get('FLE_AWS_KEY', ''),
|
||||
'secretAccessKey': os.environ.get('FLE_AWS_SECRET', '')
|
||||
}
|
||||
|
||||
|
||||
class TestSpec(SpecRunner):
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed')
|
||||
def setUpClass(cls):
|
||||
super(TestSpec, cls).setUpClass()
|
||||
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts['kms_providers']
|
||||
if 'aws' in kms_providers:
|
||||
kms_providers['aws'] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest('AWS environment credentials are not set')
|
||||
if 'key_vault_namespace' not in opts:
|
||||
opts['key_vault_namespace'] = 'admin.datakeys'
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop('autoEncryptOpts')
|
||||
if encrypt_opts:
|
||||
opts['auto_encryption_opts'] = self.parse_auto_encrypt_opts(
|
||||
encrypt_opts)
|
||||
|
||||
return super(TestSpec, self).parse_client_options(opts)
|
||||
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get('object', 'collection')
|
||||
|
||||
def maybe_skip_scenario(self, test):
|
||||
super(TestSpec, self).maybe_skip_scenario(test)
|
||||
if 'type=symbol' in test['description'].lower():
|
||||
raise unittest.SkipTest(
|
||||
'PyMongo does not support the symbol type')
|
||||
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def['key_vault_data']
|
||||
if key_vault_data:
|
||||
coll = client_context.client.get_database(
|
||||
'admin',
|
||||
write_concern=WriteConcern(w='majority'),
|
||||
codec_options=OPTS)['datakeys']
|
||||
coll.drop()
|
||||
coll.insert_many(key_vault_data)
|
||||
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = client_context.client.get_database(
|
||||
db_name, write_concern=WriteConcern(w='majority'),
|
||||
codec_options=OPTS)
|
||||
coll = db[coll_name]
|
||||
coll.drop()
|
||||
json_schema = scenario_def['json_schema']
|
||||
if json_schema:
|
||||
db.create_collection(
|
||||
coll_name,
|
||||
validator={'$jsonSchema': json_schema}, codec_options=OPTS)
|
||||
else:
|
||||
db.create_collection(coll_name)
|
||||
|
||||
if scenario_def['data']:
|
||||
# Load data.
|
||||
coll.insert_many(scenario_def['data'])
|
||||
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super(TestSpec, self).allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op['name'] == 'updateOne':
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
@client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
|
||||
return run_scenario
|
||||
|
||||
|
||||
test_creator = TestCreator(create_test, TestSpec, SPEC_PATH)
|
||||
test_creator.create_tests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -20,6 +20,7 @@ import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
from test import unittest, client_context, PyMongoTestCase
|
||||
from test.utils import TestCreator
|
||||
@ -67,6 +68,28 @@ class TestSpec(SpecRunner):
|
||||
raise unittest.SkipTest(
|
||||
'PyMongo does not support %s' % (name,))
|
||||
|
||||
def get_scenario_coll_name(self, scenario_def):
|
||||
"""Override a test's collection name to support GridFS tests."""
|
||||
if 'bucket_name' in scenario_def:
|
||||
return scenario_def['bucket_name']
|
||||
return super(TestSpec, self).get_scenario_coll_name(scenario_def)
|
||||
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup to support GridFS tests."""
|
||||
if 'bucket_name' in scenario_def:
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
db = client_context.client.get_database(
|
||||
db_name, write_concern=WriteConcern(w='majority'))
|
||||
# Create a bucket for the retryable reads GridFS tests.
|
||||
client_context.client.drop_database(db_name)
|
||||
if scenario_def['data']:
|
||||
data = scenario_def['data']
|
||||
# Load data.
|
||||
db['fs.chunks'].insert_many(data['fs.chunks'])
|
||||
db['fs.files'].insert_many(data['fs.files'])
|
||||
else:
|
||||
super(TestSpec, self).setup_scenario(scenario_def)
|
||||
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
@client_context.require_test_commands
|
||||
|
||||
@ -321,8 +321,12 @@ class TestCreator(object):
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream:
|
||||
# Use tz_aware=False to match how CodecOptions decodes
|
||||
# dates.
|
||||
opts = json_util.JSONOptions(tz_aware=False)
|
||||
scenario_def = ScenarioDict(
|
||||
json_util.loads(scenario_stream.read()))
|
||||
json_util.loads(scenario_stream.read(),
|
||||
json_options=opts))
|
||||
|
||||
test_type = os.path.splitext(filename)[0]
|
||||
|
||||
|
||||
@ -15,10 +15,14 @@
|
||||
"""Utilities for testing driver specs."""
|
||||
|
||||
import copy
|
||||
import sys
|
||||
|
||||
|
||||
from bson.binary import Binary
|
||||
from bson.py3compat import iteritems
|
||||
from bson import decode, encode
|
||||
from bson.binary import Binary, STANDARD
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.int64 import Int64
|
||||
from bson.py3compat import iteritems, abc, text_type
|
||||
from bson.son import SON
|
||||
|
||||
from gridfs import GridFSBucket
|
||||
@ -65,6 +69,7 @@ class SpecRunner(IntegrationTest):
|
||||
def setUp(self):
|
||||
super(SpecRunner, self).setUp()
|
||||
self.listener = None
|
||||
self.maxDiff = None
|
||||
|
||||
def _set_fail_point(self, client, command_args):
|
||||
cmd = SON([('configureFailPoint', 'failCommand')])
|
||||
@ -309,12 +314,16 @@ class SpecRunner(IntegrationTest):
|
||||
|
||||
return result
|
||||
|
||||
def allowable_errors(self, op):
|
||||
"""Allow encryption spec to override expected error classes."""
|
||||
return (PyMongoError,)
|
||||
|
||||
def run_operations(self, sessions, collection, ops,
|
||||
in_with_transaction=False):
|
||||
for op in ops:
|
||||
expected_result = op.get('result')
|
||||
if expect_error(op):
|
||||
with self.assertRaises(PyMongoError,
|
||||
with self.assertRaises(self.allowable_errors(op),
|
||||
msg=op['name']) as context:
|
||||
self.run_operation(sessions, collection, op.copy())
|
||||
|
||||
@ -351,9 +360,10 @@ class SpecRunner(IntegrationTest):
|
||||
if not len(test['expectations']):
|
||||
return
|
||||
|
||||
cmd_names = [event.command_name for event in res['started']]
|
||||
# Give a nicer message when there are missing or extra events
|
||||
cmds = decode_raw([event.command for event in res['started']])
|
||||
self.assertEqual(
|
||||
len(res['started']), len(test['expectations']), cmd_names)
|
||||
len(res['started']), len(test['expectations']), cmds)
|
||||
for i, expectation in enumerate(test['expectations']):
|
||||
event_type = next(iter(expectation))
|
||||
event = res['started'][i]
|
||||
@ -361,9 +371,9 @@ class SpecRunner(IntegrationTest):
|
||||
# The tests substitute 42 for any number other than 0.
|
||||
if (event.command_name == 'getMore'
|
||||
and event.command['getMore']):
|
||||
event.command['getMore'] = 42
|
||||
event.command['getMore'] = Int64(42)
|
||||
elif event.command_name == 'killCursors':
|
||||
event.command['cursors'] = [42]
|
||||
event.command['cursors'] = [Int64(42)]
|
||||
elif event.command_name == 'update':
|
||||
# TODO: remove this once PYTHON-1744 is done.
|
||||
# Add upsert and multi fields back into expectations.
|
||||
@ -396,6 +406,7 @@ class SpecRunner(IntegrationTest):
|
||||
|
||||
for attr, expected in expectation[event_type].items():
|
||||
actual = getattr(event, attr)
|
||||
expected = wrap_types(expected)
|
||||
if isinstance(expected, dict):
|
||||
for key, val in expected.items():
|
||||
if val is None:
|
||||
@ -406,7 +417,7 @@ class SpecRunner(IntegrationTest):
|
||||
self.fail("Expected key [%s] in %r" % (
|
||||
key, actual))
|
||||
else:
|
||||
self.assertEqual(val, actual[key],
|
||||
self.assertEqual(val, decode_raw(actual[key]),
|
||||
"Key [%s] in %s" % (key, actual))
|
||||
else:
|
||||
self.assertEqual(actual, expected)
|
||||
@ -432,13 +443,30 @@ class SpecRunner(IntegrationTest):
|
||||
operation."""
|
||||
self.run_operations(sessions, collection, test['operations'])
|
||||
|
||||
def parse_client_options(self, opts):
|
||||
"""Allow encryption spec to override a clientOptions parsing."""
|
||||
# Convert test['clientOptions'] to dict to avoid a Jython bug using
|
||||
# "**" with ScenarioDict.
|
||||
return dict(opts)
|
||||
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Allow specs to override a test's setup."""
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = client_context.client.get_database(
|
||||
db_name, write_concern=WriteConcern(w='majority'))
|
||||
coll = db[coll_name]
|
||||
coll.drop()
|
||||
db.create_collection(coll_name)
|
||||
if scenario_def['data']:
|
||||
# Load data.
|
||||
coll.insert_many(scenario_def['data'])
|
||||
|
||||
def run_scenario(self, scenario_def, test):
|
||||
self.maybe_skip_scenario(test)
|
||||
listener = OvertCommandListener()
|
||||
# Create a new client, to avoid interference from pooled sessions.
|
||||
# Convert test['clientOptions'] to dict to avoid a Jython bug using
|
||||
# "**" with ScenarioDict.
|
||||
client_options = dict(test['clientOptions'])
|
||||
client_options = self.parse_client_options(test['clientOptions'])
|
||||
use_multi_mongos = test['useMultipleMongoses']
|
||||
if client_context.is_mongos and use_multi_mongos:
|
||||
client = rs_client(client_context.mongos_seeds(),
|
||||
@ -456,25 +484,8 @@ class SpecRunner(IntegrationTest):
|
||||
self.addCleanup(self.kill_all_sessions)
|
||||
|
||||
database_name = self.get_scenario_db_name(scenario_def)
|
||||
write_concern_db = client_context.client.get_database(
|
||||
database_name, write_concern=WriteConcern(w='majority'))
|
||||
if 'bucket_name' in scenario_def:
|
||||
# Create a bucket for the retryable reads GridFS tests.
|
||||
collection_name = scenario_def['bucket_name']
|
||||
client_context.client.drop_database(database_name)
|
||||
if scenario_def['data']:
|
||||
data = scenario_def['data']
|
||||
# Load data.
|
||||
write_concern_db['fs.chunks'].insert_many(data['fs.chunks'])
|
||||
write_concern_db['fs.files'].insert_many(data['fs.files'])
|
||||
else:
|
||||
collection_name = self.get_scenario_coll_name(scenario_def)
|
||||
write_concern_coll = write_concern_db[collection_name]
|
||||
write_concern_coll.drop()
|
||||
write_concern_db.create_collection(collection_name)
|
||||
if scenario_def['data']:
|
||||
# Load data.
|
||||
write_concern_coll.insert_many(scenario_def['data'])
|
||||
collection_name = self.get_scenario_coll_name(scenario_def)
|
||||
self.setup_scenario(scenario_def)
|
||||
|
||||
# SPEC-1245 workaround StaleDbVersion on distinct
|
||||
for c in self.mongos_clients:
|
||||
@ -530,11 +541,16 @@ class SpecRunner(IntegrationTest):
|
||||
|
||||
# Read from the primary with local read concern to ensure causal
|
||||
# consistency.
|
||||
outcome_coll = collection.database.get_collection(
|
||||
outcome_coll = client_context.client[
|
||||
collection.database.name].get_collection(
|
||||
outcome_coll_name,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
read_concern=ReadConcern('local'))
|
||||
self.assertEqual(list(outcome_coll.find()), expected_c['data'])
|
||||
|
||||
# The expected data needs to be the left hand side here otherwise
|
||||
# CompareType(Binary) doesn't work.
|
||||
self.assertEqual(
|
||||
wrap_types(expected_c['data']), list(outcome_coll.find()))
|
||||
|
||||
|
||||
def expect_any_error(op):
|
||||
@ -546,7 +562,7 @@ def expect_any_error(op):
|
||||
|
||||
def expect_error_message(expected_result):
|
||||
if isinstance(expected_result, dict):
|
||||
return expected_result['errorContains']
|
||||
return isinstance(expected_result['errorContains'], text_type)
|
||||
|
||||
return False
|
||||
|
||||
@ -585,3 +601,38 @@ def end_sessions(sessions):
|
||||
for s in sessions.values():
|
||||
# Aborts the transaction if it's open.
|
||||
s.end_session()
|
||||
|
||||
|
||||
if sys.version_info[:2] >= (3, 6):
|
||||
DOC_CLASS = dict
|
||||
else:
|
||||
DOC_CLASS = SON
|
||||
OPTS = CodecOptions(document_class=DOC_CLASS, uuid_representation=STANDARD)
|
||||
|
||||
|
||||
def decode_raw(val):
|
||||
"""Decode RawBSONDocuments in the given container."""
|
||||
if isinstance(val, (list, abc.Mapping)):
|
||||
return decode(encode({'v': val}, codec_options=OPTS), OPTS)['v']
|
||||
return val
|
||||
|
||||
|
||||
TYPES = {
|
||||
'binData': Binary,
|
||||
'long': Int64,
|
||||
}
|
||||
|
||||
|
||||
def wrap_types(val):
|
||||
"""Support $$type assertion in command results."""
|
||||
if isinstance(val, list):
|
||||
return [wrap_types(v) for v in val]
|
||||
if isinstance(val, abc.Mapping):
|
||||
typ = val.get('$$type')
|
||||
if typ:
|
||||
return CompareType(TYPES[typ])
|
||||
d = {}
|
||||
for key in val:
|
||||
d[key] = wrap_types(val[key])
|
||||
return d
|
||||
return val
|
||||
|
||||
Loading…
Reference in New Issue
Block a user