From 2cb34e4efc84445abb18c45572b2a35c4a81d3e5 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 17 Apr 2019 12:26:56 -0700 Subject: [PATCH] PYTHON-1814 Support custom type decoder with distinct Fix pure python custom type decoding of bson arrays. --- bson/__init__.py | 24 ++++++++++++++---------- pymongo/collection.py | 18 ++++++++++++------ pymongo/command_cursor.py | 2 +- pymongo/cursor.py | 2 +- pymongo/network.py | 6 ++++++ pymongo/pool.py | 3 +++ test/test_custom_types.py | 37 +++++++++++++++++++------------------ 7 files changed, 56 insertions(+), 36 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index 2fcd668c4..b56e8ed48 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -219,6 +219,7 @@ def _get_array(data, position, obj_end, opts, element_name): append = result.append index = data.index getter = _ELEMENT_GETTER + decoder_map = opts.type_registry._decoder_map while position < end: element_type = data[position:position + 1] @@ -229,6 +230,12 @@ def _get_array(data, position, obj_end, opts, element_name): data, position, obj_end, opts, element_name) except KeyError: _raise_unknown_type(element_type, element_name) + + if decoder_map: + custom_decoder = decoder_map.get(type(value)) + if custom_decoder is not None: + value = custom_decoder(value) + append(value) if position != end + 1: @@ -941,17 +948,15 @@ if _USE_C: def _decode_selective(rawdoc, fields, codec_options): - doc = codec_options.document_class() + doc = {} for key, value in iteritems(rawdoc): if key in fields: - if fields[key] == list: - doc[key] = [_bson_to_dict(r.raw, codec_options) for r in value] - elif fields[key] == dict: - doc[key] = _bson_to_dict(value.raw, codec_options) + if fields[key] == 1: + doc[key] = _bson_to_dict(rawdoc.raw, codec_options)[key] else: doc[key] = _decode_selective(value, fields[key], codec_options) - continue - doc[key] = value + else: + doc[key] = value return doc @@ -970,9 +975,8 @@ def _decode_all_selective(data, codec_options, fields): - `fields`: Map of document namespaces where data that needs to be custom decoded lives or None. For example, to custom decode a list of objects in 'field1.subfield1', the specified value should be - ``{'field1': {'subfield1': list}}``. Use ``dict`` instead of ``list`` - if the field contains a single object to custom decode. If ``fields`` - is an empty map or None, this method is the same as ``decode_all``. + ``{'field1': {'subfield1': 1}}``. If ``fields`` is an empty map or + None, this method is the same as ``decode_all``. :Returns: - `document_list`: Single-member list containing the decoded document. diff --git a/pymongo/collection.py b/pymongo/collection.py index ed948dd23..f10e66033 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -53,7 +53,7 @@ from pymongo.write_concern import WriteConcern _NO_OBJ_ERROR = "No matching object found" _UJOIN = u"%s.%s" -_FIND_AND_MODIFY_DOC_FIELDS = {'value': dict} +_FIND_AND_MODIFY_DOC_FIELDS = {'value': 1} class ReturnDocument(object): @@ -225,6 +225,11 @@ class Collection(common.BaseObject): :class:`~pymongo.collation.Collation`. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. + - `retryable_write` (optional): True if this command is a retryable + write. + - `user_fields` (optional): Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. :Returns: The result document. @@ -2313,7 +2318,7 @@ class Collection(common.BaseObject): collation=collation, session=session, client=self.__database.client, - user_fields={'cursor': {'firstBatch': list}}) + user_fields={'cursor': {'firstBatch': 1}}) if "cursor" in result: cursor = result["cursor"] @@ -2575,7 +2580,7 @@ class Collection(common.BaseObject): with self._socket_for_reads(session=None) as (sock_info, slave_ok): return self._command(sock_info, cmd, slave_ok, collation=collation, - user_fields={'retval': list})["retval"] + user_fields={'retval': 1})["retval"] def rename(self, new_name, session=None, **kwargs): """Rename this collection. @@ -2680,7 +2685,8 @@ class Collection(common.BaseObject): return self._command(sock_info, cmd, slave_ok, read_concern=self.read_concern, collation=collation, - session=session)["values"] + session=session, + user_fields={"values": 1})["values"] def map_reduce(self, map, reduce, out, full_response=False, session=None, **kwargs): @@ -2761,7 +2767,7 @@ class Collection(common.BaseObject): else: write_concern = None if inline: - user_fields = {'results': list} + user_fields = {'results': 1} else: user_fields = None @@ -2820,7 +2826,7 @@ class Collection(common.BaseObject): ("map", map), ("reduce", reduce), ("out", {"inline": 1})]) - user_fields = {'results': list} + user_fields = {'results': 1} collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) with self._socket_for_reads(session) as (sock_info, slave_ok): diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 3b5281d9b..9bb749ed9 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -153,7 +153,7 @@ class CommandCursor(object): user_fields = None legacy_response = True if from_command: - user_fields = {'cursor': {'nextBatch': list}} + user_fields = {'cursor': {'nextBatch': 1}} legacy_response = False docs = self._unpack_response( reply, self.__id, self.__collection.codec_options, diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 481a132b4..2450c2bcc 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -50,7 +50,7 @@ _QUERY_OPTIONS = { "await_data": 32, "exhaust": 64, "partial": 128} -_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': list, 'nextBatch': list}} +_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}} class CursorType(object): diff --git a/pymongo/network.py b/pymongo/network.py index 20c120c12..2e05b2d6a 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -82,6 +82,12 @@ def command(sock, dbname, spec, slave_ok, is_mongos, - `parse_write_concern_error`: Whether to parse the ``writeConcernError`` field in the command response. - `collation`: The collation for this command. + - `compression_ctx`: optional compression Context. + - `use_op_msg`: True if we should use OP_MSG. + - `unacknowledged`: True if this is an unacknowledged command. + - `user_fields` (optional): Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. """ name = next(iter(spec)) ns = dbname + '.$cmd' diff --git a/pymongo/pool.py b/pymongo/pool.py index 8fe8f81f2..14d9f7348 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -534,6 +534,9 @@ class SocketInfo(object): - `client`: optional MongoClient for gossipping $clusterTime. - `retryable_write`: True if this command is a retryable write. - `publish_events`: Should we publish events for this command? + - `user_fields` (optional): Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. """ self.validate_session(client, session) session = _validate_session_write_concern(session, write_concern) diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 93275d542..c830258c9 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -81,7 +81,8 @@ class UndecipherableInt64Type(object): def __eq__(self, other): if isinstance(other, type(self)): return self.value == other.value - return self.value == other + # Does not compare equal to integers. + return False class UndecipherableIntDecoder(TypeDecoder): @@ -115,11 +116,17 @@ UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry( class CustomBSONTypeTests(object): - def test_encode_decode_roundtrip(self): - document = {'average': Decimal('56.47')} - bsonbytes = BSON().encode(document, codec_options=self.codecopts) + def roundtrip(self, doc): + bsonbytes = BSON().encode(doc, codec_options=self.codecopts) rt_document = BSON(bsonbytes).decode(codec_options=self.codecopts) - self.assertEqual(document, rt_document) + self.assertEqual(doc, rt_document) + + def test_encode_decode_roundtrip(self): + self.roundtrip({'average': Decimal('56.47')}) + self.roundtrip({'average': {'b': Decimal('56.47')}}) + self.roundtrip({'average': [Decimal('56.47')]}) + self.roundtrip({'average': [[Decimal('56.47')]]}) + self.roundtrip({'average': [{'b': Decimal('56.47')}]}) def test_decode_all(self): documents = [] @@ -585,24 +592,18 @@ class TestCollectionWCustomType(IntegrationTest): self.assertIsInstance(res['total_qty'], UndecipherableInt64Type) self.assertEqual(res['total_qty'].value, 20) - # collection.distinct does not support custom type decoding def test_distinct_w_custom_type(self): self.db.drop_collection("test") test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS) - test.insert_many([ - {"a": UndecipherableInt64Type(1)}, - {"a": UndecipherableInt64Type(2)}, - {"a": UndecipherableInt64Type(2)}, - {"a": UndecipherableInt64Type(2)}, - {"a": UndecipherableInt64Type(3)}]) + values = [ + UndecipherableInt64Type(1), + UndecipherableInt64Type(2), + UndecipherableInt64Type(3), + {"b": UndecipherableInt64Type(3)}] + test.insert_many({"a": val} for val in values) - distinct = test.distinct("a") - distinct.sort() - - self.assertEqual([ - UndecipherableInt64Type(1), UndecipherableInt64Type(2), - UndecipherableInt64Type(3)], distinct) + self.assertEqual(values, test.distinct("a")) def test_map_reduce_w_custom_type(self): test = self.db.get_collection(