diff --git a/bson/__init__.py b/bson/__init__.py index 66517fc07..a7c9ddc50 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -886,18 +886,11 @@ _ENCODERS = { _abc.Mapping: _encode_mapping, } - -_MARKERS = { - 5: _encode_binary, - 7: _encode_objectid, - 11: _encode_regex, - 13: _encode_code, - 17: _encode_timestamp, - 18: _encode_long, - 100: _encode_dbref, - 127: _encode_maxkey, - 255: _encode_minkey, -} +# Map each _type_marker to its encoder for faster lookup. +_MARKERS = {} +for _typ in _ENCODERS: + if hasattr(_typ, "_type_marker"): + _MARKERS[_typ._type_marker] = _ENCODERS[_typ] _BUILT_IN_TYPES = tuple(t for t in _ENCODERS) diff --git a/bson/json_util.py b/bson/json_util.py index c261703d8..699f5d8f4 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -110,6 +110,7 @@ import uuid from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -835,7 +836,7 @@ def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: json_options.datetime_representation == DatetimeRepresentation.ISO8601 and 0 <= int(obj) <= _max_datetime_ms() ): - return default(obj.as_datetime(), json_options) + return _encode_datetime(obj.as_datetime(), json_options) elif json_options.datetime_representation == DatetimeRepresentation.LEGACY: return {"$date": str(int(obj))} return {"$date": {"$numberLong": str(int(obj))}} @@ -855,100 +856,180 @@ def _encode_int64(obj: Int64, json_options: JSONOptions) -> Any: return int(obj) +def _encode_noop(obj: Any, dummy0: Any) -> Any: + return obj + + +def _encode_regex(obj: Any, json_options: JSONOptions) -> dict: + flags = "" + if obj.flags & re.IGNORECASE: + flags += "i" + if obj.flags & re.LOCALE: + flags += "l" + if obj.flags & re.MULTILINE: + flags += "m" + if obj.flags & re.DOTALL: + flags += "s" + if obj.flags & re.UNICODE: + flags += "u" + if obj.flags & re.VERBOSE: + flags += "x" + if isinstance(obj.pattern, str): + pattern = obj.pattern + else: + pattern = obj.pattern.decode("utf-8") + if json_options.json_mode == JSONMode.LEGACY: + return {"$regex": pattern, "$options": flags} + return {"$regularExpression": {"pattern": pattern, "options": flags}} + + +def _encode_int(obj: int, json_options: JSONOptions) -> Any: + if json_options.json_mode == JSONMode.CANONICAL: + if -(2**31) <= obj < 2**31: + return {"$numberInt": str(obj)} + return {"$numberLong": str(obj)} + return obj + + +def _encode_float(obj: float, json_options: JSONOptions) -> Any: + if json_options.json_mode != JSONMode.LEGACY: + if math.isnan(obj): + return {"$numberDouble": "NaN"} + elif math.isinf(obj): + representation = "Infinity" if obj > 0 else "-Infinity" + return {"$numberDouble": representation} + elif json_options.json_mode == JSONMode.CANONICAL: + # repr() will return the shortest string guaranteed to produce the + # original value, when float() is called on it. + return {"$numberDouble": str(repr(obj))} + return obj + + +def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: + if json_options.datetime_representation == DatetimeRepresentation.ISO8601: + if not obj.tzinfo: + obj = obj.replace(tzinfo=utc) + assert obj.tzinfo is not None + if obj >= EPOCH_AWARE: + off = obj.tzinfo.utcoffset(obj) + if (off.days, off.seconds, off.microseconds) == (0, 0, 0): # type: ignore + tz_string = "Z" + else: + tz_string = obj.strftime("%z") + millis = int(obj.microsecond / 1000) + fracsecs = ".%03d" % (millis,) if millis else "" + return { + "$date": "{}{}{}".format(obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string) + } + + millis = _datetime_to_millis(obj) + if json_options.datetime_representation == DatetimeRepresentation.LEGACY: + return {"$date": millis} + return {"$date": {"$numberLong": str(millis)}} + + +def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict: + return _encode_binary(obj, 0, json_options) + + +def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict: + return _encode_binary(obj, obj.subtype, json_options) + + +def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: + if json_options.strict_uuid: + binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation) + return _encode_binary(binval, binval.subtype, json_options) + else: + return {"$uuid": obj.hex} + + +def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict: + return {"$oid": str(obj)} + + +def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict: + return {"$timestamp": {"t": obj.time, "i": obj.inc}} + + +def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict: + return {"$numberDecimal": str(obj)} + + +def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict: + return _json_convert(obj.as_doc(), json_options=json_options) + + +def _encode_minkey(dummy0: Any, dummy1: Any) -> dict: + return {"$minKey": 1} + + +def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: + return {"$maxKey": 1} + + # Encoders for BSON types -_encoders = { - 5: lambda obj, json_options: _encode_binary(obj, obj.subtype, json_options), # Binary - 7: lambda obj, json_options: {"$oid": str(obj)}, # noqa: ARG005 ObjectId - 9: _encode_datetimems, # DatetimeMS - 13: _encode_code, # Code - 17: lambda obj, json_options: {"$timestamp": {"t": obj.time, "i": obj.inc}}, # noqa: ARG005 Timestamp - 18: _encode_int64, # Int64 - 19: lambda obj, json_options: {"$numberDecimal": str(obj)}, # noqa: ARG005 Decimal128 - 100: lambda obj, json_options: _json_convert(obj.as_doc(), json_options=json_options), # DBRef - 127: lambda obj, json_options: {"$maxKey": 1}, # noqa: ARG005 MaxKey - 255: lambda obj, json_options: {"$minKey": 1}, # noqa: ARG005 MinKey +# Each encoder function's signature is: +# - obj: a Python data type, e.g. a Python int for _encode_int +# - json_options: a JSONOptions +_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = { + bool: _encode_noop, + bytes: _encode_bytes, + datetime.datetime: _encode_datetime, + DatetimeMS: _encode_datetimems, + float: _encode_float, + int: _encode_int, + str: _encode_noop, + type(None): _encode_noop, + uuid.UUID: _encode_uuid, + Binary: _encode_binary_obj, + Int64: _encode_int64, + Code: _encode_code, + DBRef: _encode_dbref, + MaxKey: _encode_maxkey, + MinKey: _encode_minkey, + ObjectId: _encode_objectid, + Regex: _encode_regex, + RE_TYPE: _encode_regex, + Timestamp: _encode_timestamp, + Decimal128: _encode_decimal128, } +# Map each _type_marker to its encoder for faster lookup. +_MARKERS: dict[int, Callable[[Any, JSONOptions], Any]] = {} +for _typ in _ENCODERS: + if hasattr(_typ, "_type_marker"): + _MARKERS[_typ._type_marker] = _ENCODERS[_typ] + +_BUILT_IN_TYPES = tuple(t for t in _ENCODERS) + def default(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any: - # We preserve key order when rendering SON, DBRef, etc. as JSON by - # returning a SON for those types instead of a dict. - if isinstance(obj, bool): - return obj - elif isinstance(obj, (RE_TYPE, Regex)): - flags = "" - if obj.flags & re.IGNORECASE: - flags += "i" - if obj.flags & re.LOCALE: - flags += "l" - if obj.flags & re.MULTILINE: - flags += "m" - if obj.flags & re.DOTALL: - flags += "s" - if obj.flags & re.UNICODE: - flags += "u" - if obj.flags & re.VERBOSE: - flags += "x" - if isinstance(obj.pattern, str): - pattern = obj.pattern - else: - pattern = obj.pattern.decode("utf-8") - if json_options.json_mode == JSONMode.LEGACY: - return {"$regex": pattern, "$options": flags} - return {"$regularExpression": {"pattern": pattern, "options": flags}} - elif hasattr(obj, "_type_marker"): - type_marker = obj._type_marker - try: - return _encoders[type_marker](obj, json_options) # type: ignore[no-untyped-call] - except KeyError: - raise TypeError("%r is not JSON serializable" % obj) from None - elif isinstance(obj, int): - if json_options.json_mode == JSONMode.CANONICAL: - if -(2**31) <= obj < 2**31: - return {"$numberInt": str(obj)} - return {"$numberLong": str(obj)} - return obj - elif isinstance(obj, float): - if json_options.json_mode != JSONMode.LEGACY: - if math.isnan(obj): - return {"$numberDouble": "NaN"} - elif math.isinf(obj): - representation = "Infinity" if obj > 0 else "-Infinity" - return {"$numberDouble": representation} - elif json_options.json_mode == JSONMode.CANONICAL: - # repr() will return the shortest string guaranteed to produce the - # original value, when float() is called on it. - return {"$numberDouble": str(repr(obj))} - return obj - elif isinstance(obj, str): - return obj - elif isinstance(obj, datetime.datetime): - if json_options.datetime_representation == DatetimeRepresentation.ISO8601: - if not obj.tzinfo: - obj = obj.replace(tzinfo=utc) - assert obj.tzinfo is not None - if obj >= EPOCH_AWARE: - off = obj.tzinfo.utcoffset(obj) - if (off.days, off.seconds, off.microseconds) == (0, 0, 0): # type: ignore - tz_string = "Z" - else: - tz_string = obj.strftime("%z") - millis = int(obj.microsecond / 1000) - fracsecs = ".%03d" % (millis,) if millis else "" - return { - "$date": "{}{}{}".format(obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string) - } + # First see if the type is already cached. KeyError will only ever + # happen once per subtype. + try: + return _ENCODERS[type(obj)](obj, json_options) + except KeyError: + pass + + # Second, fall back to trying _type_marker. This has to be done + # before the loop below since users could subclass one of our + # custom types that subclasses a python built-in (e.g. Binary) + if hasattr(obj, "_type_marker"): + marker = obj._type_marker + if marker in _MARKERS: + func = _MARKERS[marker] + # Cache this type for faster subsequent lookup. + _ENCODERS[type(obj)] = func + return func(obj, json_options) + + # Third, test each base type. This will only happen once for + # a subtype of a supported base type. + for base in _BUILT_IN_TYPES: + if isinstance(obj, base): + func = _ENCODERS[base] + # Cache this type for faster subsequent lookup. + _ENCODERS[type(obj)] = func + return func(obj, json_options) - millis = _datetime_to_millis(obj) - if json_options.datetime_representation == DatetimeRepresentation.LEGACY: - return {"$date": millis} - return {"$date": {"$numberLong": str(millis)}} - elif isinstance(obj, bytes): - return _encode_binary(obj, 0, json_options) - elif isinstance(obj, uuid.UUID): - if json_options.strict_uuid: - binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation) - return _encode_binary(binval, binval.subtype, json_options) - else: - return {"$uuid": obj.hex} raise TypeError("%r is not JSON serializable" % obj) diff --git a/test/test_bson.py b/test/test_bson.py index 763885e5f..89c0983ca 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -48,7 +48,7 @@ from bson import ( encode, is_valid, ) -from bson.binary import Binary, UuidRepresentation +from bson.binary import USER_DEFINED_SUBTYPE, Binary, UuidRepresentation from bson.code import Code from bson.codec_options import CodecOptions, DatetimeConversion from bson.datetime_ms import _DATETIME_ERROR_SUGGESTION @@ -772,6 +772,21 @@ class TestBSON(unittest.TestCase): self.assertEqual(type(value), orig_type) self.assertEqual(value, orig_type(value)) + def test_encode_type_marker(self): + # Assert that a custom subclass can be BSON encoded based on the _type_marker attribute. + class MyMaxKey: + _type_marker = 127 + + expected_bson = encode({"a": MaxKey()}) + self.assertEqual(encode({"a": MyMaxKey()}), expected_bson) + + # Test a class that inherits from two built in types + class MyBinary(Binary): + pass + + expected_bson = encode({"a": Binary(b"bin", USER_DEFINED_SUBTYPE)}) + self.assertEqual(encode({"a": MyBinary(b"bin", USER_DEFINED_SUBTYPE)}), expected_bson) + def test_ordered_dict(self): d = OrderedDict([("one", 1), ("two", 2), ("three", 3), ("four", 4)]) self.assertEqual(d, decode(encode(d), CodecOptions(document_class=OrderedDict))) # type: ignore[type-var] diff --git a/test/test_json_util.py b/test/test_json_util.py index 43007808c..74cf12f33 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -20,7 +20,7 @@ import json import re import sys import uuid -from typing import Any, List, MutableMapping +from typing import Any, List, MutableMapping, Tuple, Type from bson.codec_options import CodecOptions, DatetimeConversion @@ -40,9 +40,12 @@ from bson.binary import ( from bson.code import Code from bson.datetime_ms import _max_datetime_ms from bson.dbref import DBRef +from bson.decimal128 import Decimal128 from bson.int64 import Int64 from bson.json_util import ( + CANONICAL_JSON_OPTIONS, LEGACY_JSON_OPTIONS, + RELAXED_JSON_OPTIONS, DatetimeRepresentation, JSONMode, JSONOptions, @@ -564,6 +567,56 @@ class TestJsonUtil(unittest.TestCase): json_util.loads('{"foo": "bar", "b": 1}', json_options=JSONOptions(document_class=SON)), ) + def test_encode_subclass(self): + cases: list[Tuple[Type, Any]] = [ + (int, (1,)), + (int, (2 << 60,)), + (float, (1.1,)), + (Int64, (64,)), + (Int64, (2 << 60,)), + (str, ("str",)), + (bytes, (b"bytes",)), + (datetime.datetime, (2024, 1, 16)), + (DatetimeMS, (1,)), + (uuid.UUID, ("f47ac10b-58cc-4372-a567-0e02b2c3d479",)), + (Binary, (b"1", USER_DEFINED_SUBTYPE)), + (Code, ("code",)), + (DBRef, ("coll", ObjectId())), + (ObjectId, ("65a6dab5f98bc03906ee3597",)), + (MaxKey, ()), + (MinKey, ()), + (Regex, ("pat",)), + (Timestamp, (1, 1)), + (Decimal128, ("0.5",)), + ] + allopts = [ + CANONICAL_JSON_OPTIONS.with_options(uuid_representation=STANDARD), + RELAXED_JSON_OPTIONS.with_options(uuid_representation=STANDARD), + LEGACY_JSON_OPTIONS.with_options(uuid_representation=STANDARD), + ] + for cls, args in cases: + basic_obj = cls(*args) + my_cls = type(f"My{cls.__name__}", (cls,), {}) + my_obj = my_cls(*args) + for opts in allopts: + expected_json = json_util.dumps(basic_obj, json_options=opts) + self.assertEqual(json_util.dumps(my_obj, json_options=opts), expected_json) + + def test_encode_type_marker(self): + # Assert that a custom subclass can be JSON encoded based on the _type_marker attribute. + class MyMaxKey: + _type_marker = 127 + + expected_json = json_util.dumps(MaxKey()) + self.assertEqual(json_util.dumps(MyMaxKey()), expected_json) + + # Test a class that inherits from two built in types + class MyBinary(Binary): + pass + + expected_json = json_util.dumps(Binary(b"bin", USER_DEFINED_SUBTYPE)) + self.assertEqual(json_util.dumps(MyBinary(b"bin", USER_DEFINED_SUBTYPE)), expected_json) + class TestJsonUtilRoundtrip(IntegrationTest): def test_cursor(self):