PYTHON-1814 Support custom type decoder with distinct

Fix pure python custom type decoding of bson arrays.
This commit is contained in:
Shane Harvey 2019-04-17 12:26:56 -07:00
parent 2f06e8a441
commit 2cb34e4efc
7 changed files with 56 additions and 36 deletions

View File

@ -219,6 +219,7 @@ def _get_array(data, position, obj_end, opts, element_name):
append = result.append
index = data.index
getter = _ELEMENT_GETTER
decoder_map = opts.type_registry._decoder_map
while position < end:
element_type = data[position:position + 1]
@ -229,6 +230,12 @@ def _get_array(data, position, obj_end, opts, element_name):
data, position, obj_end, opts, element_name)
except KeyError:
_raise_unknown_type(element_type, element_name)
if decoder_map:
custom_decoder = decoder_map.get(type(value))
if custom_decoder is not None:
value = custom_decoder(value)
append(value)
if position != end + 1:
@ -941,17 +948,15 @@ if _USE_C:
def _decode_selective(rawdoc, fields, codec_options):
doc = codec_options.document_class()
doc = {}
for key, value in iteritems(rawdoc):
if key in fields:
if fields[key] == list:
doc[key] = [_bson_to_dict(r.raw, codec_options) for r in value]
elif fields[key] == dict:
doc[key] = _bson_to_dict(value.raw, codec_options)
if fields[key] == 1:
doc[key] = _bson_to_dict(rawdoc.raw, codec_options)[key]
else:
doc[key] = _decode_selective(value, fields[key], codec_options)
continue
doc[key] = value
else:
doc[key] = value
return doc
@ -970,9 +975,8 @@ def _decode_all_selective(data, codec_options, fields):
- `fields`: Map of document namespaces where data that needs
to be custom decoded lives or None. For example, to custom decode a
list of objects in 'field1.subfield1', the specified value should be
``{'field1': {'subfield1': list}}``. Use ``dict`` instead of ``list``
if the field contains a single object to custom decode. If ``fields``
is an empty map or None, this method is the same as ``decode_all``.
``{'field1': {'subfield1': 1}}``. If ``fields`` is an empty map or
None, this method is the same as ``decode_all``.
:Returns:
- `document_list`: Single-member list containing the decoded document.

View File

@ -53,7 +53,7 @@ from pymongo.write_concern import WriteConcern
_NO_OBJ_ERROR = "No matching object found"
_UJOIN = u"%s.%s"
_FIND_AND_MODIFY_DOC_FIELDS = {'value': dict}
_FIND_AND_MODIFY_DOC_FIELDS = {'value': 1}
class ReturnDocument(object):
@ -225,6 +225,11 @@ class Collection(common.BaseObject):
:class:`~pymongo.collation.Collation`.
- `session` (optional): a
:class:`~pymongo.client_session.ClientSession`.
- `retryable_write` (optional): True if this command is a retryable
write.
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
:Returns:
The result document.
@ -2313,7 +2318,7 @@ class Collection(common.BaseObject):
collation=collation,
session=session,
client=self.__database.client,
user_fields={'cursor': {'firstBatch': list}})
user_fields={'cursor': {'firstBatch': 1}})
if "cursor" in result:
cursor = result["cursor"]
@ -2575,7 +2580,7 @@ class Collection(common.BaseObject):
with self._socket_for_reads(session=None) as (sock_info, slave_ok):
return self._command(sock_info, cmd, slave_ok,
collation=collation,
user_fields={'retval': list})["retval"]
user_fields={'retval': 1})["retval"]
def rename(self, new_name, session=None, **kwargs):
"""Rename this collection.
@ -2680,7 +2685,8 @@ class Collection(common.BaseObject):
return self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
collation=collation,
session=session)["values"]
session=session,
user_fields={"values": 1})["values"]
def map_reduce(self, map, reduce, out, full_response=False, session=None,
**kwargs):
@ -2761,7 +2767,7 @@ class Collection(common.BaseObject):
else:
write_concern = None
if inline:
user_fields = {'results': list}
user_fields = {'results': 1}
else:
user_fields = None
@ -2820,7 +2826,7 @@ class Collection(common.BaseObject):
("map", map),
("reduce", reduce),
("out", {"inline": 1})])
user_fields = {'results': list}
user_fields = {'results': 1}
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
with self._socket_for_reads(session) as (sock_info, slave_ok):

View File

@ -153,7 +153,7 @@ class CommandCursor(object):
user_fields = None
legacy_response = True
if from_command:
user_fields = {'cursor': {'nextBatch': list}}
user_fields = {'cursor': {'nextBatch': 1}}
legacy_response = False
docs = self._unpack_response(
reply, self.__id, self.__collection.codec_options,

View File

@ -50,7 +50,7 @@ _QUERY_OPTIONS = {
"await_data": 32,
"exhaust": 64,
"partial": 128}
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': list, 'nextBatch': list}}
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
class CursorType(object):

View File

@ -82,6 +82,12 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
- `parse_write_concern_error`: Whether to parse the ``writeConcernError``
field in the command response.
- `collation`: The collation for this command.
- `compression_ctx`: optional compression Context.
- `use_op_msg`: True if we should use OP_MSG.
- `unacknowledged`: True if this is an unacknowledged command.
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
"""
name = next(iter(spec))
ns = dbname + '.$cmd'

View File

@ -534,6 +534,9 @@ class SocketInfo(object):
- `client`: optional MongoClient for gossipping $clusterTime.
- `retryable_write`: True if this command is a retryable write.
- `publish_events`: Should we publish events for this command?
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
"""
self.validate_session(client, session)
session = _validate_session_write_concern(session, write_concern)

View File

@ -81,7 +81,8 @@ class UndecipherableInt64Type(object):
def __eq__(self, other):
if isinstance(other, type(self)):
return self.value == other.value
return self.value == other
# Does not compare equal to integers.
return False
class UndecipherableIntDecoder(TypeDecoder):
@ -115,11 +116,17 @@ UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
class CustomBSONTypeTests(object):
def test_encode_decode_roundtrip(self):
document = {'average': Decimal('56.47')}
bsonbytes = BSON().encode(document, codec_options=self.codecopts)
def roundtrip(self, doc):
bsonbytes = BSON().encode(doc, codec_options=self.codecopts)
rt_document = BSON(bsonbytes).decode(codec_options=self.codecopts)
self.assertEqual(document, rt_document)
self.assertEqual(doc, rt_document)
def test_encode_decode_roundtrip(self):
self.roundtrip({'average': Decimal('56.47')})
self.roundtrip({'average': {'b': Decimal('56.47')}})
self.roundtrip({'average': [Decimal('56.47')]})
self.roundtrip({'average': [[Decimal('56.47')]]})
self.roundtrip({'average': [{'b': Decimal('56.47')}]})
def test_decode_all(self):
documents = []
@ -585,24 +592,18 @@ class TestCollectionWCustomType(IntegrationTest):
self.assertIsInstance(res['total_qty'], UndecipherableInt64Type)
self.assertEqual(res['total_qty'].value, 20)
# collection.distinct does not support custom type decoding
def test_distinct_w_custom_type(self):
self.db.drop_collection("test")
test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS)
test.insert_many([
{"a": UndecipherableInt64Type(1)},
{"a": UndecipherableInt64Type(2)},
{"a": UndecipherableInt64Type(2)},
{"a": UndecipherableInt64Type(2)},
{"a": UndecipherableInt64Type(3)}])
values = [
UndecipherableInt64Type(1),
UndecipherableInt64Type(2),
UndecipherableInt64Type(3),
{"b": UndecipherableInt64Type(3)}]
test.insert_many({"a": val} for val in values)
distinct = test.distinct("a")
distinct.sort()
self.assertEqual([
UndecipherableInt64Type(1), UndecipherableInt64Type(2),
UndecipherableInt64Type(3)], distinct)
self.assertEqual(values, test.distinct("a"))
def test_map_reduce_w_custom_type(self):
test = self.db.get_collection(