PYTHON-1750 Support callbacks for simple types (#405)
This commit is contained in:
parent
fd34c1da2a
commit
83755b8739
@ -388,6 +388,12 @@ def _element_to_dict(data, position, obj_end, opts):
|
||||
element_name)
|
||||
except KeyError:
|
||||
_raise_unknown_type(element_type, element_name)
|
||||
|
||||
if opts.type_registry._decoder_map:
|
||||
custom_decoder = opts.type_registry._decoder_map.get(type(value))
|
||||
if custom_decoder is not None:
|
||||
value = custom_decoder(value)
|
||||
|
||||
return element_name, value, position
|
||||
if _USE_C:
|
||||
_element_to_dict = _cbson._element_to_dict
|
||||
@ -748,6 +754,14 @@ if not PY3:
|
||||
|
||||
def _name_value_to_bson(name, value, check_keys, opts):
|
||||
"""Encode a single name, value pair."""
|
||||
# Custom encoder (if any) takes precedence over default encoders.
|
||||
# Using 'if' instead of 'try...except' for performance since this will
|
||||
# usually not be true.
|
||||
# No support for auto-encoding subtypes of registered custom types.
|
||||
if opts.type_registry._encoder_map:
|
||||
custom_encoder = opts.type_registry._encoder_map.get(type(value))
|
||||
if custom_encoder is not None:
|
||||
value = custom_encoder(value)
|
||||
|
||||
# First see if the type is already cached. KeyError will only ever
|
||||
# happen once per subtype.
|
||||
|
||||
@ -447,6 +447,38 @@ static long _type_marker(PyObject* object) {
|
||||
return type;
|
||||
}
|
||||
|
||||
/* Fill out a type_registry_t* from a TypeRegistry object.
|
||||
*
|
||||
* Return 1 on success. options->document_class is a new reference.
|
||||
* Return 0 on failure.
|
||||
*/
|
||||
int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) {
|
||||
registry->encoder_map = NULL;
|
||||
registry->decoder_map = NULL;
|
||||
registry->registry_obj = NULL;
|
||||
|
||||
registry->encoder_map = PyObject_GetAttrString(registry_obj, "_encoder_map");
|
||||
if (registry->encoder_map == NULL) {
|
||||
goto fail;
|
||||
}
|
||||
registry->is_encoder_empty = (PyDict_Size(registry->encoder_map) == 0);
|
||||
|
||||
registry->decoder_map = PyObject_GetAttrString(registry_obj, "_decoder_map");
|
||||
if (registry->decoder_map == NULL) {
|
||||
goto fail;
|
||||
}
|
||||
registry->is_decoder_empty = (PyDict_Size(registry->decoder_map) == 0);
|
||||
|
||||
registry->registry_obj = registry_obj;
|
||||
Py_INCREF(registry->registry_obj);
|
||||
return 1;
|
||||
|
||||
fail:
|
||||
Py_XDECREF(registry->encoder_map);
|
||||
Py_XDECREF(registry->decoder_map);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Fill out a codec_options_t* from a CodecOptions object. Use with the "O&"
|
||||
* format spec in PyArg_ParseTuple.
|
||||
*
|
||||
@ -455,25 +487,37 @@ static long _type_marker(PyObject* object) {
|
||||
*/
|
||||
int convert_codec_options(PyObject* options_obj, void* p) {
|
||||
codec_options_t* options = (codec_options_t*)p;
|
||||
PyObject* type_registry_obj = NULL;
|
||||
long type_marker;
|
||||
|
||||
options->unicode_decode_error_handler = NULL;
|
||||
if (!PyArg_ParseTuple(options_obj, "ObbzO",
|
||||
|
||||
if (!PyArg_ParseTuple(options_obj, "ObbzOO",
|
||||
&options->document_class,
|
||||
&options->tz_aware,
|
||||
&options->uuid_rep,
|
||||
&options->unicode_decode_error_handler,
|
||||
&options->tzinfo)) {
|
||||
&options->tzinfo,
|
||||
&type_registry_obj))
|
||||
return 0;
|
||||
|
||||
type_marker = _type_marker(options->document_class);
|
||||
if (type_marker < 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
type_marker = _type_marker(options->document_class);
|
||||
if (type_marker < 0) return 0;
|
||||
if (!convert_type_registry(type_registry_obj,
|
||||
&options->type_registry)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
options->is_raw_bson = (101 == type_marker);
|
||||
options->options_obj = options_obj;
|
||||
|
||||
Py_INCREF(options->options_obj);
|
||||
Py_INCREF(options->document_class);
|
||||
Py_INCREF(options->tzinfo);
|
||||
options->options_obj = options_obj;
|
||||
Py_INCREF(options->options_obj);
|
||||
options->is_raw_bson = (101 == type_marker);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -501,17 +545,46 @@ void destroy_codec_options(codec_options_t* options) {
|
||||
Py_CLEAR(options->document_class);
|
||||
Py_CLEAR(options->tzinfo);
|
||||
Py_CLEAR(options->options_obj);
|
||||
Py_CLEAR(options->type_registry.registry_obj);
|
||||
Py_CLEAR(options->type_registry.encoder_map);
|
||||
Py_CLEAR(options->type_registry.decoder_map);
|
||||
}
|
||||
|
||||
static int write_element_to_buffer(PyObject* self, buffer_t buffer,
|
||||
int type_byte, PyObject* value,
|
||||
unsigned char check_keys,
|
||||
const codec_options_t* options) {
|
||||
int result;
|
||||
if(Py_EnterRecursiveCall(" while encoding an object to BSON "))
|
||||
int result = 0;
|
||||
PyObject* value_type = NULL;
|
||||
PyObject* converter = NULL;
|
||||
PyObject* new_value = NULL;
|
||||
|
||||
if(Py_EnterRecursiveCall(" while encoding an object to BSON ")) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!options->type_registry.is_encoder_empty) {
|
||||
value_type = PyObject_Type(value);
|
||||
if (value_type == NULL) {
|
||||
goto fail;
|
||||
}
|
||||
converter = PyDict_GetItem(options->type_registry.encoder_map, value_type);
|
||||
if (converter != NULL) {
|
||||
/* Transform types that have a registered converter.
|
||||
* A new reference is created upon transformation. */
|
||||
new_value = PyObject_CallFunctionObjArgs(converter, value, NULL);
|
||||
if (new_value == NULL) {
|
||||
goto fail;
|
||||
}
|
||||
value = new_value;
|
||||
}
|
||||
}
|
||||
result = _write_element_to_buffer(self, buffer, type_byte,
|
||||
value, check_keys, options);
|
||||
|
||||
fail:
|
||||
Py_XDECREF(value_type);
|
||||
Py_XDECREF(new_value);
|
||||
Py_LeaveRecursiveCall();
|
||||
return result;
|
||||
}
|
||||
@ -2483,6 +2556,24 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
|
||||
}
|
||||
|
||||
if (value) {
|
||||
if (!options->type_registry.is_decoder_empty) {
|
||||
PyObject* value_type = NULL;
|
||||
PyObject* converter = NULL;
|
||||
value_type = PyObject_Type(value);
|
||||
if (value_type == NULL) {
|
||||
goto invalid;
|
||||
}
|
||||
converter = PyDict_GetItem(options->type_registry.decoder_map, value_type);
|
||||
if (converter != NULL) {
|
||||
PyObject* new_value = PyObject_CallFunctionObjArgs(converter, value, NULL);
|
||||
Py_DECREF(value_type);
|
||||
Py_DECREF(value);
|
||||
return new_value;
|
||||
} else {
|
||||
Py_DECREF(value_type);
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
@ -51,12 +51,21 @@
|
||||
#define BYTES_FORMAT_STRING "s#"
|
||||
#endif
|
||||
|
||||
typedef struct type_registry_t {
|
||||
PyObject* encoder_map;
|
||||
PyObject* decoder_map;
|
||||
PyObject* registry_obj;
|
||||
unsigned char is_encoder_empty;
|
||||
unsigned char is_decoder_empty;
|
||||
} type_registry_t;
|
||||
|
||||
typedef struct codec_options_t {
|
||||
PyObject* document_class;
|
||||
unsigned char tz_aware;
|
||||
unsigned char uuid_rep;
|
||||
char* unicode_decode_error_handler;
|
||||
PyObject* tzinfo;
|
||||
type_registry_t type_registry;
|
||||
PyObject* options_obj;
|
||||
unsigned char is_raw_bson;
|
||||
} codec_options_t;
|
||||
|
||||
@ -23,6 +23,7 @@ from bson.binary import (ALL_UUID_REPRESENTATIONS,
|
||||
PYTHON_LEGACY,
|
||||
UUID_REPRESENTATION_NAMES)
|
||||
|
||||
|
||||
_RAW_BSON_DOCUMENT_MARKER = 101
|
||||
|
||||
|
||||
@ -32,10 +33,84 @@ def _raw_document_class(document_class):
|
||||
return marker == _RAW_BSON_DOCUMENT_MARKER
|
||||
|
||||
|
||||
class TypeCodecBase(object):
|
||||
"""Base class for defining type codec classes which describe how a
|
||||
custom type can be transformed to/from one of the types BSON already
|
||||
understands, and can encode/decode.
|
||||
|
||||
Codec classes must implement the ``python_type`` property, and the
|
||||
``transform_python`` method to support encoding, or the ``bson_type``
|
||||
property and ``transform_bson`` method to support decoding. Note that a
|
||||
single codec class may support both encoding and decoding.
|
||||
"""
|
||||
@property
|
||||
def python_type(self):
|
||||
"""The Python type to be converted into something serializable."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def bson_type(self):
|
||||
"""The BSON type to be converted into our own type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def transform_bson(self, value):
|
||||
"""Convert the given BSON value into our own type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def transform_python(self, value):
|
||||
"""Convert the given Python object into something serializable."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TypeRegistry(object):
|
||||
"""Encapsulates type codecs used in encoding and / or decoding BSON.
|
||||
|
||||
``TypeRegistry`` can be initialized with an arbitrary number of type
|
||||
codecs::
|
||||
|
||||
>>> from bson.codec_options import TypeRegistry
|
||||
>>> type_registry = TypeRegistry(Codec1, Codec2, Codec3, ...)
|
||||
|
||||
If multiple codecs try to transform a single python or BSON type,
|
||||
the transformation described by the last type codec prevails.
|
||||
"""
|
||||
def __init__(self, *type_codecs):
|
||||
self.__args = type_codecs
|
||||
self._encoder_map = {}
|
||||
self._decoder_map = {}
|
||||
for codec in type_codecs:
|
||||
if not isinstance(codec, TypeCodecBase):
|
||||
raise TypeError(
|
||||
"Expected an instance of %s, got %r instead" % (
|
||||
TypeCodecBase.__name__, codec))
|
||||
try:
|
||||
python_type = codec.python_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self._encoder_map[python_type] = codec.transform_python
|
||||
|
||||
try:
|
||||
bson_type = codec.bson_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
self._decoder_map[bson_type] = codec.transform_bson
|
||||
|
||||
def __repr__(self):
|
||||
return '%s%r' % (self.__class__.__name__, self.__args)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return ((self._decoder_map == other._decoder_map) and
|
||||
(self._encoder_map == other._encoder_map))
|
||||
|
||||
|
||||
_options_base = namedtuple(
|
||||
'CodecOptions',
|
||||
('document_class', 'tz_aware', 'uuid_representation',
|
||||
'unicode_decode_error_handler', 'tzinfo'))
|
||||
'unicode_decode_error_handler', 'tzinfo', 'type_registry'))
|
||||
|
||||
|
||||
class CodecOptions(_options_base):
|
||||
@ -93,6 +168,8 @@ class CodecOptions(_options_base):
|
||||
- `tzinfo`: A :class:`~datetime.tzinfo` subclass that specifies the
|
||||
timezone to/from which :class:`~datetime.datetime` objects should be
|
||||
encoded/decoded.
|
||||
- `type_registry`: Instance of :class:`TypeRegistry` used to customize
|
||||
encoding and decoding behavior.
|
||||
|
||||
.. warning:: Care must be taken when changing
|
||||
`unicode_decode_error_handler` from its default value ('strict').
|
||||
@ -104,7 +181,7 @@ class CodecOptions(_options_base):
|
||||
def __new__(cls, document_class=dict,
|
||||
tz_aware=False, uuid_representation=PYTHON_LEGACY,
|
||||
unicode_decode_error_handler="strict",
|
||||
tzinfo=None):
|
||||
tzinfo=None, type_registry=None):
|
||||
if not (issubclass(document_class, abc.MutableMapping) or
|
||||
_raw_document_class(document_class)):
|
||||
raise TypeError("document_class must be dict, bson.son.SON, "
|
||||
@ -126,9 +203,14 @@ class CodecOptions(_options_base):
|
||||
raise ValueError(
|
||||
"cannot specify tzinfo without also setting tz_aware=True")
|
||||
|
||||
type_registry = type_registry or TypeRegistry()
|
||||
|
||||
if not isinstance(type_registry, TypeRegistry):
|
||||
raise TypeError("type_registry must be an instance of TypeRegistry")
|
||||
|
||||
return tuple.__new__(
|
||||
cls, (document_class, tz_aware, uuid_representation,
|
||||
unicode_decode_error_handler, tzinfo))
|
||||
unicode_decode_error_handler, tzinfo, type_registry))
|
||||
|
||||
def _arguments_repr(self):
|
||||
"""Representation of the arguments used to create this object."""
|
||||
@ -139,10 +221,12 @@ class CodecOptions(_options_base):
|
||||
uuid_rep_repr = UUID_REPRESENTATION_NAMES.get(self.uuid_representation,
|
||||
self.uuid_representation)
|
||||
|
||||
return ('document_class=%s, tz_aware=%r, uuid_representation='
|
||||
'%s, unicode_decode_error_handler=%r, tzinfo=%r' %
|
||||
return ('document_class=%s, tz_aware=%r, uuid_representation=%s, '
|
||||
'unicode_decode_error_handler=%r, tzinfo=%r, '
|
||||
'type_registry=%r' %
|
||||
(document_class_repr, self.tz_aware, uuid_rep_repr,
|
||||
self.unicode_decode_error_handler, self.tzinfo))
|
||||
self.unicode_decode_error_handler, self.tzinfo,
|
||||
self.type_registry))
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s)' % (self.__class__.__name__, self._arguments_repr())
|
||||
@ -165,7 +249,9 @@ class CodecOptions(_options_base):
|
||||
kwargs.get('uuid_representation', self.uuid_representation),
|
||||
kwargs.get('unicode_decode_error_handler',
|
||||
self.unicode_decode_error_handler),
|
||||
kwargs.get('tzinfo', self.tzinfo))
|
||||
kwargs.get('tzinfo', self.tzinfo),
|
||||
kwargs.get('type_registry', self.type_registry)
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_CODEC_OPTIONS = CodecOptions()
|
||||
@ -183,4 +269,6 @@ def _parse_codec_options(options):
|
||||
unicode_decode_error_handler=options.get(
|
||||
'unicode_decode_error_handler',
|
||||
DEFAULT_CODEC_OPTIONS.unicode_decode_error_handler),
|
||||
tzinfo=options.get('tzinfo', DEFAULT_CODEC_OPTIONS.tzinfo))
|
||||
tzinfo=options.get('tzinfo', DEFAULT_CODEC_OPTIONS.tzinfo),
|
||||
type_registry=options.get(
|
||||
'type_registry', DEFAULT_CODEC_OPTIONS.type_registry))
|
||||
|
||||
@ -166,6 +166,8 @@ latex_documents = [
|
||||
# If false, no module index is generated.
|
||||
#latex_use_modindex = True
|
||||
|
||||
|
||||
intersphinx_mapping = {
|
||||
'gevent': ('http://www.gevent.org/', None),
|
||||
'py': ('https://docs.python.org/3/', None),
|
||||
}
|
||||
|
||||
148
doc/examples/custom_type.rst
Normal file
148
doc/examples/custom_type.rst
Normal file
@ -0,0 +1,148 @@
|
||||
Custom Type Example
|
||||
===================
|
||||
|
||||
This is an example of using a custom type with PyMongo. The example here shows
|
||||
how to subclass :class:`~bson.codec_options.TypeCodecBase` to write a type
|
||||
codec, which is used to populate a :class:`~bson.codec_options.TypeRegistry`.
|
||||
The type registry can then be used to create a custom-type-aware
|
||||
:class:`~pymongo.collection.Collection`. Read and write operations
|
||||
issued against the resulting collection object transparently manipulate
|
||||
documents as they are saved or retrieved from MongoDB.
|
||||
|
||||
|
||||
Setup
|
||||
-----
|
||||
|
||||
We'll start by getting a clean database to use for the example:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> from pymongo import MongoClient
|
||||
>>> client = MongoClient()
|
||||
>>> client.drop_database('custom_type_example')
|
||||
>>> db = client.custom_type_example
|
||||
|
||||
|
||||
Since the purpose of the example is to demonstrate working with custom types,
|
||||
we'll need a custom data type to use. For this example, we will be working with
|
||||
the :py:class:`~decimal.Decimal` type from Python's standard library. Since the
|
||||
BSON library has a :class:`~bson.decimal128.Decimal128` type (that implements
|
||||
the IEEE 754 decimal128 decimal-based floating-point numbering format) which
|
||||
is distinct from Python's built-in :py:class:`~decimal.Decimal` type, when we
|
||||
try to save an instance of ``Decimal`` with PyMongo, we get an
|
||||
:exc:`~bson.errors.InvalidDocument` exception.
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> from decimal import Decimal
|
||||
>>> num = Decimal("45.321")
|
||||
>>> db.test.insert_one({'num': num})
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
bson.errors.InvalidDocument: Cannot encode object: <__main__.Decimal object at ...>
|
||||
|
||||
|
||||
.. _custom-type-type-codec:
|
||||
|
||||
The Type Codec
|
||||
--------------
|
||||
|
||||
In order to encode custom types, we must first define a **type codec** for our
|
||||
type. A type codec describes how an instance of a custom type can be
|
||||
*transformed* to/from one of the types :mod:`~bson` already understands, and
|
||||
can encode/decode. Type codecs must inherit from
|
||||
:class:`~bson.codec_options.TypeCodecBase`. In order to encode a custom type,
|
||||
a codec must implement the ``python_type`` property and the
|
||||
``transform_python`` method. Similarly, in order to decode a custom type,
|
||||
a codec must implement the ``bson_type`` property and the ``transform_bson``
|
||||
method. Note that a type codec need not support both encoding and decoding.
|
||||
|
||||
|
||||
The type codec for our custom type simply needs to define how a
|
||||
:py:class:`~decimal.Decimal` instance can be converted into a
|
||||
:class:`~bson.decimal128.Decimal128` instance and vice-versa:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> from bson.decimal128 import Decimal128
|
||||
>>> from bson.codec_options import TypeCodecBase
|
||||
>>> class DecimalCodec(TypeCodecBase):
|
||||
... @property
|
||||
... def python_type(self):
|
||||
... """The Python type acted upon by this type codec."""
|
||||
... return Decimal
|
||||
...
|
||||
... def transform_python(self, value):
|
||||
... """Function that transforms a custom type value into a type
|
||||
... that BSON can encode."""
|
||||
... return Decimal128(value)
|
||||
...
|
||||
... @property
|
||||
... def bson_type(self):
|
||||
... """The BSON type acted upon by this type codec."""
|
||||
... return Decimal128
|
||||
...
|
||||
... def transform_bson(self, value):
|
||||
... """Function that transforms a vanilla BSON type value into our
|
||||
... custom type."""
|
||||
... return value.to_decimal()
|
||||
>>> decimal_codec = DecimalCodec()
|
||||
|
||||
|
||||
.. _custom-type-type-registry:
|
||||
|
||||
The Type Registry
|
||||
-----------------
|
||||
|
||||
Before we can begin encoding and decoding our custom type objects, we must
|
||||
first inform PyMongo about our type codec. This is done by creating a
|
||||
:class:`~bson.codec_options.TypeRegistry` instance:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> from bson.codec_options import TypeRegistry
|
||||
>>> type_registry = TypeRegistry(decimal_codec)
|
||||
|
||||
|
||||
Note that type registries can be instantiated with any number of type codecs.
|
||||
Once instantiated, registries are immutable and the only way to add codecs
|
||||
to a registry is to create a new one.
|
||||
|
||||
|
||||
Putting it together
|
||||
-------------------
|
||||
|
||||
Finally, we can define a :class:`~bson.codec_options.CodecOptions` instance
|
||||
with our ``type_registry`` and use it to get a
|
||||
:class:`~pymongo.collection.Collection` object that understands the
|
||||
:py:class:`~decimal.Decimal` data type:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> from bson.codec_options import CodecOptions
|
||||
>>> codec_options = CodecOptions(type_registry=type_registry)
|
||||
>>> collection = db.get_collection('test', codec_options=codec_options)
|
||||
|
||||
|
||||
Now, we can seamlessly encode and decode instances of
|
||||
:py:class:`~decimal.Decimal`:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> collection.insert_one({'num': Decimal("45.321")})
|
||||
<pymongo.results.InsertOneResult object at ...>
|
||||
>>> mydoc = collection.find_one()
|
||||
>>> import pprint
|
||||
>>> pprint.pprint(mydoc)
|
||||
{u'_id': ObjectId('...'), u'num': Decimal('45.321')}
|
||||
|
||||
|
||||
We can see what's actually being saved to the database by creating a fresh
|
||||
collection object without the customized codec options and using that to query
|
||||
MongoDB:
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> vanilla_collection = db.get_collection('test')
|
||||
>>> pprint.pprint(vanilla_collection.find_one())
|
||||
{u'_id': ObjectId('...'), u'num': Decimal128('45.321')}
|
||||
@ -20,6 +20,7 @@ MongoDB, you can start it like so:
|
||||
authentication
|
||||
collations
|
||||
copydb
|
||||
custom_type
|
||||
bulk
|
||||
datetimes
|
||||
geo
|
||||
|
||||
@ -689,7 +689,7 @@ class IntegrationTest(PyMongoTestCase):
|
||||
# Use assertRaisesRegex if available, otherwise use Python 2.7's
|
||||
# deprecated assertRaisesRegexp, with a 'p'.
|
||||
if not hasattr(unittest.TestCase, 'assertRaisesRegex'):
|
||||
IntegrationTest.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
|
||||
unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
|
||||
|
||||
|
||||
class MockClientTest(unittest.TestCase):
|
||||
|
||||
@ -34,7 +34,7 @@ from bson import (BSON,
|
||||
Regex)
|
||||
from bson.binary import Binary, UUIDLegacy
|
||||
from bson.code import Code
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.codec_options import CodecOptions, TypeCodecBase, TypeRegistry
|
||||
from bson.int64 import Int64
|
||||
from bson.objectid import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
@ -906,6 +906,92 @@ class TestBSON(unittest.TestCase):
|
||||
BSON.encode({"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}})
|
||||
|
||||
|
||||
class TestTypeRegistry(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
class MyIntType(object):
|
||||
def __init__(self, x):
|
||||
assert isinstance(x, int)
|
||||
self.x = x
|
||||
|
||||
class MyStrType(object):
|
||||
def __init__(self, x):
|
||||
assert isinstance(x, str)
|
||||
self.x = x
|
||||
|
||||
class MyIntCodec(TypeCodecBase):
|
||||
@property
|
||||
def python_type(self):
|
||||
return MyIntType
|
||||
|
||||
@property
|
||||
def bson_type(self):
|
||||
return int
|
||||
|
||||
def transform_python(self, value):
|
||||
return value.x
|
||||
|
||||
def transform_bson(self, value):
|
||||
return MyIntType(value)
|
||||
|
||||
class MyStrCodec(TypeCodecBase):
|
||||
@property
|
||||
def python_type(self):
|
||||
return MyStrType
|
||||
|
||||
@property
|
||||
def bson_type(self):
|
||||
return str
|
||||
|
||||
def transform_python(self, value):
|
||||
return value.x
|
||||
|
||||
def transform_bson(self, value):
|
||||
return MyStrType(value)
|
||||
|
||||
cls.types = (MyIntType, MyStrType)
|
||||
cls.codecs = (MyIntCodec, MyStrCodec)
|
||||
|
||||
def test_simple(self):
|
||||
codec_instances = [codec() for codec in self.codecs]
|
||||
type_registry = TypeRegistry(*codec_instances)
|
||||
self.assertEqual(type_registry._encoder_map, {
|
||||
self.types[0]: codec_instances[0].transform_python,
|
||||
self.types[1]: codec_instances[1].transform_python})
|
||||
self.assertEqual(type_registry._decoder_map, {
|
||||
int: codec_instances[0].transform_bson,
|
||||
str: codec_instances[1].transform_bson})
|
||||
|
||||
def test_initialize_fail(self):
|
||||
err_msg = "Expected an instance of TypeCodecBase, got .* instead"
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry(*self.codecs)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
TypeRegistry(type('AnyType', (object,), {})())
|
||||
|
||||
def test_not_implemented(self):
|
||||
type_registry = TypeRegistry(type("codec1", (TypeCodecBase, ), {})(),
|
||||
type("codec2", (TypeCodecBase, ), {})())
|
||||
self.assertEqual(type_registry._encoder_map, {})
|
||||
self.assertEqual(type_registry._decoder_map, {})
|
||||
|
||||
def test_type_registry_repr(self):
|
||||
codec_instances = [codec() for codec in self.codecs]
|
||||
type_registry = TypeRegistry(*codec_instances)
|
||||
r = ("TypeRegistry(%s, %s)" % tuple(codec_instances))
|
||||
self.assertEqual(r, repr(type_registry))
|
||||
|
||||
def test_type_registry_eq(self):
|
||||
codec_instances = [codec() for codec in self.codecs]
|
||||
self.assertEqual(
|
||||
TypeRegistry(*codec_instances), TypeRegistry(*codec_instances))
|
||||
|
||||
codec_instances_2 = [codec() for codec in self.codecs]
|
||||
self.assertNotEqual(
|
||||
TypeRegistry(*codec_instances), TypeRegistry(*codec_instances_2))
|
||||
|
||||
|
||||
class TestCodecOptions(unittest.TestCase):
|
||||
def test_document_class(self):
|
||||
self.assertRaises(TypeError, CodecOptions, document_class=object)
|
||||
@ -931,7 +1017,7 @@ class TestCodecOptions(unittest.TestCase):
|
||||
r = ("CodecOptions(document_class=dict, tz_aware=False, "
|
||||
"uuid_representation=PYTHON_LEGACY, "
|
||||
"unicode_decode_error_handler='strict', "
|
||||
"tzinfo=None)")
|
||||
"tzinfo=None, type_registry=TypeRegistry())")
|
||||
self.assertEqual(r, repr(CodecOptions()))
|
||||
|
||||
def test_decode_all_defaults(self):
|
||||
|
||||
119
test/test_bson_custom_types.py
Normal file
119
test/test_bson_custom_types.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test support for callbacks to encode/decode custom types."""
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
from decimal import Decimal
|
||||
from random import random
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson import (BSON,
|
||||
Decimal128,
|
||||
decode_all,
|
||||
decode_file_iter,
|
||||
decode_iter,
|
||||
_dict_to_bson,
|
||||
_bson_to_dict)
|
||||
from bson.codec_options import CodecOptions, TypeCodecBase, TypeRegistry
|
||||
|
||||
from test import unittest
|
||||
|
||||
|
||||
class DecimalCodec(TypeCodecBase):
|
||||
@property
|
||||
def bson_type(self):
|
||||
return Decimal128
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return Decimal
|
||||
|
||||
def transform_bson(self, value):
|
||||
return value.to_decimal()
|
||||
|
||||
def transform_python(self, value):
|
||||
return Decimal128(value)
|
||||
|
||||
|
||||
class TestCustomPythonTypeToBSON(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
type_registry = TypeRegistry(DecimalCodec())
|
||||
codec_options = CodecOptions(type_registry=type_registry)
|
||||
cls.codecopts = codec_options
|
||||
|
||||
def test_encode_decode_roundtrip(self):
|
||||
document = {'average': Decimal('56.47')}
|
||||
bsonbytes = BSON().encode(document, codec_options=self.codecopts)
|
||||
rt_document = BSON(bsonbytes).decode(codec_options=self.codecopts)
|
||||
self.assertEqual(document, rt_document)
|
||||
|
||||
def test_decode_all(self):
|
||||
documents = []
|
||||
for dec in range(3):
|
||||
documents.append({'average': Decimal('56.4%s' % (dec,))})
|
||||
|
||||
bsonstream = bytes()
|
||||
for doc in documents:
|
||||
bsonstream += BSON.encode(doc, codec_options=self.codecopts)
|
||||
|
||||
self.assertEqual(
|
||||
decode_all(bsonstream, self.codecopts), documents)
|
||||
|
||||
def test__bson_to_dict(self):
|
||||
document = {'average': Decimal('56.47')}
|
||||
rawbytes = BSON.encode(document, codec_options=self.codecopts)
|
||||
decoded_document = _bson_to_dict(rawbytes, self.codecopts)
|
||||
self.assertEqual(document, decoded_document)
|
||||
|
||||
def test__dict_to_bson(self):
|
||||
document = {'average': Decimal('56.47')}
|
||||
rawbytes = BSON.encode(document, codec_options=self.codecopts)
|
||||
encoded_document = _dict_to_bson(document, False, self.codecopts)
|
||||
self.assertEqual(encoded_document, rawbytes)
|
||||
|
||||
def _generate_multidocument_bson_stream(self):
|
||||
inp_num = [str(random() * 100)[:4] for _ in range(10)]
|
||||
docs = [{'n': Decimal128(dec)} for dec in inp_num]
|
||||
edocs = [{'n': Decimal(dec)} for dec in inp_num]
|
||||
bsonstream = b""
|
||||
for doc in docs:
|
||||
bsonstream += BSON.encode(doc)
|
||||
return edocs, bsonstream
|
||||
|
||||
def test_decode_iter(self):
|
||||
expected, bson_data = self._generate_multidocument_bson_stream()
|
||||
for expected_doc, decoded_doc in zip(
|
||||
expected, decode_iter(bson_data, self.codecopts)):
|
||||
self.assertEqual(expected_doc, decoded_doc)
|
||||
|
||||
def test_decode_file_iter(self):
|
||||
expected, bson_data = self._generate_multidocument_bson_stream()
|
||||
fileobj = tempfile.TemporaryFile()
|
||||
fileobj.write(bson_data)
|
||||
fileobj.seek(0)
|
||||
|
||||
for expected_doc, decoded_doc in zip(
|
||||
expected, decode_file_iter(fileobj, self.codecopts)):
|
||||
self.assertEqual(expected_doc, decoded_doc)
|
||||
|
||||
fileobj.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user