diff --git a/bson/__init__.py b/bson/__init__.py index 16573b3de..29b1af17e 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -388,6 +388,12 @@ def _element_to_dict(data, position, obj_end, opts): element_name) except KeyError: _raise_unknown_type(element_type, element_name) + + if opts.type_registry._decoder_map: + custom_decoder = opts.type_registry._decoder_map.get(type(value)) + if custom_decoder is not None: + value = custom_decoder(value) + return element_name, value, position if _USE_C: _element_to_dict = _cbson._element_to_dict @@ -748,6 +754,14 @@ if not PY3: def _name_value_to_bson(name, value, check_keys, opts): """Encode a single name, value pair.""" + # Custom encoder (if any) takes precedence over default encoders. + # Using 'if' instead of 'try...except' for performance since this will + # usually not be true. + # No support for auto-encoding subtypes of registered custom types. + if opts.type_registry._encoder_map: + custom_encoder = opts.type_registry._encoder_map.get(type(value)) + if custom_encoder is not None: + value = custom_encoder(value) # First see if the type is already cached. KeyError will only ever # happen once per subtype. diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index afe05dcd7..6d4f8241b 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -447,6 +447,38 @@ static long _type_marker(PyObject* object) { return type; } +/* Fill out a type_registry_t* from a TypeRegistry object. + * + * Return 1 on success. options->document_class is a new reference. + * Return 0 on failure. + */ +int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) { + registry->encoder_map = NULL; + registry->decoder_map = NULL; + registry->registry_obj = NULL; + + registry->encoder_map = PyObject_GetAttrString(registry_obj, "_encoder_map"); + if (registry->encoder_map == NULL) { + goto fail; + } + registry->is_encoder_empty = (PyDict_Size(registry->encoder_map) == 0); + + registry->decoder_map = PyObject_GetAttrString(registry_obj, "_decoder_map"); + if (registry->decoder_map == NULL) { + goto fail; + } + registry->is_decoder_empty = (PyDict_Size(registry->decoder_map) == 0); + + registry->registry_obj = registry_obj; + Py_INCREF(registry->registry_obj); + return 1; + +fail: + Py_XDECREF(registry->encoder_map); + Py_XDECREF(registry->decoder_map); + return 0; +} + /* Fill out a codec_options_t* from a CodecOptions object. Use with the "O&" * format spec in PyArg_ParseTuple. * @@ -455,25 +487,37 @@ static long _type_marker(PyObject* object) { */ int convert_codec_options(PyObject* options_obj, void* p) { codec_options_t* options = (codec_options_t*)p; + PyObject* type_registry_obj = NULL; long type_marker; + options->unicode_decode_error_handler = NULL; - if (!PyArg_ParseTuple(options_obj, "ObbzO", + + if (!PyArg_ParseTuple(options_obj, "ObbzOO", &options->document_class, &options->tz_aware, &options->uuid_rep, &options->unicode_decode_error_handler, - &options->tzinfo)) { + &options->tzinfo, + &type_registry_obj)) + return 0; + + type_marker = _type_marker(options->document_class); + if (type_marker < 0) { return 0; } - type_marker = _type_marker(options->document_class); - if (type_marker < 0) return 0; + if (!convert_type_registry(type_registry_obj, + &options->type_registry)) { + return 0; + } + options->is_raw_bson = (101 == type_marker); + options->options_obj = options_obj; + + Py_INCREF(options->options_obj); Py_INCREF(options->document_class); Py_INCREF(options->tzinfo); - options->options_obj = options_obj; - Py_INCREF(options->options_obj); - options->is_raw_bson = (101 == type_marker); + return 1; } @@ -501,17 +545,46 @@ void destroy_codec_options(codec_options_t* options) { Py_CLEAR(options->document_class); Py_CLEAR(options->tzinfo); Py_CLEAR(options->options_obj); + Py_CLEAR(options->type_registry.registry_obj); + Py_CLEAR(options->type_registry.encoder_map); + Py_CLEAR(options->type_registry.decoder_map); } static int write_element_to_buffer(PyObject* self, buffer_t buffer, int type_byte, PyObject* value, unsigned char check_keys, const codec_options_t* options) { - int result; - if(Py_EnterRecursiveCall(" while encoding an object to BSON ")) + int result = 0; + PyObject* value_type = NULL; + PyObject* converter = NULL; + PyObject* new_value = NULL; + + if(Py_EnterRecursiveCall(" while encoding an object to BSON ")) { return 0; + } + + if (!options->type_registry.is_encoder_empty) { + value_type = PyObject_Type(value); + if (value_type == NULL) { + goto fail; + } + converter = PyDict_GetItem(options->type_registry.encoder_map, value_type); + if (converter != NULL) { + /* Transform types that have a registered converter. + * A new reference is created upon transformation. */ + new_value = PyObject_CallFunctionObjArgs(converter, value, NULL); + if (new_value == NULL) { + goto fail; + } + value = new_value; + } + } result = _write_element_to_buffer(self, buffer, type_byte, value, check_keys, options); + +fail: + Py_XDECREF(value_type); + Py_XDECREF(new_value); Py_LeaveRecursiveCall(); return result; } @@ -2483,6 +2556,24 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, } if (value) { + if (!options->type_registry.is_decoder_empty) { + PyObject* value_type = NULL; + PyObject* converter = NULL; + value_type = PyObject_Type(value); + if (value_type == NULL) { + goto invalid; + } + converter = PyDict_GetItem(options->type_registry.decoder_map, value_type); + if (converter != NULL) { + PyObject* new_value = PyObject_CallFunctionObjArgs(converter, value, NULL); + Py_DECREF(value_type); + Py_DECREF(value); + return new_value; + } else { + Py_DECREF(value_type); + return value; + } + } return value; } diff --git a/bson/_cbsonmodule.h b/bson/_cbsonmodule.h index f483a8b3c..10bb15fcf 100644 --- a/bson/_cbsonmodule.h +++ b/bson/_cbsonmodule.h @@ -51,12 +51,21 @@ #define BYTES_FORMAT_STRING "s#" #endif +typedef struct type_registry_t { + PyObject* encoder_map; + PyObject* decoder_map; + PyObject* registry_obj; + unsigned char is_encoder_empty; + unsigned char is_decoder_empty; +} type_registry_t; + typedef struct codec_options_t { PyObject* document_class; unsigned char tz_aware; unsigned char uuid_rep; char* unicode_decode_error_handler; PyObject* tzinfo; + type_registry_t type_registry; PyObject* options_obj; unsigned char is_raw_bson; } codec_options_t; diff --git a/bson/codec_options.py b/bson/codec_options.py index f54fab433..3194a0e7f 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -23,6 +23,7 @@ from bson.binary import (ALL_UUID_REPRESENTATIONS, PYTHON_LEGACY, UUID_REPRESENTATION_NAMES) + _RAW_BSON_DOCUMENT_MARKER = 101 @@ -32,10 +33,84 @@ def _raw_document_class(document_class): return marker == _RAW_BSON_DOCUMENT_MARKER +class TypeCodecBase(object): + """Base class for defining type codec classes which describe how a + custom type can be transformed to/from one of the types BSON already + understands, and can encode/decode. + + Codec classes must implement the ``python_type`` property, and the + ``transform_python`` method to support encoding, or the ``bson_type`` + property and ``transform_bson`` method to support decoding. Note that a + single codec class may support both encoding and decoding. + """ + @property + def python_type(self): + """The Python type to be converted into something serializable.""" + raise NotImplementedError + + @property + def bson_type(self): + """The BSON type to be converted into our own type.""" + raise NotImplementedError + + def transform_bson(self, value): + """Convert the given BSON value into our own type.""" + raise NotImplementedError + + def transform_python(self, value): + """Convert the given Python object into something serializable.""" + raise NotImplementedError + + +class TypeRegistry(object): + """Encapsulates type codecs used in encoding and / or decoding BSON. + + ``TypeRegistry`` can be initialized with an arbitrary number of type + codecs:: + + >>> from bson.codec_options import TypeRegistry + >>> type_registry = TypeRegistry(Codec1, Codec2, Codec3, ...) + + If multiple codecs try to transform a single python or BSON type, + the transformation described by the last type codec prevails. + """ + def __init__(self, *type_codecs): + self.__args = type_codecs + self._encoder_map = {} + self._decoder_map = {} + for codec in type_codecs: + if not isinstance(codec, TypeCodecBase): + raise TypeError( + "Expected an instance of %s, got %r instead" % ( + TypeCodecBase.__name__, codec)) + try: + python_type = codec.python_type + except NotImplementedError: + pass + else: + self._encoder_map[python_type] = codec.transform_python + + try: + bson_type = codec.bson_type + except NotImplementedError: + pass + else: + self._decoder_map[bson_type] = codec.transform_bson + + def __repr__(self): + return '%s%r' % (self.__class__.__name__, self.__args) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return ((self._decoder_map == other._decoder_map) and + (self._encoder_map == other._encoder_map)) + + _options_base = namedtuple( 'CodecOptions', ('document_class', 'tz_aware', 'uuid_representation', - 'unicode_decode_error_handler', 'tzinfo')) + 'unicode_decode_error_handler', 'tzinfo', 'type_registry')) class CodecOptions(_options_base): @@ -93,6 +168,8 @@ class CodecOptions(_options_base): - `tzinfo`: A :class:`~datetime.tzinfo` subclass that specifies the timezone to/from which :class:`~datetime.datetime` objects should be encoded/decoded. + - `type_registry`: Instance of :class:`TypeRegistry` used to customize + encoding and decoding behavior. .. warning:: Care must be taken when changing `unicode_decode_error_handler` from its default value ('strict'). @@ -104,7 +181,7 @@ class CodecOptions(_options_base): def __new__(cls, document_class=dict, tz_aware=False, uuid_representation=PYTHON_LEGACY, unicode_decode_error_handler="strict", - tzinfo=None): + tzinfo=None, type_registry=None): if not (issubclass(document_class, abc.MutableMapping) or _raw_document_class(document_class)): raise TypeError("document_class must be dict, bson.son.SON, " @@ -126,9 +203,14 @@ class CodecOptions(_options_base): raise ValueError( "cannot specify tzinfo without also setting tz_aware=True") + type_registry = type_registry or TypeRegistry() + + if not isinstance(type_registry, TypeRegistry): + raise TypeError("type_registry must be an instance of TypeRegistry") + return tuple.__new__( cls, (document_class, tz_aware, uuid_representation, - unicode_decode_error_handler, tzinfo)) + unicode_decode_error_handler, tzinfo, type_registry)) def _arguments_repr(self): """Representation of the arguments used to create this object.""" @@ -139,10 +221,12 @@ class CodecOptions(_options_base): uuid_rep_repr = UUID_REPRESENTATION_NAMES.get(self.uuid_representation, self.uuid_representation) - return ('document_class=%s, tz_aware=%r, uuid_representation=' - '%s, unicode_decode_error_handler=%r, tzinfo=%r' % + return ('document_class=%s, tz_aware=%r, uuid_representation=%s, ' + 'unicode_decode_error_handler=%r, tzinfo=%r, ' + 'type_registry=%r' % (document_class_repr, self.tz_aware, uuid_rep_repr, - self.unicode_decode_error_handler, self.tzinfo)) + self.unicode_decode_error_handler, self.tzinfo, + self.type_registry)) def __repr__(self): return '%s(%s)' % (self.__class__.__name__, self._arguments_repr()) @@ -165,7 +249,9 @@ class CodecOptions(_options_base): kwargs.get('uuid_representation', self.uuid_representation), kwargs.get('unicode_decode_error_handler', self.unicode_decode_error_handler), - kwargs.get('tzinfo', self.tzinfo)) + kwargs.get('tzinfo', self.tzinfo), + kwargs.get('type_registry', self.type_registry) + ) DEFAULT_CODEC_OPTIONS = CodecOptions() @@ -183,4 +269,6 @@ def _parse_codec_options(options): unicode_decode_error_handler=options.get( 'unicode_decode_error_handler', DEFAULT_CODEC_OPTIONS.unicode_decode_error_handler), - tzinfo=options.get('tzinfo', DEFAULT_CODEC_OPTIONS.tzinfo)) + tzinfo=options.get('tzinfo', DEFAULT_CODEC_OPTIONS.tzinfo), + type_registry=options.get( + 'type_registry', DEFAULT_CODEC_OPTIONS.type_registry)) diff --git a/doc/conf.py b/doc/conf.py index 1702f0fa5..ad4c42b9e 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -166,6 +166,8 @@ latex_documents = [ # If false, no module index is generated. #latex_use_modindex = True + intersphinx_mapping = { 'gevent': ('http://www.gevent.org/', None), + 'py': ('https://docs.python.org/3/', None), } diff --git a/doc/examples/custom_type.rst b/doc/examples/custom_type.rst new file mode 100644 index 000000000..cbdf35ee1 --- /dev/null +++ b/doc/examples/custom_type.rst @@ -0,0 +1,148 @@ +Custom Type Example +=================== + +This is an example of using a custom type with PyMongo. The example here shows +how to subclass :class:`~bson.codec_options.TypeCodecBase` to write a type +codec, which is used to populate a :class:`~bson.codec_options.TypeRegistry`. +The type registry can then be used to create a custom-type-aware +:class:`~pymongo.collection.Collection`. Read and write operations +issued against the resulting collection object transparently manipulate +documents as they are saved or retrieved from MongoDB. + + +Setup +----- + +We'll start by getting a clean database to use for the example: + +.. doctest:: + + >>> from pymongo import MongoClient + >>> client = MongoClient() + >>> client.drop_database('custom_type_example') + >>> db = client.custom_type_example + + +Since the purpose of the example is to demonstrate working with custom types, +we'll need a custom data type to use. For this example, we will be working with +the :py:class:`~decimal.Decimal` type from Python's standard library. Since the +BSON library has a :class:`~bson.decimal128.Decimal128` type (that implements +the IEEE 754 decimal128 decimal-based floating-point numbering format) which +is distinct from Python's built-in :py:class:`~decimal.Decimal` type, when we +try to save an instance of ``Decimal`` with PyMongo, we get an +:exc:`~bson.errors.InvalidDocument` exception. + +.. doctest:: + + >>> from decimal import Decimal + >>> num = Decimal("45.321") + >>> db.test.insert_one({'num': num}) + Traceback (most recent call last): + ... + bson.errors.InvalidDocument: Cannot encode object: <__main__.Decimal object at ...> + + +.. _custom-type-type-codec: + +The Type Codec +-------------- + +In order to encode custom types, we must first define a **type codec** for our +type. A type codec describes how an instance of a custom type can be +*transformed* to/from one of the types :mod:`~bson` already understands, and +can encode/decode. Type codecs must inherit from +:class:`~bson.codec_options.TypeCodecBase`. In order to encode a custom type, +a codec must implement the ``python_type`` property and the +``transform_python`` method. Similarly, in order to decode a custom type, +a codec must implement the ``bson_type`` property and the ``transform_bson`` +method. Note that a type codec need not support both encoding and decoding. + + +The type codec for our custom type simply needs to define how a +:py:class:`~decimal.Decimal` instance can be converted into a +:class:`~bson.decimal128.Decimal128` instance and vice-versa: + +.. doctest:: + + >>> from bson.decimal128 import Decimal128 + >>> from bson.codec_options import TypeCodecBase + >>> class DecimalCodec(TypeCodecBase): + ... @property + ... def python_type(self): + ... """The Python type acted upon by this type codec.""" + ... return Decimal + ... + ... def transform_python(self, value): + ... """Function that transforms a custom type value into a type + ... that BSON can encode.""" + ... return Decimal128(value) + ... + ... @property + ... def bson_type(self): + ... """The BSON type acted upon by this type codec.""" + ... return Decimal128 + ... + ... def transform_bson(self, value): + ... """Function that transforms a vanilla BSON type value into our + ... custom type.""" + ... return value.to_decimal() + >>> decimal_codec = DecimalCodec() + + +.. _custom-type-type-registry: + +The Type Registry +----------------- + +Before we can begin encoding and decoding our custom type objects, we must +first inform PyMongo about our type codec. This is done by creating a +:class:`~bson.codec_options.TypeRegistry` instance: + +.. doctest:: + + >>> from bson.codec_options import TypeRegistry + >>> type_registry = TypeRegistry(decimal_codec) + + +Note that type registries can be instantiated with any number of type codecs. +Once instantiated, registries are immutable and the only way to add codecs +to a registry is to create a new one. + + +Putting it together +------------------- + +Finally, we can define a :class:`~bson.codec_options.CodecOptions` instance +with our ``type_registry`` and use it to get a +:class:`~pymongo.collection.Collection` object that understands the +:py:class:`~decimal.Decimal` data type: + +.. doctest:: + + >>> from bson.codec_options import CodecOptions + >>> codec_options = CodecOptions(type_registry=type_registry) + >>> collection = db.get_collection('test', codec_options=codec_options) + + +Now, we can seamlessly encode and decode instances of +:py:class:`~decimal.Decimal`: + +.. doctest:: + + >>> collection.insert_one({'num': Decimal("45.321")}) + + >>> mydoc = collection.find_one() + >>> import pprint + >>> pprint.pprint(mydoc) + {u'_id': ObjectId('...'), u'num': Decimal('45.321')} + + +We can see what's actually being saved to the database by creating a fresh +collection object without the customized codec options and using that to query +MongoDB: + +.. doctest:: + + >>> vanilla_collection = db.get_collection('test') + >>> pprint.pprint(vanilla_collection.find_one()) + {u'_id': ObjectId('...'), u'num': Decimal128('45.321')} diff --git a/doc/examples/index.rst b/doc/examples/index.rst index ab3d3086a..7431acd9e 100644 --- a/doc/examples/index.rst +++ b/doc/examples/index.rst @@ -20,6 +20,7 @@ MongoDB, you can start it like so: authentication collations copydb + custom_type bulk datetimes geo diff --git a/test/__init__.py b/test/__init__.py index 4404add02..41e035084 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -689,7 +689,7 @@ class IntegrationTest(PyMongoTestCase): # Use assertRaisesRegex if available, otherwise use Python 2.7's # deprecated assertRaisesRegexp, with a 'p'. if not hasattr(unittest.TestCase, 'assertRaisesRegex'): - IntegrationTest.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp + unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp class MockClientTest(unittest.TestCase): diff --git a/test/test_bson.py b/test/test_bson.py index 86653c9c9..520dd7824 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -34,7 +34,7 @@ from bson import (BSON, Regex) from bson.binary import Binary, UUIDLegacy from bson.code import Code -from bson.codec_options import CodecOptions +from bson.codec_options import CodecOptions, TypeCodecBase, TypeRegistry from bson.int64 import Int64 from bson.objectid import ObjectId from bson.dbref import DBRef @@ -906,6 +906,92 @@ class TestBSON(unittest.TestCase): BSON.encode({"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}}) +class TestTypeRegistry(unittest.TestCase): + @classmethod + def setUpClass(cls): + class MyIntType(object): + def __init__(self, x): + assert isinstance(x, int) + self.x = x + + class MyStrType(object): + def __init__(self, x): + assert isinstance(x, str) + self.x = x + + class MyIntCodec(TypeCodecBase): + @property + def python_type(self): + return MyIntType + + @property + def bson_type(self): + return int + + def transform_python(self, value): + return value.x + + def transform_bson(self, value): + return MyIntType(value) + + class MyStrCodec(TypeCodecBase): + @property + def python_type(self): + return MyStrType + + @property + def bson_type(self): + return str + + def transform_python(self, value): + return value.x + + def transform_bson(self, value): + return MyStrType(value) + + cls.types = (MyIntType, MyStrType) + cls.codecs = (MyIntCodec, MyStrCodec) + + def test_simple(self): + codec_instances = [codec() for codec in self.codecs] + type_registry = TypeRegistry(*codec_instances) + self.assertEqual(type_registry._encoder_map, { + self.types[0]: codec_instances[0].transform_python, + self.types[1]: codec_instances[1].transform_python}) + self.assertEqual(type_registry._decoder_map, { + int: codec_instances[0].transform_bson, + str: codec_instances[1].transform_bson}) + + def test_initialize_fail(self): + err_msg = "Expected an instance of TypeCodecBase, got .* instead" + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry(*self.codecs) + + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry(type('AnyType', (object,), {})()) + + def test_not_implemented(self): + type_registry = TypeRegistry(type("codec1", (TypeCodecBase, ), {})(), + type("codec2", (TypeCodecBase, ), {})()) + self.assertEqual(type_registry._encoder_map, {}) + self.assertEqual(type_registry._decoder_map, {}) + + def test_type_registry_repr(self): + codec_instances = [codec() for codec in self.codecs] + type_registry = TypeRegistry(*codec_instances) + r = ("TypeRegistry(%s, %s)" % tuple(codec_instances)) + self.assertEqual(r, repr(type_registry)) + + def test_type_registry_eq(self): + codec_instances = [codec() for codec in self.codecs] + self.assertEqual( + TypeRegistry(*codec_instances), TypeRegistry(*codec_instances)) + + codec_instances_2 = [codec() for codec in self.codecs] + self.assertNotEqual( + TypeRegistry(*codec_instances), TypeRegistry(*codec_instances_2)) + + class TestCodecOptions(unittest.TestCase): def test_document_class(self): self.assertRaises(TypeError, CodecOptions, document_class=object) @@ -931,7 +1017,7 @@ class TestCodecOptions(unittest.TestCase): r = ("CodecOptions(document_class=dict, tz_aware=False, " "uuid_representation=PYTHON_LEGACY, " "unicode_decode_error_handler='strict', " - "tzinfo=None)") + "tzinfo=None, type_registry=TypeRegistry())") self.assertEqual(r, repr(CodecOptions())) def test_decode_all_defaults(self): diff --git a/test/test_bson_custom_types.py b/test/test_bson_custom_types.py new file mode 100644 index 000000000..c3a8f78a3 --- /dev/null +++ b/test/test_bson_custom_types.py @@ -0,0 +1,119 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test support for callbacks to encode/decode custom types.""" + +import sys +import tempfile +from decimal import Decimal +from random import random + +sys.path[0:0] = [""] + +from bson import (BSON, + Decimal128, + decode_all, + decode_file_iter, + decode_iter, + _dict_to_bson, + _bson_to_dict) +from bson.codec_options import CodecOptions, TypeCodecBase, TypeRegistry + +from test import unittest + + +class DecimalCodec(TypeCodecBase): + @property + def bson_type(self): + return Decimal128 + + @property + def python_type(self): + return Decimal + + def transform_bson(self, value): + return value.to_decimal() + + def transform_python(self, value): + return Decimal128(value) + + +class TestCustomPythonTypeToBSON(unittest.TestCase): + @classmethod + def setUpClass(cls): + type_registry = TypeRegistry(DecimalCodec()) + codec_options = CodecOptions(type_registry=type_registry) + cls.codecopts = codec_options + + def test_encode_decode_roundtrip(self): + document = {'average': Decimal('56.47')} + bsonbytes = BSON().encode(document, codec_options=self.codecopts) + rt_document = BSON(bsonbytes).decode(codec_options=self.codecopts) + self.assertEqual(document, rt_document) + + def test_decode_all(self): + documents = [] + for dec in range(3): + documents.append({'average': Decimal('56.4%s' % (dec,))}) + + bsonstream = bytes() + for doc in documents: + bsonstream += BSON.encode(doc, codec_options=self.codecopts) + + self.assertEqual( + decode_all(bsonstream, self.codecopts), documents) + + def test__bson_to_dict(self): + document = {'average': Decimal('56.47')} + rawbytes = BSON.encode(document, codec_options=self.codecopts) + decoded_document = _bson_to_dict(rawbytes, self.codecopts) + self.assertEqual(document, decoded_document) + + def test__dict_to_bson(self): + document = {'average': Decimal('56.47')} + rawbytes = BSON.encode(document, codec_options=self.codecopts) + encoded_document = _dict_to_bson(document, False, self.codecopts) + self.assertEqual(encoded_document, rawbytes) + + def _generate_multidocument_bson_stream(self): + inp_num = [str(random() * 100)[:4] for _ in range(10)] + docs = [{'n': Decimal128(dec)} for dec in inp_num] + edocs = [{'n': Decimal(dec)} for dec in inp_num] + bsonstream = b"" + for doc in docs: + bsonstream += BSON.encode(doc) + return edocs, bsonstream + + def test_decode_iter(self): + expected, bson_data = self._generate_multidocument_bson_stream() + for expected_doc, decoded_doc in zip( + expected, decode_iter(bson_data, self.codecopts)): + self.assertEqual(expected_doc, decoded_doc) + + def test_decode_file_iter(self): + expected, bson_data = self._generate_multidocument_bson_stream() + fileobj = tempfile.TemporaryFile() + fileobj.write(bson_data) + fileobj.seek(0) + + for expected_doc, decoded_doc in zip( + expected, decode_file_iter(fileobj, self.codecopts)): + self.assertEqual(expected_doc, decoded_doc) + + fileobj.close() + + + +if __name__ == "__main__": + unittest.main()