From e01efc707376efc43cb3711d370581957833afb1 Mon Sep 17 00:00:00 2001 From: Prashant Mital Date: Wed, 13 Mar 2019 10:33:18 -0700 Subject: [PATCH] PYTHON-1731 Implement callback for unencodable types --- bson/__init__.py | 14 +++- bson/_cbsonmodule.c | 36 +++++++- bson/_cbsonmodule.h | 2 + bson/codec_options.py | 41 ++++++--- doc/examples/custom_type.rst | 148 ++++++++++++++++++++++++++++++++- doc/faq.rst | 3 +- test/test_bson.py | 62 ++++++++++---- test/test_bson_custom_types.py | 47 ++++++++++- 8 files changed, 317 insertions(+), 36 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index 29b1af17e..c5675710e 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -752,7 +752,8 @@ if not PY3: _ENCODERS[long] = _encode_long -def _name_value_to_bson(name, value, check_keys, opts): +def _name_value_to_bson(name, value, check_keys, opts, + in_fallback_call=False): """Encode a single name, value pair.""" # Custom encoder (if any) takes precedence over default encoders. # Using 'if' instead of 'try...except' for performance since this will @@ -789,8 +790,15 @@ def _name_value_to_bson(name, value, check_keys, opts): _ENCODERS[type(value)] = func return func(name, value, check_keys, opts) - raise InvalidDocument("cannot convert value of type %s to bson" % - type(value)) + # As a last resort, try using the fallback encoder, if the user has + # provided one. + fallback_encoder = opts.type_registry._fallback_encoder + if not in_fallback_call and fallback_encoder is not None: + return _name_value_to_bson( + name, fallback_encoder(value), check_keys, opts, True) + + raise InvalidDocument( + "cannot convert value of type %s to bson" % type(value)) def _element_to_bson(key, value, check_keys, opts): diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 6d4f8241b..ca4b13108 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -119,7 +119,8 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, static int _write_element_to_buffer(PyObject* self, buffer_t buffer, int type_byte, PyObject* value, unsigned char check_keys, - const codec_options_t* options); + const codec_options_t* options, + unsigned char in_fallback_call); /* Date stuff */ static PyObject* datetime_from_millis(long long millis) { @@ -455,6 +456,7 @@ static long _type_marker(PyObject* object) { int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) { registry->encoder_map = NULL; registry->decoder_map = NULL; + registry->fallback_encoder = NULL; registry->registry_obj = NULL; registry->encoder_map = PyObject_GetAttrString(registry_obj, "_encoder_map"); @@ -469,6 +471,12 @@ int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) { } registry->is_decoder_empty = (PyDict_Size(registry->decoder_map) == 0); + registry->fallback_encoder = PyObject_GetAttrString(registry_obj, "_fallback_encoder"); + if (registry->fallback_encoder == NULL) { + goto fail; + } + registry->has_fallback_encoder = (registry->fallback_encoder != Py_None); + registry->registry_obj = registry_obj; Py_INCREF(registry->registry_obj); return 1; @@ -476,6 +484,7 @@ int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) { fail: Py_XDECREF(registry->encoder_map); Py_XDECREF(registry->decoder_map); + Py_XDECREF(registry->fallback_encoder); return 0; } @@ -548,6 +557,7 @@ void destroy_codec_options(codec_options_t* options) { Py_CLEAR(options->type_registry.registry_obj); Py_CLEAR(options->type_registry.encoder_map); Py_CLEAR(options->type_registry.decoder_map); + Py_CLEAR(options->type_registry.fallback_encoder); } static int write_element_to_buffer(PyObject* self, buffer_t buffer, @@ -580,7 +590,7 @@ static int write_element_to_buffer(PyObject* self, buffer_t buffer, } } result = _write_element_to_buffer(self, buffer, type_byte, - value, check_keys, options); + value, check_keys, options, 0); fail: Py_XDECREF(value_type); @@ -770,9 +780,12 @@ static int _write_regex_to_buffer( static int _write_element_to_buffer(PyObject* self, buffer_t buffer, int type_byte, PyObject* value, unsigned char check_keys, - const codec_options_t* options) { + const codec_options_t* options, + unsigned char in_fallback_call) { struct module_state *state = GETSTATE(self); PyObject* mapping_type; + PyObject* new_value = NULL; + int retval; PyObject* uuid_type; /* * Don't use PyObject_IsInstance for our custom types. It causes @@ -1363,6 +1376,23 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, } Py_XDECREF(mapping_type); Py_XDECREF(uuid_type); + + /* Try the fallback encoder if one is provided and we have not already + * attempted to use the fallback encoder. */ + if (!in_fallback_call && options->type_registry.has_fallback_encoder) { + new_value = PyObject_CallFunctionObjArgs( + options->type_registry.fallback_encoder, value, NULL); + if (new_value == NULL) { + // propagate any exception raised by the callback + return 0; + } + retval = _write_element_to_buffer(self, buffer, type_byte, new_value, + check_keys, options, 1); + Py_XDECREF(new_value); + return retval; + } + Py_XDECREF(new_value); + /* We can't determine value's type. Fail. */ _set_cannot_encode(value); return 0; diff --git a/bson/_cbsonmodule.h b/bson/_cbsonmodule.h index 10bb15fcf..237eea037 100644 --- a/bson/_cbsonmodule.h +++ b/bson/_cbsonmodule.h @@ -54,9 +54,11 @@ typedef struct type_registry_t { PyObject* encoder_map; PyObject* decoder_map; + PyObject* fallback_encoder; PyObject* registry_obj; unsigned char is_encoder_empty; unsigned char is_decoder_empty; + unsigned char has_fallback_encoder; } type_registry_t; typedef struct codec_options_t { diff --git a/bson/codec_options.py b/bson/codec_options.py index 3194a0e7f..53331832c 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -63,22 +63,38 @@ class TypeCodecBase(object): class TypeRegistry(object): - """Encapsulates type codecs used in encoding and / or decoding BSON. + """Encapsulates type codecs used in encoding and / or decoding BSON, as + well as the fallback encoder. Type registries cannot be modified after + instantiation. - ``TypeRegistry`` can be initialized with an arbitrary number of type - codecs:: + ``TypeRegistry`` can be initialized with an iterable of type codecs, and + a callable for the fallback encoder:: >>> from bson.codec_options import TypeRegistry - >>> type_registry = TypeRegistry(Codec1, Codec2, Codec3, ...) + >>> type_registry = TypeRegistry([Codec1, Codec2, Codec3, ...], + ... fallback_encoder) - If multiple codecs try to transform a single python or BSON type, - the transformation described by the last type codec prevails. + :Parameters: + - `type_codecs` (optional): iterable of type codec instances. If + ``type_codecs`` contains multiple codecs that transform a single + python or BSON type, the transformation specified by the type codec + occurring last prevails. + - `fallback_encoder` (optional): callable that accepts a single, + unencodable python value and transforms it into a type that BSON can + encode. """ - def __init__(self, *type_codecs): - self.__args = type_codecs + def __init__(self, type_codecs=None, fallback_encoder=None): + self.__type_codecs = list(type_codecs or []) + self._fallback_encoder = fallback_encoder self._encoder_map = {} self._decoder_map = {} - for codec in type_codecs: + + if self._fallback_encoder is not None: + if not callable(fallback_encoder): + raise TypeError("fallback_encoder %r is not a callable" % ( + fallback_encoder)) + + for codec in self.__type_codecs: if not isinstance(codec, TypeCodecBase): raise TypeError( "Expected an instance of %s, got %r instead" % ( @@ -98,13 +114,16 @@ class TypeRegistry(object): self._decoder_map[bson_type] = codec.transform_bson def __repr__(self): - return '%s%r' % (self.__class__.__name__, self.__args) + return ('%s(type_codecs=%r, fallback_encoder=%r)' % ( + self.__class__.__name__, self.__type_codecs, + self._fallback_encoder)) def __eq__(self, other): if not isinstance(other, type(self)): return NotImplemented return ((self._decoder_map == other._decoder_map) and - (self._encoder_map == other._encoder_map)) + (self._encoder_map == other._encoder_map) and + (self._fallback_encoder == other._fallback_encoder)) _options_base = namedtuple( diff --git a/doc/examples/custom_type.rst b/doc/examples/custom_type.rst index cbdf35ee1..2311baca0 100644 --- a/doc/examples/custom_type.rst +++ b/doc/examples/custom_type.rst @@ -47,6 +47,8 @@ try to save an instance of ``Decimal`` with PyMongo, we get an The Type Codec -------------- +.. versionadded:: 3.8 + In order to encode custom types, we must first define a **type codec** for our type. A type codec describes how an instance of a custom type can be *transformed* to/from one of the types :mod:`~bson` already understands, and @@ -94,6 +96,8 @@ The type codec for our custom type simply needs to define how a The Type Registry ----------------- +.. versionadded:: 3.8 + Before we can begin encoding and decoding our custom type objects, we must first inform PyMongo about our type codec. This is done by creating a :class:`~bson.codec_options.TypeRegistry` instance: @@ -101,7 +105,7 @@ first inform PyMongo about our type codec. This is done by creating a .. doctest:: >>> from bson.codec_options import TypeRegistry - >>> type_registry = TypeRegistry(decimal_codec) + >>> type_registry = TypeRegistry([decimal_codec]) Note that type registries can be instantiated with any number of type codecs. @@ -146,3 +150,145 @@ MongoDB: >>> vanilla_collection = db.get_collection('test') >>> pprint.pprint(vanilla_collection.find_one()) {u'_id': ObjectId('...'), u'num': Decimal128('45.321')} + + +Encoding Subtypes +^^^^^^^^^^^^^^^^^ + +Consider the situation where, in addition to encoding +:py:class:`~decimal.Decimal`, we also need to encode a type that subclasses +``Decimal``. PyMongo does this automatically for types that inherit from +Python types that are BSON-encodable by default, but the type codec system +described above does not offer the same flexibility. + +Consider this subtype of ``Decimal`` that has a method to return its value as +an integer: + +.. doctest:: + + >>> class DecimalInt(Decimal): + ... def my_method(self): + ... """Method implementing some custom logic.""" + ... return int(self) + +If we try to save an instance of this type without first registering a type +codec for it, we get an error: + +.. doctest:: + + >>> collection.insert_one({'num': DecimalInt("45.321")}) + Traceback (most recent call last): + ... + bson.errors.InvalidDocument: Cannot encode object: Decimal('45.321') + +In order to proceed further, we must define a type codec for ``DecimalInt``. +This is trivial to do since the same transformation as the one used for +``Decimal`` is adequate for encoding ``DecimalInt`` as well: + +.. doctest:: + + >>> class DecimalIntCodec(DecimalCodec): + ... @property + ... def python_type(self): + ... """The Python type acted upon by this type codec.""" + ... return DecimalInt + >>> decimalint_codec = DecimalIntCodec() + + +.. note:: + + No attempt is made to modify decoding behavior because without additional + information, it is impossible to discern which incoming + :class:`~bson.decimal128.Decimal128` value needs to be decoded as ``Decimal`` + and which needs to be decoded as ``DecimalInt``. This example only considers + the situation where a user wants to *encode* documents containing one or both + of these types. + +Now, we can create a new codec options object and use it to get a collection +object: + +.. doctest:: + + >>> type_registry = TypeRegistry([decimal_codec, decimalint_codec]) + >>> codec_options = CodecOptions(type_registry=type_registry) + >>> collection = db.get_collection('test', codec_options=codec_options) + >>> collection.drop() + + +We can now seamlessly encode instances of ``DecimalInt``. Note that the +``transform_bson`` method of the base codec class results in these values +being decoded as ``Decimal`` (and not ``DecimalInt``): + +.. doctest:: + + >>> collection.insert_one({'num': DecimalInt("45.321")}) + + >>> mydoc = collection.find_one() + >>> pprint.pprint(mydoc) + {u'_id': ObjectId('...'), u'num': Decimal('45.321')} + + +The Fallback Encoder +-------------------- + +.. versionadded:: 3.8 + + +In addition to type codecs, users can also register a callable to encode types +that BSON doesn't recognize and for which no type codec has been registered. +This callable is the **fallback encoder** and like the ``transform_python`` +method, it accepts an unencodable value as a parameter and returns a +BSON-encodable value. The following fallback encoder encodes python's +:py:class:`~decimal.Decimal` type to a :class:`~bson.decimal128.Decimal128`: + +.. doctest:: + + >>> def fallback_encoder(value): + ... if isinstance(value, Decimal): + ... return Decimal128(value) + ... return value + +After declaring the callback, we must create a type registry and codec options +with this fallback encoder before it can be used for initializing a collection: + +.. doctest:: + + >>> type_registry = TypeRegistry(fallback_encoder=fallback_encoder) + >>> codec_options = CodecOptions(type_registry=type_registry) + >>> collection = db.get_collection('test', codec_options=codec_options) + >>> collection.drop() + +We can now seamlessly encode instances of :py:class:`~decimal.Decimal`: + +.. doctest:: + + >>> collection.insert_one({'num': Decimal("45.321")}) + + >>> mydoc = collection.find_one() + >>> pprint.pprint(mydoc) + {u'_id': ObjectId('...'), u'num': Decimal128('45.321')} + +As you can tell, fallback encoders are a compelling alternative to type codecs +when we only want to encode custom types due to their much simpler API. +Users should note however, that fallback encoders cannot be used to modify the +encoding of types that PyMongo already understands, as illustrated by the +following example: + + >>> def fallback_encoder(value): + ... """Encoder that converts floats to int.""" + ... if isinstance(value, float): + ... return int(value) + ... return value + >>> type_registry = TypeRegistry(fallback_encoder=fallback_encoder) + >>> codec_options = CodecOptions(type_registry=type_registry) + >>> collection = db.get_collection('test', codec_options=codec_options) + >>> collection.drop() + >>> collection.insert_one({'num': 45.321}) + + >>> mydoc = collection.find_one() + >>> pprint.pprint(mydoc) + {u'_id': ObjectId('...'), u'num': 45.321} + +This is due to the fact that fallback encoders are invoked only after +an attempt to encode the value with type codecs and standard BSON encoding +routines has been unsuccessful. \ No newline at end of file diff --git a/doc/faq.rst b/doc/faq.rst index b27cc1427..c425c438a 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -248,7 +248,8 @@ collection, configured to use :class:`~bson.son.SON` instead of dict: tz_aware=False, uuid_representation=PYTHON_LEGACY, unicode_decode_error_handler='strict', - tzinfo=None) + tzinfo=None, type_registry=TypeRegistry(type_codecs=[], + fallback_encoder=None)) >>> collection_son = collection.with_options(codec_options=opts) Now, documents and subdocuments in query results are represented with diff --git a/test/test_bson.py b/test/test_bson.py index 520dd7824..9684eec08 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -949,47 +949,76 @@ class TestTypeRegistry(unittest.TestCase): def transform_bson(self, value): return MyStrType(value) + def fallback_encoder(value): + return value + cls.types = (MyIntType, MyStrType) cls.codecs = (MyIntCodec, MyStrCodec) + cls.fallback_encoder = fallback_encoder def test_simple(self): codec_instances = [codec() for codec in self.codecs] - type_registry = TypeRegistry(*codec_instances) - self.assertEqual(type_registry._encoder_map, { - self.types[0]: codec_instances[0].transform_python, - self.types[1]: codec_instances[1].transform_python}) - self.assertEqual(type_registry._decoder_map, { - int: codec_instances[0].transform_bson, - str: codec_instances[1].transform_bson}) + def assert_proper_initialization(type_registry, codec_instances): + self.assertEqual(type_registry._encoder_map, { + self.types[0]: codec_instances[0].transform_python, + self.types[1]: codec_instances[1].transform_python}) + self.assertEqual(type_registry._decoder_map, { + int: codec_instances[0].transform_bson, + str: codec_instances[1].transform_bson}) + self.assertEqual( + type_registry._fallback_encoder, self.fallback_encoder) + + type_registry = TypeRegistry(codec_instances, self.fallback_encoder) + assert_proper_initialization(type_registry, codec_instances) + + type_registry = TypeRegistry( + fallback_encoder=self.fallback_encoder, type_codecs=codec_instances) + assert_proper_initialization(type_registry, codec_instances) + + # Ensure codec list held by the type registry doesn't change if we + # mutate the initial list. + codec_instances_copy = list(codec_instances) + codec_instances.pop(0) + self.assertListEqual( + type_registry._TypeRegistry__type_codecs, codec_instances_copy) def test_initialize_fail(self): err_msg = "Expected an instance of TypeCodecBase, got .* instead" with self.assertRaisesRegex(TypeError, err_msg): - TypeRegistry(*self.codecs) + TypeRegistry(self.codecs) with self.assertRaisesRegex(TypeError, err_msg): - TypeRegistry(type('AnyType', (object,), {})()) + TypeRegistry([type('AnyType', (object,), {})()]) + + err_msg = "fallback_encoder %r is not a callable" % (True,) + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry([], True) + + err_msg = "fallback_encoder %r is not a callable" % ('hello',) + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry(fallback_encoder='hello') def test_not_implemented(self): - type_registry = TypeRegistry(type("codec1", (TypeCodecBase, ), {})(), - type("codec2", (TypeCodecBase, ), {})()) + type_registry = TypeRegistry([type("codec1", (TypeCodecBase, ), {})(), + type("codec2", (TypeCodecBase, ), {})()]) self.assertEqual(type_registry._encoder_map, {}) self.assertEqual(type_registry._decoder_map, {}) def test_type_registry_repr(self): codec_instances = [codec() for codec in self.codecs] - type_registry = TypeRegistry(*codec_instances) - r = ("TypeRegistry(%s, %s)" % tuple(codec_instances)) + type_registry = TypeRegistry(codec_instances) + r = ("TypeRegistry(type_codecs=%r, fallback_encoder=%r)" % ( + codec_instances, None)) self.assertEqual(r, repr(type_registry)) def test_type_registry_eq(self): codec_instances = [codec() for codec in self.codecs] self.assertEqual( - TypeRegistry(*codec_instances), TypeRegistry(*codec_instances)) + TypeRegistry(codec_instances), TypeRegistry(codec_instances)) codec_instances_2 = [codec() for codec in self.codecs] self.assertNotEqual( - TypeRegistry(*codec_instances), TypeRegistry(*codec_instances_2)) + TypeRegistry(codec_instances), TypeRegistry(codec_instances_2)) class TestCodecOptions(unittest.TestCase): @@ -1017,7 +1046,8 @@ class TestCodecOptions(unittest.TestCase): r = ("CodecOptions(document_class=dict, tz_aware=False, " "uuid_representation=PYTHON_LEGACY, " "unicode_decode_error_handler='strict', " - "tzinfo=None, type_registry=TypeRegistry())") + "tzinfo=None, type_registry=TypeRegistry(type_codecs=[], " + "fallback_encoder=None))") self.assertEqual(r, repr(CodecOptions())) def test_decode_all_defaults(self): diff --git a/test/test_bson_custom_types.py b/test/test_bson_custom_types.py index c3a8f78a3..16d07031a 100644 --- a/test/test_bson_custom_types.py +++ b/test/test_bson_custom_types.py @@ -29,6 +29,7 @@ from bson import (BSON, _dict_to_bson, _bson_to_dict) from bson.codec_options import CodecOptions, TypeCodecBase, TypeRegistry +from bson.errors import InvalidDocument from test import unittest @@ -52,7 +53,7 @@ class DecimalCodec(TypeCodecBase): class TestCustomPythonTypeToBSON(unittest.TestCase): @classmethod def setUpClass(cls): - type_registry = TypeRegistry(DecimalCodec()) + type_registry = TypeRegistry((DecimalCodec(),)) codec_options = CodecOptions(type_registry=type_registry) cls.codecopts = codec_options @@ -114,6 +115,50 @@ class TestCustomPythonTypeToBSON(unittest.TestCase): fileobj.close() +class TestFallbackEncoder(unittest.TestCase): + def _get_codec_options(self, fallback_encoder): + type_registry = TypeRegistry(fallback_encoder=fallback_encoder) + return CodecOptions(type_registry=type_registry) + + def test_simple(self): + codecopts = self._get_codec_options(lambda x: Decimal128(x)) + document = {'average': Decimal('56.47')} + bsonbytes = BSON().encode(document, codec_options=codecopts) + + exp_document = {'average': Decimal128('56.47')} + exp_bsonbytes = BSON().encode(exp_document) + self.assertEqual(bsonbytes, exp_bsonbytes) + + def test_erroring_fallback_encoder(self): + codecopts = self._get_codec_options(lambda _: 1/0) + + # fallback converter should not be invoked when encoding known types. + BSON().encode( + {'a': 1, 'b': Decimal128('1.01'), 'c': {'arr': ['abc', 3.678]}}, + codec_options=codecopts) + + # expect an error when encoding a custom type. + document = {'average': Decimal('56.47')} + with self.assertRaises(ZeroDivisionError): + BSON().encode(document, codec_options=codecopts) + + def test_noop_fallback_encoder(self): + codecopts = self._get_codec_options(lambda x: x) + document = {'average': Decimal('56.47')} + with self.assertRaises(InvalidDocument): + BSON().encode(document, codec_options=codecopts) + + def test_type_unencodable_by_fallback_encoder(self): + def fallback_encoder(value): + try: + return Decimal128(value) + except: + raise TypeError("cannot encode type %s" % (type(value))) + codecopts = self._get_codec_options(fallback_encoder) + document = {'average': Decimal} + with self.assertRaises(TypeError): + BSON().encode(document, codec_options=codecopts) + if __name__ == "__main__": unittest.main()