PYTHON-1731 Implement callback for unencodable types
(cherry picked from commit e01efc7073)
This commit is contained in:
parent
3ef4aa982c
commit
9093ddf365
@ -752,7 +752,8 @@ if not PY3:
|
||||
_ENCODERS[long] = _encode_long
|
||||
|
||||
|
||||
def _name_value_to_bson(name, value, check_keys, opts):
|
||||
def _name_value_to_bson(name, value, check_keys, opts,
|
||||
in_fallback_call=False):
|
||||
"""Encode a single name, value pair."""
|
||||
# Custom encoder (if any) takes precedence over default encoders.
|
||||
# Using 'if' instead of 'try...except' for performance since this will
|
||||
@ -789,8 +790,15 @@ def _name_value_to_bson(name, value, check_keys, opts):
|
||||
_ENCODERS[type(value)] = func
|
||||
return func(name, value, check_keys, opts)
|
||||
|
||||
raise InvalidDocument("cannot convert value of type %s to bson" %
|
||||
type(value))
|
||||
# As a last resort, try using the fallback encoder, if the user has
|
||||
# provided one.
|
||||
fallback_encoder = opts.type_registry._fallback_encoder
|
||||
if not in_fallback_call and fallback_encoder is not None:
|
||||
return _name_value_to_bson(
|
||||
name, fallback_encoder(value), check_keys, opts, True)
|
||||
|
||||
raise InvalidDocument(
|
||||
"cannot convert value of type %s to bson" % type(value))
|
||||
|
||||
|
||||
def _element_to_bson(key, value, check_keys, opts):
|
||||
|
||||
@ -119,7 +119,8 @@ static PyObject* elements_to_dict(PyObject* self, const char* string,
|
||||
static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
|
||||
int type_byte, PyObject* value,
|
||||
unsigned char check_keys,
|
||||
const codec_options_t* options);
|
||||
const codec_options_t* options,
|
||||
unsigned char in_fallback_call);
|
||||
|
||||
/* Date stuff */
|
||||
static PyObject* datetime_from_millis(long long millis) {
|
||||
@ -455,6 +456,7 @@ static long _type_marker(PyObject* object) {
|
||||
int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) {
|
||||
registry->encoder_map = NULL;
|
||||
registry->decoder_map = NULL;
|
||||
registry->fallback_encoder = NULL;
|
||||
registry->registry_obj = NULL;
|
||||
|
||||
registry->encoder_map = PyObject_GetAttrString(registry_obj, "_encoder_map");
|
||||
@ -469,6 +471,12 @@ int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) {
|
||||
}
|
||||
registry->is_decoder_empty = (PyDict_Size(registry->decoder_map) == 0);
|
||||
|
||||
registry->fallback_encoder = PyObject_GetAttrString(registry_obj, "_fallback_encoder");
|
||||
if (registry->fallback_encoder == NULL) {
|
||||
goto fail;
|
||||
}
|
||||
registry->has_fallback_encoder = (registry->fallback_encoder != Py_None);
|
||||
|
||||
registry->registry_obj = registry_obj;
|
||||
Py_INCREF(registry->registry_obj);
|
||||
return 1;
|
||||
@ -476,6 +484,7 @@ int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) {
|
||||
fail:
|
||||
Py_XDECREF(registry->encoder_map);
|
||||
Py_XDECREF(registry->decoder_map);
|
||||
Py_XDECREF(registry->fallback_encoder);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -548,6 +557,7 @@ void destroy_codec_options(codec_options_t* options) {
|
||||
Py_CLEAR(options->type_registry.registry_obj);
|
||||
Py_CLEAR(options->type_registry.encoder_map);
|
||||
Py_CLEAR(options->type_registry.decoder_map);
|
||||
Py_CLEAR(options->type_registry.fallback_encoder);
|
||||
}
|
||||
|
||||
static int write_element_to_buffer(PyObject* self, buffer_t buffer,
|
||||
@ -580,7 +590,7 @@ static int write_element_to_buffer(PyObject* self, buffer_t buffer,
|
||||
}
|
||||
}
|
||||
result = _write_element_to_buffer(self, buffer, type_byte,
|
||||
value, check_keys, options);
|
||||
value, check_keys, options, 0);
|
||||
|
||||
fail:
|
||||
Py_XDECREF(value_type);
|
||||
@ -770,9 +780,12 @@ static int _write_regex_to_buffer(
|
||||
static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
|
||||
int type_byte, PyObject* value,
|
||||
unsigned char check_keys,
|
||||
const codec_options_t* options) {
|
||||
const codec_options_t* options,
|
||||
unsigned char in_fallback_call) {
|
||||
struct module_state *state = GETSTATE(self);
|
||||
PyObject* mapping_type;
|
||||
PyObject* new_value = NULL;
|
||||
int retval;
|
||||
PyObject* uuid_type;
|
||||
/*
|
||||
* Don't use PyObject_IsInstance for our custom types. It causes
|
||||
@ -1363,6 +1376,23 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
|
||||
}
|
||||
Py_XDECREF(mapping_type);
|
||||
Py_XDECREF(uuid_type);
|
||||
|
||||
/* Try the fallback encoder if one is provided and we have not already
|
||||
* attempted to use the fallback encoder. */
|
||||
if (!in_fallback_call && options->type_registry.has_fallback_encoder) {
|
||||
new_value = PyObject_CallFunctionObjArgs(
|
||||
options->type_registry.fallback_encoder, value, NULL);
|
||||
if (new_value == NULL) {
|
||||
// propagate any exception raised by the callback
|
||||
return 0;
|
||||
}
|
||||
retval = _write_element_to_buffer(self, buffer, type_byte, new_value,
|
||||
check_keys, options, 1);
|
||||
Py_XDECREF(new_value);
|
||||
return retval;
|
||||
}
|
||||
Py_XDECREF(new_value);
|
||||
|
||||
/* We can't determine value's type. Fail. */
|
||||
_set_cannot_encode(value);
|
||||
return 0;
|
||||
|
||||
@ -54,9 +54,11 @@
|
||||
typedef struct type_registry_t {
|
||||
PyObject* encoder_map;
|
||||
PyObject* decoder_map;
|
||||
PyObject* fallback_encoder;
|
||||
PyObject* registry_obj;
|
||||
unsigned char is_encoder_empty;
|
||||
unsigned char is_decoder_empty;
|
||||
unsigned char has_fallback_encoder;
|
||||
} type_registry_t;
|
||||
|
||||
typedef struct codec_options_t {
|
||||
|
||||
@ -63,22 +63,38 @@ class TypeCodecBase(object):
|
||||
|
||||
|
||||
class TypeRegistry(object):
|
||||
"""Encapsulates type codecs used in encoding and / or decoding BSON.
|
||||
"""Encapsulates type codecs used in encoding and / or decoding BSON, as
|
||||
well as the fallback encoder. Type registries cannot be modified after
|
||||
instantiation.
|
||||
|
||||
``TypeRegistry`` can be initialized with an arbitrary number of type
|
||||
codecs::
|
||||
``TypeRegistry`` can be initialized with an iterable of type codecs, and
|
||||
a callable for the fallback encoder::
|
||||
|
||||
>>> from bson.codec_options import TypeRegistry
|
||||
>>> type_registry = TypeRegistry(Codec1, Codec2, Codec3, ...)
|
||||
>>> type_registry = TypeRegistry([Codec1, Codec2, Codec3, ...],
|
||||
... fallback_encoder)
|
||||
|
||||
If multiple codecs try to transform a single python or BSON type,
|
||||
the transformation described by the last type codec prevails.
|
||||
:Parameters:
|
||||
- `type_codecs` (optional): iterable of type codec instances. If
|
||||
``type_codecs`` contains multiple codecs that transform a single
|
||||
python or BSON type, the transformation specified by the type codec
|
||||
occurring last prevails.
|
||||
- `fallback_encoder` (optional): callable that accepts a single,
|
||||
unencodable python value and transforms it into a type that BSON can
|
||||
encode.
|
||||
"""
|
||||
def __init__(self, *type_codecs):
|
||||
self.__args = type_codecs
|
||||
def __init__(self, type_codecs=None, fallback_encoder=None):
|
||||
self.__type_codecs = list(type_codecs or [])
|
||||
self._fallback_encoder = fallback_encoder
|
||||
self._encoder_map = {}
|
||||
self._decoder_map = {}
|
||||
for codec in type_codecs:
|
||||
|
||||
if self._fallback_encoder is not None:
|
||||
if not callable(fallback_encoder):
|
||||
raise TypeError("fallback_encoder %r is not a callable" % (
|
||||
fallback_encoder))
|
||||
|
||||
for codec in self.__type_codecs:
|
||||
if not isinstance(codec, TypeCodecBase):
|
||||
raise TypeError(
|
||||
"Expected an instance of %s, got %r instead" % (
|
||||
@ -98,13 +114,16 @@ class TypeRegistry(object):
|
||||
self._decoder_map[bson_type] = codec.transform_bson
|
||||
|
||||
def __repr__(self):
|
||||
return '%s%r' % (self.__class__.__name__, self.__args)
|
||||
return ('%s(type_codecs=%r, fallback_encoder=%r)' % (
|
||||
self.__class__.__name__, self.__type_codecs,
|
||||
self._fallback_encoder))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return ((self._decoder_map == other._decoder_map) and
|
||||
(self._encoder_map == other._encoder_map))
|
||||
(self._encoder_map == other._encoder_map) and
|
||||
(self._fallback_encoder == other._fallback_encoder))
|
||||
|
||||
|
||||
_options_base = namedtuple(
|
||||
|
||||
@ -47,6 +47,8 @@ try to save an instance of ``Decimal`` with PyMongo, we get an
|
||||
The Type Codec
|
||||
--------------
|
||||
|
||||
.. versionadded:: 3.8
|
||||
|
||||
In order to encode custom types, we must first define a **type codec** for our
|
||||
type. A type codec describes how an instance of a custom type can be
|
||||
*transformed* to/from one of the types :mod:`~bson` already understands, and
|
||||
@ -94,6 +96,8 @@ The type codec for our custom type simply needs to define how a
|
||||
The Type Registry
|
||||
-----------------
|
||||
|
||||
.. versionadded:: 3.8
|
||||
|
||||
Before we can begin encoding and decoding our custom type objects, we must
|
||||
first inform PyMongo about our type codec. This is done by creating a
|
||||
:class:`~bson.codec_options.TypeRegistry` instance:
|
||||
@ -101,7 +105,7 @@ first inform PyMongo about our type codec. This is done by creating a
|
||||
.. doctest::
|
||||
|
||||
>>> from bson.codec_options import TypeRegistry
|
||||
>>> type_registry = TypeRegistry(decimal_codec)
|
||||
>>> type_registry = TypeRegistry([decimal_codec])
|
||||
|
||||
|
||||
Note that type registries can be instantiated with any number of type codecs.
|
||||
@ -146,3 +150,145 @@ MongoDB:
|
||||
>>> vanilla_collection = db.get_collection('test')
|
||||
>>> pprint.pprint(vanilla_collection.find_one())
|
||||
{u'_id': ObjectId('...'), u'num': Decimal128('45.321')}
|
||||
|
||||
|
||||
Encoding Subtypes
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
Consider the situation where, in addition to encoding
|
||||
:py:class:`~decimal.Decimal`, we also need to encode a type that subclasses
|
||||
``Decimal``. PyMongo does this automatically for types that inherit from
|
||||
Python types that are BSON-encodable by default, but the type codec system
|
||||
described above does not offer the same flexibility.
|
||||
|
||||
Consider this subtype of ``Decimal`` that has a method to return its value as
|
||||
an integer:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> class DecimalInt(Decimal):
|
||||
... def my_method(self):
|
||||
... """Method implementing some custom logic."""
|
||||
... return int(self)
|
||||
|
||||
If we try to save an instance of this type without first registering a type
|
||||
codec for it, we get an error:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> collection.insert_one({'num': DecimalInt("45.321")})
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
bson.errors.InvalidDocument: Cannot encode object: Decimal('45.321')
|
||||
|
||||
In order to proceed further, we must define a type codec for ``DecimalInt``.
|
||||
This is trivial to do since the same transformation as the one used for
|
||||
``Decimal`` is adequate for encoding ``DecimalInt`` as well:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> class DecimalIntCodec(DecimalCodec):
|
||||
... @property
|
||||
... def python_type(self):
|
||||
... """The Python type acted upon by this type codec."""
|
||||
... return DecimalInt
|
||||
>>> decimalint_codec = DecimalIntCodec()
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
No attempt is made to modify decoding behavior because without additional
|
||||
information, it is impossible to discern which incoming
|
||||
:class:`~bson.decimal128.Decimal128` value needs to be decoded as ``Decimal``
|
||||
and which needs to be decoded as ``DecimalInt``. This example only considers
|
||||
the situation where a user wants to *encode* documents containing one or both
|
||||
of these types.
|
||||
|
||||
Now, we can create a new codec options object and use it to get a collection
|
||||
object:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> type_registry = TypeRegistry([decimal_codec, decimalint_codec])
|
||||
>>> codec_options = CodecOptions(type_registry=type_registry)
|
||||
>>> collection = db.get_collection('test', codec_options=codec_options)
|
||||
>>> collection.drop()
|
||||
|
||||
|
||||
We can now seamlessly encode instances of ``DecimalInt``. Note that the
|
||||
``transform_bson`` method of the base codec class results in these values
|
||||
being decoded as ``Decimal`` (and not ``DecimalInt``):
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> collection.insert_one({'num': DecimalInt("45.321")})
|
||||
<pymongo.results.InsertOneResult object at ...>
|
||||
>>> mydoc = collection.find_one()
|
||||
>>> pprint.pprint(mydoc)
|
||||
{u'_id': ObjectId('...'), u'num': Decimal('45.321')}
|
||||
|
||||
|
||||
The Fallback Encoder
|
||||
--------------------
|
||||
|
||||
.. versionadded:: 3.8
|
||||
|
||||
|
||||
In addition to type codecs, users can also register a callable to encode types
|
||||
that BSON doesn't recognize and for which no type codec has been registered.
|
||||
This callable is the **fallback encoder** and like the ``transform_python``
|
||||
method, it accepts an unencodable value as a parameter and returns a
|
||||
BSON-encodable value. The following fallback encoder encodes python's
|
||||
:py:class:`~decimal.Decimal` type to a :class:`~bson.decimal128.Decimal128`:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> def fallback_encoder(value):
|
||||
... if isinstance(value, Decimal):
|
||||
... return Decimal128(value)
|
||||
... return value
|
||||
|
||||
After declaring the callback, we must create a type registry and codec options
|
||||
with this fallback encoder before it can be used for initializing a collection:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> type_registry = TypeRegistry(fallback_encoder=fallback_encoder)
|
||||
>>> codec_options = CodecOptions(type_registry=type_registry)
|
||||
>>> collection = db.get_collection('test', codec_options=codec_options)
|
||||
>>> collection.drop()
|
||||
|
||||
We can now seamlessly encode instances of :py:class:`~decimal.Decimal`:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> collection.insert_one({'num': Decimal("45.321")})
|
||||
<pymongo.results.InsertOneResult object at ...>
|
||||
>>> mydoc = collection.find_one()
|
||||
>>> pprint.pprint(mydoc)
|
||||
{u'_id': ObjectId('...'), u'num': Decimal128('45.321')}
|
||||
|
||||
As you can tell, fallback encoders are a compelling alternative to type codecs
|
||||
when we only want to encode custom types due to their much simpler API.
|
||||
Users should note however, that fallback encoders cannot be used to modify the
|
||||
encoding of types that PyMongo already understands, as illustrated by the
|
||||
following example:
|
||||
|
||||
>>> def fallback_encoder(value):
|
||||
... """Encoder that converts floats to int."""
|
||||
... if isinstance(value, float):
|
||||
... return int(value)
|
||||
... return value
|
||||
>>> type_registry = TypeRegistry(fallback_encoder=fallback_encoder)
|
||||
>>> codec_options = CodecOptions(type_registry=type_registry)
|
||||
>>> collection = db.get_collection('test', codec_options=codec_options)
|
||||
>>> collection.drop()
|
||||
>>> collection.insert_one({'num': 45.321})
|
||||
<pymongo.results.InsertOneResult object at ...>
|
||||
>>> mydoc = collection.find_one()
|
||||
>>> pprint.pprint(mydoc)
|
||||
{u'_id': ObjectId('...'), u'num': 45.321}
|
||||
|
||||
This is due to the fact that fallback encoders are invoked only after
|
||||
an attempt to encode the value with type codecs and standard BSON encoding
|
||||
routines has been unsuccessful.
|
||||
@ -248,7 +248,8 @@ collection, configured to use :class:`~bson.son.SON` instead of dict:
|
||||
tz_aware=False,
|
||||
uuid_representation=PYTHON_LEGACY,
|
||||
unicode_decode_error_handler='strict',
|
||||
tzinfo=None)
|
||||
tzinfo=None, type_registry=TypeRegistry(type_codecs=[],
|
||||
fallback_encoder=None))
|
||||
>>> collection_son = collection.with_options(codec_options=opts)
|
||||
|
||||
Now, documents and subdocuments in query results are represented with
|
||||
|
||||
@ -949,47 +949,76 @@ class TestTypeRegistry(unittest.TestCase):
|
||||
def transform_bson(self, value):
|
||||
return MyStrType(value)
|
||||
|
||||
def fallback_encoder(value):
|
||||
return value
|
||||
|
||||
cls.types = (MyIntType, MyStrType)
|
||||
cls.codecs = (MyIntCodec, MyStrCodec)
|
||||
cls.fallback_encoder = fallback_encoder
|
||||
|
||||
def test_simple(self):
|
||||
codec_instances = [codec() for codec in self.codecs]
|
||||
type_registry = TypeRegistry(*codec_instances)
|
||||
self.assertEqual(type_registry._encoder_map, {
|
||||
self.types[0]: codec_instances[0].transform_python,
|
||||
self.types[1]: codec_instances[1].transform_python})
|
||||
self.assertEqual(type_registry._decoder_map, {
|
||||
int: codec_instances[0].transform_bson,
|
||||
str: codec_instances[1].transform_bson})
|
||||
def assert_proper_initialization(type_registry, codec_instances):
|
||||
self.assertEqual(type_registry._encoder_map, {
|
||||
self.types[0]: codec_instances[0].transform_python,
|
||||
self.types[1]: codec_instances[1].transform_python})
|
||||
self.assertEqual(type_registry._decoder_map, {
|
||||
int: codec_instances[0].transform_bson,
|
||||
str: codec_instances[1].transform_bson})
|
||||
self.assertEqual(
|
||||
type_registry._fallback_encoder, self.fallback_encoder)
|
||||
|
||||
type_registry = TypeRegistry(codec_instances, self.fallback_encoder)
|
||||
assert_proper_initialization(type_registry, codec_instances)
|
||||
|
||||
type_registry = TypeRegistry(
|
||||
fallback_encoder=self.fallback_encoder, type_codecs=codec_instances)
|
||||
assert_proper_initialization(type_registry, codec_instances)
|
||||
|
||||
# Ensure codec list held by the type registry doesn't change if we
|
||||
# mutate the initial list.
|
||||
codec_instances_copy = list(codec_instances)
|
||||
codec_instances.pop(0)
|
||||
self.assertListEqual(
|
||||
type_registry._TypeRegistry__type_codecs, codec_instances_copy)
|
||||
|
||||
def test_initialize_fail(self):
|
||||
err_msg = "Expected an instance of TypeCodecBase, got .* instead"
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry(*self.codecs)
|
||||
TypeRegistry(self.codecs)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry(type('AnyType', (object,), {})())
|
||||
TypeRegistry([type('AnyType', (object,), {})()])
|
||||
|
||||
err_msg = "fallback_encoder %r is not a callable" % (True,)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry([], True)
|
||||
|
||||
err_msg = "fallback_encoder %r is not a callable" % ('hello',)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry(fallback_encoder='hello')
|
||||
|
||||
def test_not_implemented(self):
|
||||
type_registry = TypeRegistry(type("codec1", (TypeCodecBase, ), {})(),
|
||||
type("codec2", (TypeCodecBase, ), {})())
|
||||
type_registry = TypeRegistry([type("codec1", (TypeCodecBase, ), {})(),
|
||||
type("codec2", (TypeCodecBase, ), {})()])
|
||||
self.assertEqual(type_registry._encoder_map, {})
|
||||
self.assertEqual(type_registry._decoder_map, {})
|
||||
|
||||
def test_type_registry_repr(self):
|
||||
codec_instances = [codec() for codec in self.codecs]
|
||||
type_registry = TypeRegistry(*codec_instances)
|
||||
r = ("TypeRegistry(%s, %s)" % tuple(codec_instances))
|
||||
type_registry = TypeRegistry(codec_instances)
|
||||
r = ("TypeRegistry(type_codecs=%r, fallback_encoder=%r)" % (
|
||||
codec_instances, None))
|
||||
self.assertEqual(r, repr(type_registry))
|
||||
|
||||
def test_type_registry_eq(self):
|
||||
codec_instances = [codec() for codec in self.codecs]
|
||||
self.assertEqual(
|
||||
TypeRegistry(*codec_instances), TypeRegistry(*codec_instances))
|
||||
TypeRegistry(codec_instances), TypeRegistry(codec_instances))
|
||||
|
||||
codec_instances_2 = [codec() for codec in self.codecs]
|
||||
self.assertNotEqual(
|
||||
TypeRegistry(*codec_instances), TypeRegistry(*codec_instances_2))
|
||||
TypeRegistry(codec_instances), TypeRegistry(codec_instances_2))
|
||||
|
||||
|
||||
class TestCodecOptions(unittest.TestCase):
|
||||
@ -1017,7 +1046,8 @@ class TestCodecOptions(unittest.TestCase):
|
||||
r = ("CodecOptions(document_class=dict, tz_aware=False, "
|
||||
"uuid_representation=PYTHON_LEGACY, "
|
||||
"unicode_decode_error_handler='strict', "
|
||||
"tzinfo=None, type_registry=TypeRegistry())")
|
||||
"tzinfo=None, type_registry=TypeRegistry(type_codecs=[], "
|
||||
"fallback_encoder=None))")
|
||||
self.assertEqual(r, repr(CodecOptions()))
|
||||
|
||||
def test_decode_all_defaults(self):
|
||||
|
||||
@ -29,6 +29,7 @@ from bson import (BSON,
|
||||
_dict_to_bson,
|
||||
_bson_to_dict)
|
||||
from bson.codec_options import CodecOptions, TypeCodecBase, TypeRegistry
|
||||
from bson.errors import InvalidDocument
|
||||
|
||||
from test import unittest
|
||||
|
||||
@ -52,7 +53,7 @@ class DecimalCodec(TypeCodecBase):
|
||||
class TestCustomPythonTypeToBSON(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
type_registry = TypeRegistry(DecimalCodec())
|
||||
type_registry = TypeRegistry((DecimalCodec(),))
|
||||
codec_options = CodecOptions(type_registry=type_registry)
|
||||
cls.codecopts = codec_options
|
||||
|
||||
@ -114,6 +115,50 @@ class TestCustomPythonTypeToBSON(unittest.TestCase):
|
||||
fileobj.close()
|
||||
|
||||
|
||||
class TestFallbackEncoder(unittest.TestCase):
|
||||
def _get_codec_options(self, fallback_encoder):
|
||||
type_registry = TypeRegistry(fallback_encoder=fallback_encoder)
|
||||
return CodecOptions(type_registry=type_registry)
|
||||
|
||||
def test_simple(self):
|
||||
codecopts = self._get_codec_options(lambda x: Decimal128(x))
|
||||
document = {'average': Decimal('56.47')}
|
||||
bsonbytes = BSON().encode(document, codec_options=codecopts)
|
||||
|
||||
exp_document = {'average': Decimal128('56.47')}
|
||||
exp_bsonbytes = BSON().encode(exp_document)
|
||||
self.assertEqual(bsonbytes, exp_bsonbytes)
|
||||
|
||||
def test_erroring_fallback_encoder(self):
|
||||
codecopts = self._get_codec_options(lambda _: 1/0)
|
||||
|
||||
# fallback converter should not be invoked when encoding known types.
|
||||
BSON().encode(
|
||||
{'a': 1, 'b': Decimal128('1.01'), 'c': {'arr': ['abc', 3.678]}},
|
||||
codec_options=codecopts)
|
||||
|
||||
# expect an error when encoding a custom type.
|
||||
document = {'average': Decimal('56.47')}
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
BSON().encode(document, codec_options=codecopts)
|
||||
|
||||
def test_noop_fallback_encoder(self):
|
||||
codecopts = self._get_codec_options(lambda x: x)
|
||||
document = {'average': Decimal('56.47')}
|
||||
with self.assertRaises(InvalidDocument):
|
||||
BSON().encode(document, codec_options=codecopts)
|
||||
|
||||
def test_type_unencodable_by_fallback_encoder(self):
|
||||
def fallback_encoder(value):
|
||||
try:
|
||||
return Decimal128(value)
|
||||
except:
|
||||
raise TypeError("cannot encode type %s" % (type(value)))
|
||||
codecopts = self._get_codec_options(fallback_encoder)
|
||||
document = {'average': Decimal}
|
||||
with self.assertRaises(TypeError):
|
||||
BSON().encode(document, codec_options=codecopts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user