mongo-python-driver/test/test_bson_corpus.py

234 lines
9.2 KiB
Python

# Copyright 2016-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the BSON corpus specification tests."""
from __future__ import annotations
import binascii
import functools
import glob
import json
import os
import sys
from decimal import DecimalException
sys.path[0:0] = [""]
from test import unittest
from bson import decode, encode, json_util
from bson.binary import STANDARD
from bson.codec_options import CodecOptions
from bson.dbref import DBRef
from bson.decimal128 import Decimal128
from bson.errors import InvalidBSON, InvalidDocument, InvalidId
from bson.json_util import JSONMode
from bson.son import SON
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bson_corpus")
_TESTS_TO_SKIP = {
# Python cannot decode dates after year 9999.
"Y10K",
}
_NON_PARSE_ERRORS = {
# {"$date": <number>} is our legacy format which we still need to parse.
"Bad $date (number, not string or hash)",
# This variant of $numberLong may have been generated by an old version
# of mongoexport.
"Bad $numberLong (number, not string)",
# Python's UUID constructor is very permissive.
"$uuid invalid value--misplaced hyphens",
# We parse Regex flags with extra characters, including nulls.
"Null byte in $regularExpression options",
}
_IMPLCIT_LOSSY_TESTS = {
# JSON decodes top-level $ref+$id as a DBRef but BSON doesn't.
"Document with key names similar to those of a DBRef"
}
_DEPRECATED_BSON_TYPES = {
# Symbol
"0x0E": str,
# Undefined
"0x06": type(None),
# DBPointer
"0x0C": DBRef,
}
# Need to set tz_aware=True in order to use "strict" dates in extended JSON.
codec_options: CodecOptions = CodecOptions(tz_aware=True, document_class=SON)
codec_options_no_tzaware: CodecOptions = CodecOptions(document_class=SON)
# We normally encode UUID as binary subtype 0x03,
# but we'll need to encode to subtype 0x04 for one of the tests.
codec_options_uuid_04 = codec_options._replace(uuid_representation=STANDARD)
json_options_uuid_04 = json_util.JSONOptions(
json_mode=JSONMode.CANONICAL, uuid_representation=STANDARD
)
json_options_iso8601 = json_util.JSONOptions(
datetime_representation=json_util.DatetimeRepresentation.ISO8601, json_mode=JSONMode.LEGACY
)
to_extjson = functools.partial(json_util.dumps, json_options=json_util.CANONICAL_JSON_OPTIONS)
to_extjson_uuid_04 = functools.partial(json_util.dumps, json_options=json_options_uuid_04)
to_extjson_iso8601 = functools.partial(json_util.dumps, json_options=json_options_iso8601)
to_relaxed_extjson = functools.partial(json_util.dumps, json_options=json_util.RELAXED_JSON_OPTIONS)
to_bson_uuid_04 = functools.partial(encode, codec_options=codec_options_uuid_04)
to_bson = functools.partial(encode, codec_options=codec_options)
decode_bson = functools.partial(decode, codec_options=codec_options_no_tzaware)
decode_extjson = functools.partial(
json_util.loads,
json_options=json_util.JSONOptions(json_mode=JSONMode.CANONICAL, document_class=SON),
)
loads = functools.partial(json.loads, object_pairs_hook=SON)
class TestBSONCorpus(unittest.TestCase):
def assertJsonEqual(self, first, second, msg=None):
"""Fail if the two json strings are unequal.
Normalize json by parsing it with the built-in json library. This
accounts for discrepancies in spacing.
"""
self.assertEqual(loads(first), loads(second), msg=msg)
def create_test(case_spec):
bson_type = case_spec["bson_type"]
# Test key is absent when testing top-level documents.
test_key = case_spec.get("test_key")
deprecated = case_spec.get("deprecated")
def run_test(self):
for valid_case in case_spec.get("valid", []):
description = valid_case["description"]
if description in _TESTS_TO_SKIP:
continue
# Special case for testing encoding UUID as binary subtype 0x04.
if description.startswith("subtype 0x04"):
encode_extjson = to_extjson_uuid_04
encode_bson = to_bson_uuid_04
else:
encode_extjson = to_extjson
encode_bson = to_bson
cB = binascii.unhexlify(valid_case["canonical_bson"].encode("utf8"))
cEJ = valid_case["canonical_extjson"]
rEJ = valid_case.get("relaxed_extjson")
dEJ = valid_case.get("degenerate_extjson")
if description in _IMPLCIT_LOSSY_TESTS:
valid_case.setdefault("lossy", True)
lossy = valid_case.get("lossy")
# BSON double, use lowercase 'e+' to match Python's encoding
if bson_type == "0x01":
cEJ = cEJ.replace("E+", "e+")
decoded_bson = decode_bson(cB)
if not lossy:
# Make sure we can parse the legacy (default) JSON format.
legacy_json = json_util.dumps(
decoded_bson, json_options=json_util.LEGACY_JSON_OPTIONS
)
self.assertEqual(decode_extjson(legacy_json), decoded_bson, description)
if deprecated:
if "converted_bson" in valid_case:
converted_bson = binascii.unhexlify(valid_case["converted_bson"].encode("utf8"))
self.assertEqual(encode_bson(decoded_bson), converted_bson)
self.assertJsonEqual(
encode_extjson(decode_bson(converted_bson)), valid_case["converted_extjson"]
)
# Make sure we can decode the type.
self.assertEqual(decoded_bson, decode_extjson(cEJ))
if test_key is not None:
self.assertIsInstance(decoded_bson[test_key], _DEPRECATED_BSON_TYPES[bson_type])
continue
self.assertJsonEqual(encode_extjson(decoded_bson), cEJ)
# Test round-tripping canonical extended json.
decoded_json = decode_extjson(cEJ)
self.assertJsonEqual(encode_extjson(decoded_json), cEJ)
if not lossy:
self.assertEqual(encode_bson(decoded_json), cB)
# Test round-tripping degenerate bson.
if "degenerate_bson" in valid_case:
dB = binascii.unhexlify(valid_case["degenerate_bson"].encode("utf8"))
self.assertEqual(encode_bson(decode_bson(dB)), cB)
# Test round-tripping degenerate extended json.
if dEJ is not None:
decoded_json = decode_extjson(dEJ)
self.assertJsonEqual(encode_extjson(decoded_json), cEJ)
if not lossy:
self.assertEqual(encode_bson(decoded_json), cB)
# Test round-tripping relaxed extended json.
if rEJ is not None:
self.assertJsonEqual(to_relaxed_extjson(decoded_bson), rEJ)
decoded_json = decode_extjson(rEJ)
self.assertJsonEqual(to_relaxed_extjson(decoded_json), rEJ)
for decode_error_case in case_spec.get("decodeErrors", []):
with self.assertRaises(InvalidBSON):
decode_bson(binascii.unhexlify(decode_error_case["bson"].encode("utf8")))
for parse_error_case in case_spec.get("parseErrors", []):
description = parse_error_case["description"]
if description in _NON_PARSE_ERRORS:
decode_extjson(parse_error_case["string"])
continue
if bson_type == "0x13":
self.assertRaises(DecimalException, Decimal128, parse_error_case["string"])
elif bson_type == "0x00":
try:
doc = decode_extjson(parse_error_case["string"])
# Null bytes are validated when encoding to BSON.
if "Null" in description:
to_bson(doc)
raise AssertionError("exception not raised for test case: " + description)
except (ValueError, KeyError, TypeError, InvalidId, InvalidDocument):
pass
elif bson_type == "0x05":
try:
decode_extjson(parse_error_case["string"])
raise AssertionError("exception not raised for test case: " + description)
except (TypeError, ValueError):
pass
else:
raise AssertionError("cannot test parseErrors for type " + bson_type)
return run_test
def create_tests():
for filename in glob.glob(os.path.join(_TEST_PATH, "*.json")):
test_suffix, _ = os.path.splitext(os.path.basename(filename))
with open(filename, encoding="utf-8") as bson_test_file:
test_method = create_test(json.load(bson_test_file))
setattr(TestBSONCorpus, "test_" + test_suffix, test_method)
create_tests()
if __name__ == "__main__":
unittest.main()