PYTHON-1814 Support custom type decoder with distinct
Fix pure python custom type decoding of bson arrays.
This commit is contained in:
parent
2f06e8a441
commit
2cb34e4efc
@ -219,6 +219,7 @@ def _get_array(data, position, obj_end, opts, element_name):
|
||||
append = result.append
|
||||
index = data.index
|
||||
getter = _ELEMENT_GETTER
|
||||
decoder_map = opts.type_registry._decoder_map
|
||||
|
||||
while position < end:
|
||||
element_type = data[position:position + 1]
|
||||
@ -229,6 +230,12 @@ def _get_array(data, position, obj_end, opts, element_name):
|
||||
data, position, obj_end, opts, element_name)
|
||||
except KeyError:
|
||||
_raise_unknown_type(element_type, element_name)
|
||||
|
||||
if decoder_map:
|
||||
custom_decoder = decoder_map.get(type(value))
|
||||
if custom_decoder is not None:
|
||||
value = custom_decoder(value)
|
||||
|
||||
append(value)
|
||||
|
||||
if position != end + 1:
|
||||
@ -941,17 +948,15 @@ if _USE_C:
|
||||
|
||||
|
||||
def _decode_selective(rawdoc, fields, codec_options):
|
||||
doc = codec_options.document_class()
|
||||
doc = {}
|
||||
for key, value in iteritems(rawdoc):
|
||||
if key in fields:
|
||||
if fields[key] == list:
|
||||
doc[key] = [_bson_to_dict(r.raw, codec_options) for r in value]
|
||||
elif fields[key] == dict:
|
||||
doc[key] = _bson_to_dict(value.raw, codec_options)
|
||||
if fields[key] == 1:
|
||||
doc[key] = _bson_to_dict(rawdoc.raw, codec_options)[key]
|
||||
else:
|
||||
doc[key] = _decode_selective(value, fields[key], codec_options)
|
||||
continue
|
||||
doc[key] = value
|
||||
else:
|
||||
doc[key] = value
|
||||
return doc
|
||||
|
||||
|
||||
@ -970,9 +975,8 @@ def _decode_all_selective(data, codec_options, fields):
|
||||
- `fields`: Map of document namespaces where data that needs
|
||||
to be custom decoded lives or None. For example, to custom decode a
|
||||
list of objects in 'field1.subfield1', the specified value should be
|
||||
``{'field1': {'subfield1': list}}``. Use ``dict`` instead of ``list``
|
||||
if the field contains a single object to custom decode. If ``fields``
|
||||
is an empty map or None, this method is the same as ``decode_all``.
|
||||
``{'field1': {'subfield1': 1}}``. If ``fields`` is an empty map or
|
||||
None, this method is the same as ``decode_all``.
|
||||
|
||||
:Returns:
|
||||
- `document_list`: Single-member list containing the decoded document.
|
||||
|
||||
@ -53,7 +53,7 @@ from pymongo.write_concern import WriteConcern
|
||||
|
||||
_NO_OBJ_ERROR = "No matching object found"
|
||||
_UJOIN = u"%s.%s"
|
||||
_FIND_AND_MODIFY_DOC_FIELDS = {'value': dict}
|
||||
_FIND_AND_MODIFY_DOC_FIELDS = {'value': 1}
|
||||
|
||||
|
||||
class ReturnDocument(object):
|
||||
@ -225,6 +225,11 @@ class Collection(common.BaseObject):
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
- `session` (optional): a
|
||||
:class:`~pymongo.client_session.ClientSession`.
|
||||
- `retryable_write` (optional): True if this command is a retryable
|
||||
write.
|
||||
- `user_fields` (optional): Response fields that should be decoded
|
||||
using the TypeDecoders from codec_options, passed to
|
||||
bson._decode_all_selective.
|
||||
|
||||
:Returns:
|
||||
The result document.
|
||||
@ -2313,7 +2318,7 @@ class Collection(common.BaseObject):
|
||||
collation=collation,
|
||||
session=session,
|
||||
client=self.__database.client,
|
||||
user_fields={'cursor': {'firstBatch': list}})
|
||||
user_fields={'cursor': {'firstBatch': 1}})
|
||||
|
||||
if "cursor" in result:
|
||||
cursor = result["cursor"]
|
||||
@ -2575,7 +2580,7 @@ class Collection(common.BaseObject):
|
||||
with self._socket_for_reads(session=None) as (sock_info, slave_ok):
|
||||
return self._command(sock_info, cmd, slave_ok,
|
||||
collation=collation,
|
||||
user_fields={'retval': list})["retval"]
|
||||
user_fields={'retval': 1})["retval"]
|
||||
|
||||
def rename(self, new_name, session=None, **kwargs):
|
||||
"""Rename this collection.
|
||||
@ -2680,7 +2685,8 @@ class Collection(common.BaseObject):
|
||||
return self._command(sock_info, cmd, slave_ok,
|
||||
read_concern=self.read_concern,
|
||||
collation=collation,
|
||||
session=session)["values"]
|
||||
session=session,
|
||||
user_fields={"values": 1})["values"]
|
||||
|
||||
def map_reduce(self, map, reduce, out, full_response=False, session=None,
|
||||
**kwargs):
|
||||
@ -2761,7 +2767,7 @@ class Collection(common.BaseObject):
|
||||
else:
|
||||
write_concern = None
|
||||
if inline:
|
||||
user_fields = {'results': list}
|
||||
user_fields = {'results': 1}
|
||||
else:
|
||||
user_fields = None
|
||||
|
||||
@ -2820,7 +2826,7 @@ class Collection(common.BaseObject):
|
||||
("map", map),
|
||||
("reduce", reduce),
|
||||
("out", {"inline": 1})])
|
||||
user_fields = {'results': list}
|
||||
user_fields = {'results': 1}
|
||||
collation = validate_collation_or_none(kwargs.pop('collation', None))
|
||||
cmd.update(kwargs)
|
||||
with self._socket_for_reads(session) as (sock_info, slave_ok):
|
||||
|
||||
@ -153,7 +153,7 @@ class CommandCursor(object):
|
||||
user_fields = None
|
||||
legacy_response = True
|
||||
if from_command:
|
||||
user_fields = {'cursor': {'nextBatch': list}}
|
||||
user_fields = {'cursor': {'nextBatch': 1}}
|
||||
legacy_response = False
|
||||
docs = self._unpack_response(
|
||||
reply, self.__id, self.__collection.codec_options,
|
||||
|
||||
@ -50,7 +50,7 @@ _QUERY_OPTIONS = {
|
||||
"await_data": 32,
|
||||
"exhaust": 64,
|
||||
"partial": 128}
|
||||
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': list, 'nextBatch': list}}
|
||||
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
|
||||
|
||||
|
||||
class CursorType(object):
|
||||
|
||||
@ -82,6 +82,12 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
|
||||
- `parse_write_concern_error`: Whether to parse the ``writeConcernError``
|
||||
field in the command response.
|
||||
- `collation`: The collation for this command.
|
||||
- `compression_ctx`: optional compression Context.
|
||||
- `use_op_msg`: True if we should use OP_MSG.
|
||||
- `unacknowledged`: True if this is an unacknowledged command.
|
||||
- `user_fields` (optional): Response fields that should be decoded
|
||||
using the TypeDecoders from codec_options, passed to
|
||||
bson._decode_all_selective.
|
||||
"""
|
||||
name = next(iter(spec))
|
||||
ns = dbname + '.$cmd'
|
||||
|
||||
@ -534,6 +534,9 @@ class SocketInfo(object):
|
||||
- `client`: optional MongoClient for gossipping $clusterTime.
|
||||
- `retryable_write`: True if this command is a retryable write.
|
||||
- `publish_events`: Should we publish events for this command?
|
||||
- `user_fields` (optional): Response fields that should be decoded
|
||||
using the TypeDecoders from codec_options, passed to
|
||||
bson._decode_all_selective.
|
||||
"""
|
||||
self.validate_session(client, session)
|
||||
session = _validate_session_write_concern(session, write_concern)
|
||||
|
||||
@ -81,7 +81,8 @@ class UndecipherableInt64Type(object):
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return self.value == other.value
|
||||
return self.value == other
|
||||
# Does not compare equal to integers.
|
||||
return False
|
||||
|
||||
|
||||
class UndecipherableIntDecoder(TypeDecoder):
|
||||
@ -115,11 +116,17 @@ UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
|
||||
|
||||
|
||||
class CustomBSONTypeTests(object):
|
||||
def test_encode_decode_roundtrip(self):
|
||||
document = {'average': Decimal('56.47')}
|
||||
bsonbytes = BSON().encode(document, codec_options=self.codecopts)
|
||||
def roundtrip(self, doc):
|
||||
bsonbytes = BSON().encode(doc, codec_options=self.codecopts)
|
||||
rt_document = BSON(bsonbytes).decode(codec_options=self.codecopts)
|
||||
self.assertEqual(document, rt_document)
|
||||
self.assertEqual(doc, rt_document)
|
||||
|
||||
def test_encode_decode_roundtrip(self):
|
||||
self.roundtrip({'average': Decimal('56.47')})
|
||||
self.roundtrip({'average': {'b': Decimal('56.47')}})
|
||||
self.roundtrip({'average': [Decimal('56.47')]})
|
||||
self.roundtrip({'average': [[Decimal('56.47')]]})
|
||||
self.roundtrip({'average': [{'b': Decimal('56.47')}]})
|
||||
|
||||
def test_decode_all(self):
|
||||
documents = []
|
||||
@ -585,24 +592,18 @@ class TestCollectionWCustomType(IntegrationTest):
|
||||
self.assertIsInstance(res['total_qty'], UndecipherableInt64Type)
|
||||
self.assertEqual(res['total_qty'].value, 20)
|
||||
|
||||
# collection.distinct does not support custom type decoding
|
||||
def test_distinct_w_custom_type(self):
|
||||
self.db.drop_collection("test")
|
||||
|
||||
test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS)
|
||||
test.insert_many([
|
||||
{"a": UndecipherableInt64Type(1)},
|
||||
{"a": UndecipherableInt64Type(2)},
|
||||
{"a": UndecipherableInt64Type(2)},
|
||||
{"a": UndecipherableInt64Type(2)},
|
||||
{"a": UndecipherableInt64Type(3)}])
|
||||
values = [
|
||||
UndecipherableInt64Type(1),
|
||||
UndecipherableInt64Type(2),
|
||||
UndecipherableInt64Type(3),
|
||||
{"b": UndecipherableInt64Type(3)}]
|
||||
test.insert_many({"a": val} for val in values)
|
||||
|
||||
distinct = test.distinct("a")
|
||||
distinct.sort()
|
||||
|
||||
self.assertEqual([
|
||||
UndecipherableInt64Type(1), UndecipherableInt64Type(2),
|
||||
UndecipherableInt64Type(3)], distinct)
|
||||
self.assertEqual(values, test.distinct("a"))
|
||||
|
||||
def test_map_reduce_w_custom_type(self):
|
||||
test = self.db.get_collection(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user