mongo-python-driver/test/test_gridfs_spec.py
2022-02-09 06:44:28 -06:00

233 lines
8.1 KiB
Python

# Copyright 2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test GridFSBucket class."""
import copy
import datetime
import os
import re
import sys
from json import loads
sys.path[0:0] = [""]
from test import IntegrationTest, unittest
import gridfs
from bson import Binary
from bson.int64 import Int64
from bson.json_util import object_hook
from gridfs.errors import CorruptGridFile, NoFile
# Commands.
_COMMANDS = {
"delete": lambda coll, doc: [coll.delete_many(d["q"]) for d in doc["deletes"]],
"insert": lambda coll, doc: coll.insert_many(doc["documents"]),
"update": lambda coll, doc: [coll.update_many(u["q"], u["u"]) for u in doc["updates"]],
}
# Location of JSON test specifications.
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "gridfs")
def camel_to_snake(camel):
# Regex to convert CamelCase to snake_case. Special case for _id.
if camel == "id":
return "file_id"
snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower()
class TestAllScenarios(IntegrationTest):
fs: gridfs.GridFSBucket
str_to_cmd: dict
@classmethod
def setUpClass(cls):
super(TestAllScenarios, cls).setUpClass()
cls.fs = gridfs.GridFSBucket(cls.db)
cls.str_to_cmd = {
"upload": cls.fs.upload_from_stream,
"download": cls.fs.open_download_stream,
"delete": cls.fs.delete,
"download_by_name": cls.fs.open_download_stream_by_name,
}
def init_db(self, data, test):
self.cleanup_colls(
self.db.fs.files, self.db.fs.chunks, self.db.expected.files, self.db.expected.chunks
)
# Read in data.
if data["files"]:
self.db.fs.files.insert_many(data["files"])
self.db.expected.files.insert_many(data["files"])
if data["chunks"]:
self.db.fs.chunks.insert_many(data["chunks"])
self.db.expected.chunks.insert_many(data["chunks"])
# Make initial modifications.
if "arrange" in test:
for cmd in test["arrange"].get("data", []):
for key in cmd.keys():
if key in _COMMANDS:
coll = self.db.get_collection(cmd[key])
_COMMANDS[key](coll, cmd)
def init_expected_db(self, test, result):
# Modify outcome DB.
for cmd in test["assert"].get("data", []):
for key in cmd.keys():
if key in _COMMANDS:
# Replace wildcards in inserts.
for doc in cmd.get("documents", []):
keylist = doc.keys()
for dockey in copy.deepcopy(list(keylist)):
if "result" in str(doc[dockey]):
doc[dockey] = result
if "actual" in str(doc[dockey]): # Avoid duplicate
doc.pop(dockey)
# Move contentType to metadata.
if dockey == "contentType":
doc["metadata"] = {dockey: doc.pop(dockey)}
coll = self.db.get_collection(cmd[key])
_COMMANDS[key](coll, cmd)
if test["assert"].get("result") == "&result":
test["assert"]["result"] = result
def sorted_list(self, coll, ignore_id):
to_sort = []
for doc in coll.find():
docstr = "{"
if ignore_id: # Cannot compare _id in chunks collection.
doc.pop("_id")
for k in sorted(doc.keys()):
if k == "uploadDate": # Can't compare datetime.
self.assertTrue(isinstance(doc[k], datetime.datetime))
else:
docstr += "%s:%s " % (k, repr(doc[k]))
to_sort.append(docstr + "}")
return to_sort
def create_test(scenario_def):
def run_scenario(self):
# Run tests.
self.assertTrue(scenario_def["tests"], "tests cannot be empty")
for test in scenario_def["tests"]:
self.init_db(scenario_def["data"], test)
# Run GridFs Operation.
operation = self.str_to_cmd[test["act"]["operation"]]
args = test["act"]["arguments"]
extra_opts = args.pop("options", {})
if "contentType" in extra_opts:
extra_opts["metadata"] = {"contentType": extra_opts.pop("contentType")}
args.update(extra_opts)
converted_args = dict((camel_to_snake(c), v) for c, v in args.items())
expect_error = test["assert"].get("error", False)
result = None
error = None
try:
result = operation(**converted_args)
if "download" in test["act"]["operation"]:
result = Binary(result.read())
except Exception as exc:
if not expect_error:
raise
error = exc
self.init_expected_db(test, result)
# Asserts.
errors = {
"FileNotFound": NoFile,
"ChunkIsMissing": CorruptGridFile,
"ExtraChunk": CorruptGridFile,
"ChunkIsWrongSize": CorruptGridFile,
"RevisionNotFound": NoFile,
}
if expect_error:
self.assertIsNotNone(error)
self.assertIsInstance(error, errors[test["assert"]["error"]], test["description"])
else:
self.assertIsNone(error)
if "result" in test["assert"]:
if test["assert"]["result"] == "void":
test["assert"]["result"] = None
self.assertEqual(result, test["assert"].get("result"))
if "data" in test["assert"]:
# Create alphabetized list
self.assertEqual(
set(self.sorted_list(self.db.fs.chunks, True)),
set(self.sorted_list(self.db.expected.chunks, True)),
)
self.assertEqual(
set(self.sorted_list(self.db.fs.files, False)),
set(self.sorted_list(self.db.expected.files, False)),
)
return run_scenario
def _object_hook(dct):
if "length" in dct:
dct["length"] = Int64(dct["length"])
return object_hook(dct)
def create_tests():
for dirpath, _, filenames in os.walk(_TEST_PATH):
for filename in filenames:
with open(os.path.join(dirpath, filename)) as scenario_stream:
scenario_def = loads(scenario_stream.read(), object_hook=_object_hook)
# Because object_hook is already defined by bson.json_util,
# and everything is named 'data'
def str2hex(jsn):
for key, val in jsn.items():
if key in ("data", "source", "result"):
if "$hex" in val:
jsn[key] = Binary(bytes.fromhex(val["$hex"]))
if isinstance(jsn[key], dict):
str2hex(jsn[key])
if isinstance(jsn[key], list):
for k in jsn[key]:
str2hex(k)
str2hex(scenario_def)
# Construct test from scenario.
new_test = create_test(scenario_def)
test_name = "test_%s" % (os.path.splitext(filename)[0])
new_test.__name__ = test_name
setattr(TestAllScenarios, new_test.__name__, new_test)
create_tests()
if __name__ == "__main__":
unittest.main()