PYTHON-1783 Decode user-facing documents but not internal driver-server
communications.
(cherry picked from commit 749116287a)
This commit is contained in:
parent
3f7481fe3b
commit
b14dede0fb
@ -215,7 +215,7 @@ def _get_array(data, position, obj_end, opts, element_name):
|
||||
end -= 1
|
||||
result = []
|
||||
|
||||
# Avoid doing global and attibute lookups in the loop.
|
||||
# Avoid doing global and attribute lookups in the loop.
|
||||
append = result.append
|
||||
index = data.index
|
||||
getter = _ELEMENT_GETTER
|
||||
@ -940,6 +940,59 @@ if _USE_C:
|
||||
decode_all = _cbson.decode_all
|
||||
|
||||
|
||||
def _decode_selective(rawdoc, fields, codec_options):
|
||||
doc = codec_options.document_class()
|
||||
for key, value in iteritems(rawdoc):
|
||||
if key in fields:
|
||||
if fields[key] == list:
|
||||
doc[key] = [_bson_to_dict(r.raw, codec_options) for r in value]
|
||||
elif fields[key] == dict:
|
||||
doc[key] = _bson_to_dict(value.raw, codec_options)
|
||||
else:
|
||||
doc[key] = _decode_selective(value, fields[key], codec_options)
|
||||
continue
|
||||
doc[key] = value
|
||||
return doc
|
||||
|
||||
|
||||
def _decode_all_selective(data, codec_options, fields):
|
||||
"""Decode BSON data to a single document while using user-provided
|
||||
custom decoding logic.
|
||||
|
||||
`data` must be a string representing a valid, BSON-encoded document.
|
||||
|
||||
:Parameters:
|
||||
- `data`: BSON data
|
||||
- `codec_options`: An instance of
|
||||
:class:`~bson.codec_options.CodecOptions` with user-specified type
|
||||
decoders. If no decoders are found, this method is the same as
|
||||
``decode_all``.
|
||||
- `fields`: Map of document namespaces where data that needs
|
||||
to be custom decoded lives or None. For example, to custom decode a
|
||||
list of objects in 'field1.subfield1', the specified value should be
|
||||
``{'field1': {'subfield1': list}}``. Use ``dict`` instead of ``list``
|
||||
if the field contains a single object to custom decode. If ``fields``
|
||||
is an empty map or None, this method is the same as ``decode_all``.
|
||||
|
||||
:Returns:
|
||||
- `document_list`: Single-member list containing the decoded document.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
if not codec_options.type_registry._decoder_map:
|
||||
return decode_all(data, codec_options)
|
||||
|
||||
if not fields:
|
||||
return decode_all(data, codec_options.with_options(type_registry=None))
|
||||
|
||||
# Decode documents for internal use.
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
internal_codec_options = codec_options.with_options(
|
||||
document_class=RawBSONDocument, type_registry=None)
|
||||
_doc = _bson_to_dict(data, internal_codec_options)
|
||||
return [_decode_selective(_doc, fields, codec_options,)]
|
||||
|
||||
|
||||
def decode_iter(data, codec_options=DEFAULT_CODEC_OPTIONS):
|
||||
"""Decode BSON data to multiple documents as a generator.
|
||||
|
||||
|
||||
@ -25,7 +25,8 @@ from gridfs.errors import NoFile
|
||||
from gridfs.grid_file import (GridIn,
|
||||
GridOut,
|
||||
GridOutCursor,
|
||||
DEFAULT_CHUNK_SIZE)
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
_clear_entity_type_registry)
|
||||
from pymongo import (ASCENDING,
|
||||
DESCENDING)
|
||||
from pymongo.common import UNAUTHORIZED_CODES, validate_string
|
||||
@ -61,6 +62,8 @@ class GridFS(object):
|
||||
if not isinstance(database, Database):
|
||||
raise TypeError("database must be an instance of Database")
|
||||
|
||||
database = _clear_entity_type_registry(database)
|
||||
|
||||
if not database.write_concern.acknowledged:
|
||||
raise ConfigurationError('database must use '
|
||||
'acknowledged write_concern')
|
||||
@ -443,6 +446,8 @@ class GridFSBucket(object):
|
||||
if not isinstance(db, Database):
|
||||
raise TypeError("database must be an instance of Database")
|
||||
|
||||
db = _clear_entity_type_registry(db)
|
||||
|
||||
wtc = write_concern if write_concern is not None else db.write_concern
|
||||
if not wtc.acknowledged:
|
||||
raise ConfigurationError('write concern must be acknowledged')
|
||||
|
||||
@ -98,6 +98,12 @@ def _grid_out_property(field_name, docstring):
|
||||
return property(getter, doc=docstring)
|
||||
|
||||
|
||||
def _clear_entity_type_registry(entity, **kwargs):
|
||||
"""Clear the given database/collection object's type registry."""
|
||||
codecopts = entity.codec_options.with_options(type_registry=None)
|
||||
return entity.with_options(codec_options=codecopts, **kwargs)
|
||||
|
||||
|
||||
class GridIn(object):
|
||||
"""Class to write data to GridFS.
|
||||
"""
|
||||
@ -168,8 +174,8 @@ class GridIn(object):
|
||||
if "chunk_size" in kwargs:
|
||||
kwargs["chunkSize"] = kwargs.pop("chunk_size")
|
||||
|
||||
coll = root_collection.with_options(
|
||||
read_preference=ReadPreference.PRIMARY)
|
||||
coll = _clear_entity_type_registry(
|
||||
root_collection, read_preference=ReadPreference.PRIMARY)
|
||||
|
||||
if not disable_md5:
|
||||
kwargs["md5"] = hashlib.md5()
|
||||
@ -449,6 +455,8 @@ class GridOut(object):
|
||||
raise TypeError("root_collection must be an "
|
||||
"instance of Collection")
|
||||
|
||||
root_collection = _clear_entity_type_registry(root_collection)
|
||||
|
||||
self.__chunks = root_collection.chunks
|
||||
self.__files = root_collection.files
|
||||
self.__file_id = file_id
|
||||
@ -800,6 +808,8 @@ class GridOutCursor(Cursor):
|
||||
|
||||
.. mongodoc:: cursors
|
||||
"""
|
||||
collection = _clear_entity_type_registry(collection)
|
||||
|
||||
# Hold on to the base "fs" collection to create GridOut objects later.
|
||||
self.__root_collection = collection
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ from pymongo.write_concern import WriteConcern
|
||||
|
||||
_NO_OBJ_ERROR = "No matching object found"
|
||||
_UJOIN = u"%s.%s"
|
||||
_FIND_AND_MODIFY_DOC_FIELDS = {'value': dict}
|
||||
|
||||
|
||||
class ReturnDocument(object):
|
||||
@ -202,7 +203,8 @@ class Collection(common.BaseObject):
|
||||
write_concern=None,
|
||||
collation=None,
|
||||
session=None,
|
||||
retryable_write=False):
|
||||
retryable_write=False,
|
||||
user_fields=None):
|
||||
"""Internal command helper.
|
||||
|
||||
:Parameters:
|
||||
@ -241,7 +243,8 @@ class Collection(common.BaseObject):
|
||||
collation=collation,
|
||||
session=s,
|
||||
client=self.__database.client,
|
||||
retryable_write=retryable_write)
|
||||
retryable_write=retryable_write,
|
||||
user_fields=user_fields)
|
||||
|
||||
def __create(self, options, collation, session):
|
||||
"""Sends a create command with the given options.
|
||||
@ -314,9 +317,8 @@ class Collection(common.BaseObject):
|
||||
"""
|
||||
return self.__database
|
||||
|
||||
def with_options(
|
||||
self, codec_options=None, read_preference=None,
|
||||
write_concern=None, read_concern=None):
|
||||
def with_options(self, codec_options=None, read_preference=None,
|
||||
write_concern=None, read_concern=None):
|
||||
"""Get a clone of this collection changing the specified settings.
|
||||
|
||||
>>> coll1.read_preference
|
||||
@ -2307,7 +2309,8 @@ class Collection(common.BaseObject):
|
||||
write_concern=write_concern,
|
||||
collation=collation,
|
||||
session=session,
|
||||
client=self.__database.client)
|
||||
client=self.__database.client,
|
||||
user_fields={'cursor': {'firstBatch': list}})
|
||||
|
||||
if "cursor" in result:
|
||||
cursor = result["cursor"]
|
||||
@ -2565,7 +2568,8 @@ class Collection(common.BaseObject):
|
||||
|
||||
with self._socket_for_reads(session=None) as (sock_info, slave_ok):
|
||||
return self._command(sock_info, cmd, slave_ok,
|
||||
collation=collation)["retval"]
|
||||
collation=collation,
|
||||
user_fields={'retval': list})["retval"]
|
||||
|
||||
def rename(self, new_name, session=None, **kwargs):
|
||||
"""Rename this collection.
|
||||
@ -2669,7 +2673,8 @@ class Collection(common.BaseObject):
|
||||
with self._socket_for_reads(session) as (sock_info, slave_ok):
|
||||
return self._command(sock_info, cmd, slave_ok,
|
||||
read_concern=self.read_concern,
|
||||
collation=collation, session=session)["values"]
|
||||
collation=collation,
|
||||
session=session)["values"]
|
||||
|
||||
def map_reduce(self, map, reduce, out, full_response=False, session=None,
|
||||
**kwargs):
|
||||
@ -2749,12 +2754,17 @@ class Collection(common.BaseObject):
|
||||
write_concern = self._write_concern_for(session)
|
||||
else:
|
||||
write_concern = None
|
||||
if inline:
|
||||
user_fields = {'results': list}
|
||||
else:
|
||||
user_fields = None
|
||||
|
||||
response = self._command(
|
||||
sock_info, cmd, slave_ok, read_pref,
|
||||
read_concern=read_concern,
|
||||
write_concern=write_concern,
|
||||
collation=collation, session=session)
|
||||
collation=collation, session=session,
|
||||
user_fields=user_fields)
|
||||
|
||||
if full_response or not response.get('result'):
|
||||
return response
|
||||
@ -2804,16 +2814,19 @@ class Collection(common.BaseObject):
|
||||
("map", map),
|
||||
("reduce", reduce),
|
||||
("out", {"inline": 1})])
|
||||
user_fields = {'results': list}
|
||||
collation = validate_collation_or_none(kwargs.pop('collation', None))
|
||||
cmd.update(kwargs)
|
||||
with self._socket_for_reads(session) as (sock_info, slave_ok):
|
||||
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
|
||||
res = self._command(sock_info, cmd, slave_ok,
|
||||
read_concern=self.read_concern,
|
||||
collation=collation, session=session)
|
||||
collation=collation, session=session,
|
||||
user_fields=user_fields)
|
||||
else:
|
||||
res = self._command(sock_info, cmd, slave_ok,
|
||||
collation=collation, session=session)
|
||||
collation=collation, session=session,
|
||||
user_fields=user_fields)
|
||||
|
||||
if full_response:
|
||||
return res
|
||||
@ -2831,6 +2844,7 @@ class Collection(common.BaseObject):
|
||||
return_document=ReturnDocument.BEFORE,
|
||||
array_filters=None, session=None, **kwargs):
|
||||
"""Internal findAndModify helper."""
|
||||
|
||||
common.validate_is_mapping("filter", filter)
|
||||
if not isinstance(return_document, bool):
|
||||
raise ValueError("return_document must be "
|
||||
@ -2870,8 +2884,10 @@ class Collection(common.BaseObject):
|
||||
write_concern=write_concern,
|
||||
allowable_errors=[_NO_OBJ_ERROR],
|
||||
collation=collation, session=session,
|
||||
retryable_write=retryable_write)
|
||||
retryable_write=retryable_write,
|
||||
user_fields=_FIND_AND_MODIFY_DOC_FIELDS)
|
||||
_check_write_command_response(out)
|
||||
|
||||
return out.get("value")
|
||||
|
||||
return self.__database.client._retryable_write(
|
||||
@ -3287,7 +3303,8 @@ class Collection(common.BaseObject):
|
||||
result = self._command(
|
||||
sock_info, cmd, read_preference=ReadPreference.PRIMARY,
|
||||
allowable_errors=[_NO_OBJ_ERROR], collation=collation,
|
||||
session=session, retryable_write=retryable_write)
|
||||
session=session, retryable_write=retryable_write,
|
||||
user_fields=_FIND_AND_MODIFY_DOC_FIELDS)
|
||||
|
||||
_check_write_command_response(result)
|
||||
return result
|
||||
|
||||
@ -149,12 +149,17 @@ class CommandCursor(object):
|
||||
reply = response.data
|
||||
|
||||
try:
|
||||
docs = self._unpack_response(reply,
|
||||
self.__id,
|
||||
self.__collection.codec_options)
|
||||
user_fields = None
|
||||
legacy_response = True
|
||||
if from_command:
|
||||
user_fields = {'cursor': {'nextBatch': list}}
|
||||
legacy_response = False
|
||||
docs = self._unpack_response(
|
||||
reply, self.__id, self.__collection.codec_options,
|
||||
legacy_response=legacy_response, user_fields=user_fields)
|
||||
if from_command:
|
||||
first = docs[0]
|
||||
client._receive_cluster_time(first, self.__session)
|
||||
client._process_response(first, self.__session)
|
||||
helpers._check_command_response(first)
|
||||
|
||||
except OperationFailure as exc:
|
||||
@ -174,7 +179,6 @@ class CommandCursor(object):
|
||||
listeners.publish_command_failure(
|
||||
duration(), exc.details, "getMore", rqst_id, self.__address)
|
||||
|
||||
client._reset_server_and_request_check(self.address)
|
||||
raise
|
||||
except Exception as exc:
|
||||
if publish:
|
||||
@ -208,8 +212,10 @@ class CommandCursor(object):
|
||||
kill()
|
||||
self.__data = deque(documents)
|
||||
|
||||
def _unpack_response(self, response, cursor_id, codec_options):
|
||||
return response.unpack_response(cursor_id, codec_options)
|
||||
def _unpack_response(self, response, cursor_id, codec_options,
|
||||
user_fields=None, legacy_response=False):
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields,
|
||||
legacy_response)
|
||||
|
||||
def _refresh(self):
|
||||
"""Refreshes the cursor with more data from the server.
|
||||
@ -330,7 +336,8 @@ class RawBatchCommandCursor(CommandCursor):
|
||||
collection, cursor_info, address, retrieved, batch_size,
|
||||
max_await_time_ms, session, explicit_session)
|
||||
|
||||
def _unpack_response(self, response, cursor_id, codec_options):
|
||||
def _unpack_response(self, response, cursor_id, codec_options,
|
||||
user_fields=None, legacy_response=False):
|
||||
return response.raw_response(cursor_id)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
@ -40,7 +40,7 @@ from pymongo.message import (_convert_exception,
|
||||
_RawBatchGetMore,
|
||||
_Query,
|
||||
_RawBatchQuery)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
|
||||
|
||||
_QUERY_OPTIONS = {
|
||||
"tailable_cursor": 2,
|
||||
@ -50,6 +50,7 @@ _QUERY_OPTIONS = {
|
||||
"await_data": 32,
|
||||
"exhaust": 64,
|
||||
"partial": 128}
|
||||
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': list, 'nextBatch': list}}
|
||||
|
||||
|
||||
class CursorType(object):
|
||||
@ -995,12 +996,17 @@ class Cursor(object):
|
||||
raise
|
||||
|
||||
try:
|
||||
docs = self._unpack_response(response=reply,
|
||||
cursor_id=self.__id,
|
||||
codec_options=self.__codec_options)
|
||||
user_fields = None
|
||||
legacy_response = True
|
||||
if from_command:
|
||||
user_fields = _CURSOR_DOC_FIELDS
|
||||
legacy_response = False
|
||||
docs = self._unpack_response(
|
||||
reply, self.__id, self.__collection.codec_options,
|
||||
legacy_response=legacy_response, user_fields=user_fields)
|
||||
if from_command:
|
||||
first = docs[0]
|
||||
client._receive_cluster_time(first, self.__session)
|
||||
client._process_response(first, self.__session)
|
||||
helpers._check_command_response(first)
|
||||
except OperationFailure as exc:
|
||||
self.__killed = True
|
||||
@ -1031,7 +1037,6 @@ class Cursor(object):
|
||||
listeners.publish_command_failure(
|
||||
duration(), exc.details, cmd_name, rqst_id, self.__address)
|
||||
|
||||
client._reset_server_and_request_check(self.__address)
|
||||
raise
|
||||
except Exception as exc:
|
||||
if publish:
|
||||
@ -1085,8 +1090,10 @@ class Cursor(object):
|
||||
if self.__limit and self.__id and self.__limit <= self.__retrieved:
|
||||
self.__die()
|
||||
|
||||
def _unpack_response(self, response, cursor_id, codec_options):
|
||||
return response.unpack_response(cursor_id, codec_options)
|
||||
def _unpack_response(self, response, cursor_id, codec_options,
|
||||
user_fields=None, legacy_response=False):
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields,
|
||||
legacy_response)
|
||||
|
||||
def _read_preference(self):
|
||||
if self.__read_preference is None:
|
||||
@ -1303,7 +1310,8 @@ class RawBatchCursor(Cursor):
|
||||
raise InvalidOperation(
|
||||
"Cannot use RawBatchCursor with manipulate=True")
|
||||
|
||||
def _unpack_response(self, response, cursor_id, codec_options):
|
||||
def _unpack_response(self, response, cursor_id, codec_options,
|
||||
user_fields=None, legacy_response=False):
|
||||
return response.raw_response(cursor_id)
|
||||
|
||||
def explain(self):
|
||||
|
||||
@ -222,6 +222,46 @@ class Database(common.BaseObject):
|
||||
return [manipulator.__class__.__name__
|
||||
for manipulator in self.__outgoing_copying_manipulators]
|
||||
|
||||
def with_options(self, codec_options=None, read_preference=None,
|
||||
write_concern=None, read_concern=None):
|
||||
"""Get a clone of this database changing the specified settings.
|
||||
|
||||
>>> db1.read_preference
|
||||
Primary()
|
||||
>>> from pymongo import ReadPreference
|
||||
>>> db2 = db1.with_options(read_preference=ReadPreference.SECONDARY)
|
||||
>>> db1.read_preference
|
||||
Primary()
|
||||
>>> db2.read_preference
|
||||
Secondary(tag_sets=None)
|
||||
|
||||
:Parameters:
|
||||
- `codec_options` (optional): An instance of
|
||||
:class:`~bson.codec_options.CodecOptions`. If ``None`` (the
|
||||
default) the :attr:`codec_options` of this :class:`Collection`
|
||||
is used.
|
||||
- `read_preference` (optional): The read preference to use. If
|
||||
``None`` (the default) the :attr:`read_preference` of this
|
||||
:class:`Collection` is used. See :mod:`~pymongo.read_preferences`
|
||||
for options.
|
||||
- `write_concern` (optional): An instance of
|
||||
:class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the
|
||||
default) the :attr:`write_concern` of this :class:`Collection`
|
||||
is used.
|
||||
- `read_concern` (optional): An instance of
|
||||
:class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the
|
||||
default) the :attr:`read_concern` of this :class:`Collection`
|
||||
is used.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
return Database(self.client,
|
||||
self.__name,
|
||||
codec_options or self.codec_options,
|
||||
read_preference or self.read_preference,
|
||||
write_concern or self.write_concern,
|
||||
read_concern or self.read_concern)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Database):
|
||||
return (self.__client == other.client and
|
||||
|
||||
@ -1398,7 +1398,8 @@ class _OpReply(object):
|
||||
return [self.documents]
|
||||
|
||||
def unpack_response(self, cursor_id=None,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS):
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields=None, legacy_response=False):
|
||||
"""Unpack a response from the database and decode the BSON document(s).
|
||||
|
||||
Check the response for errors and unpack, returning a dictionary
|
||||
@ -1415,7 +1416,10 @@ class _OpReply(object):
|
||||
:class:`~bson.codec_options.CodecOptions`
|
||||
"""
|
||||
self.raw_response(cursor_id)
|
||||
return bson.decode_all(self.documents, codec_options)
|
||||
if legacy_response:
|
||||
return bson.decode_all(self.documents, codec_options)
|
||||
return bson._decode_all_selective(
|
||||
self.documents, codec_options, user_fields)
|
||||
|
||||
def command_response(self):
|
||||
"""Unpack a command response."""
|
||||
@ -1451,7 +1455,8 @@ class _OpMsg(object):
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_response(self, cursor_id=None,
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS):
|
||||
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields=None, legacy_response=False):
|
||||
"""Unpack a OP_MSG command response.
|
||||
|
||||
:Parameters:
|
||||
@ -1459,7 +1464,10 @@ class _OpMsg(object):
|
||||
- `codec_options` (optional): an instance of
|
||||
:class:`~bson.codec_options.CodecOptions`
|
||||
"""
|
||||
return bson.decode_all(self.payload_document, codec_options)
|
||||
# If _OpMsg is in-use, this cannot be a legacy response.
|
||||
assert not legacy_response
|
||||
return bson._decode_all_selective(
|
||||
self.payload_document, codec_options, user_fields)
|
||||
|
||||
def command_response(self):
|
||||
"""Unpack a command response."""
|
||||
|
||||
@ -58,7 +58,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
|
||||
collation=None,
|
||||
compression_ctx=None,
|
||||
use_op_msg=False,
|
||||
unacknowledged=False):
|
||||
unacknowledged=False,
|
||||
user_fields=None):
|
||||
"""Execute a command over the socket, or raise socket.error.
|
||||
|
||||
:Parameters:
|
||||
@ -139,7 +140,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
|
||||
response_doc = {"ok": 1}
|
||||
else:
|
||||
reply = receive_message(sock, request_id)
|
||||
unpacked_docs = reply.unpack_response(codec_options=codec_options)
|
||||
unpacked_docs = reply.unpack_response(
|
||||
codec_options=codec_options, user_fields=user_fields)
|
||||
|
||||
response_doc = unpacked_docs[0]
|
||||
if client:
|
||||
|
||||
@ -512,7 +512,8 @@ class SocketInfo(object):
|
||||
session=None,
|
||||
client=None,
|
||||
retryable_write=False,
|
||||
publish_events=True):
|
||||
publish_events=True,
|
||||
user_fields=None):
|
||||
"""Execute a command or raise an error.
|
||||
|
||||
:Parameters:
|
||||
@ -576,7 +577,8 @@ class SocketInfo(object):
|
||||
collation=collation,
|
||||
compression_ctx=self.compression_context,
|
||||
use_op_msg=self.op_msg_enabled,
|
||||
unacknowledged=unacknowledged)
|
||||
unacknowledged=unacknowledged,
|
||||
user_fields=user_fields)
|
||||
except OperationFailure:
|
||||
raise
|
||||
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
|
||||
|
||||
@ -132,6 +132,12 @@ class TestCollection(IntegrationTest):
|
||||
def tearDownClass(cls):
|
||||
cls.db.drop_collection("test_large_limit")
|
||||
|
||||
def setUp(self):
|
||||
self.db.test.drop()
|
||||
|
||||
def tearDown(self):
|
||||
self.db.test.drop()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def write_concern_collection(self):
|
||||
if client_context.version.at_least(3, 3, 9) and client_context.is_rs:
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
"""Test support for callbacks to encode/decode custom types."""
|
||||
|
||||
import datetime
|
||||
import sys
|
||||
import tempfile
|
||||
from decimal import Decimal
|
||||
@ -30,11 +31,21 @@ from bson import (BSON,
|
||||
_BUILT_IN_TYPES,
|
||||
_dict_to_bson,
|
||||
_bson_to_dict)
|
||||
from bson.code import Code
|
||||
from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder,
|
||||
TypeEncoder, TypeRegistry)
|
||||
from bson.errors import InvalidDocument
|
||||
from bson.int64 import Int64
|
||||
from bson.py3compat import text_type
|
||||
|
||||
from test import unittest
|
||||
from gridfs import GridIn, GridOut
|
||||
|
||||
from pymongo.collection import ReturnDocument
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
|
||||
from test import client_context, unittest
|
||||
from test.test_client import IntegrationTest
|
||||
from test.utils import ignore_deprecations
|
||||
|
||||
|
||||
class DecimalEncoder(TypeEncoder):
|
||||
@ -59,7 +70,51 @@ class DecimalCodec(DecimalDecoder, DecimalEncoder):
|
||||
pass
|
||||
|
||||
|
||||
class CustomTypeTests(object):
|
||||
DECIMAL_CODECOPTS = CodecOptions(
|
||||
type_registry=TypeRegistry([DecimalCodec()]))
|
||||
|
||||
|
||||
class UndecipherableInt64Type(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return self.value == other.value
|
||||
return self.value == other
|
||||
|
||||
|
||||
class UndecipherableIntDecoder(TypeDecoder):
|
||||
bson_type = Int64
|
||||
def transform_bson(self, value):
|
||||
return UndecipherableInt64Type(value)
|
||||
|
||||
|
||||
class UndecipherableIntEncoder(TypeEncoder):
|
||||
python_type = UndecipherableInt64Type
|
||||
def transform_python(self, value):
|
||||
return Int64(value.value)
|
||||
|
||||
|
||||
UNINT_DECODER_CODECOPTS = CodecOptions(
|
||||
type_registry=TypeRegistry([UndecipherableIntDecoder(), ]))
|
||||
|
||||
|
||||
UNINT_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
|
||||
[UndecipherableIntDecoder(), UndecipherableIntEncoder()]))
|
||||
|
||||
|
||||
class UppercaseTextDecoder(TypeDecoder):
|
||||
bson_type = text_type
|
||||
def transform_bson(self, value):
|
||||
return value.upper()
|
||||
|
||||
|
||||
UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
|
||||
[UppercaseTextDecoder(),]))
|
||||
|
||||
|
||||
class CustomBSONTypeTests(object):
|
||||
def test_encode_decode_roundtrip(self):
|
||||
document = {'average': Decimal('56.47')}
|
||||
bsonbytes = BSON().encode(document, codec_options=self.codecopts)
|
||||
@ -117,25 +172,24 @@ class CustomTypeTests(object):
|
||||
|
||||
fileobj.close()
|
||||
|
||||
class TestCustomPythonTypeToBSONMonolithicCodec(CustomTypeTests,
|
||||
unittest.TestCase):
|
||||
|
||||
class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests,
|
||||
unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
type_registry = TypeRegistry((DecimalCodec(),))
|
||||
codec_options = CodecOptions(type_registry=type_registry)
|
||||
cls.codecopts = DECIMAL_CODECOPTS
|
||||
|
||||
|
||||
class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests,
|
||||
unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
codec_options = CodecOptions(
|
||||
type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder())))
|
||||
cls.codecopts = codec_options
|
||||
|
||||
|
||||
class TestCustomPythonTypeToBSONMultiplexedCodec(CustomTypeTests,
|
||||
unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
type_registry = TypeRegistry((DecimalEncoder(), DecimalDecoder()))
|
||||
codec_options = CodecOptions(type_registry=type_registry)
|
||||
cls.codecopts = codec_options
|
||||
|
||||
|
||||
class TestFallbackEncoder(unittest.TestCase):
|
||||
class TestBSONFallbackEncoder(unittest.TestCase):
|
||||
def _get_codec_options(self, fallback_encoder):
|
||||
type_registry = TypeRegistry(fallback_encoder=fallback_encoder)
|
||||
return CodecOptions(type_registry=type_registry)
|
||||
@ -180,7 +234,7 @@ class TestFallbackEncoder(unittest.TestCase):
|
||||
BSON().encode(document, codec_options=codecopts)
|
||||
|
||||
|
||||
class TestTypeEnDeCodecs(unittest.TestCase):
|
||||
class TestBSONTypeEnDeCodecs(unittest.TestCase):
|
||||
def test_instantiation(self):
|
||||
msg = "Can't instantiate abstract class .* with abstract methods .*"
|
||||
def run_test(base, attrs, fail):
|
||||
@ -220,7 +274,7 @@ class TestTypeEnDeCodecs(unittest.TestCase):
|
||||
self.assertFalse(issubclass(TypeEncoder, TypeDecoder))
|
||||
|
||||
|
||||
class TestCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
|
||||
class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
class TypeA(object):
|
||||
@ -467,5 +521,192 @@ class TestTypeRegistry(unittest.TestCase):
|
||||
'transform_bson': lambda x: x})
|
||||
|
||||
|
||||
class TestCollectionWCustomType(IntegrationTest):
|
||||
def setUp(self):
|
||||
self.db.test.drop()
|
||||
|
||||
def tearDown(self):
|
||||
self.db.test.drop()
|
||||
|
||||
def test_command_errors_w_custom_type_decoder(self):
|
||||
db = self.db
|
||||
test_doc = {'_id': 1, 'data': 'a'}
|
||||
test = db.get_collection('test',
|
||||
codec_options=UNINT_DECODER_CODECOPTS)
|
||||
|
||||
result = test.insert_one(test_doc)
|
||||
self.assertEqual(result.inserted_id, test_doc['_id'])
|
||||
with self.assertRaises(DuplicateKeyError):
|
||||
test.insert_one(test_doc)
|
||||
|
||||
def test_find_w_custom_type_decoder(self):
|
||||
db = self.db
|
||||
input_docs = [
|
||||
{'x': Int64(k)} for k in [1.0, 2.0, 3.0]]
|
||||
for doc in input_docs:
|
||||
db.test.insert_one(doc)
|
||||
|
||||
test = db.get_collection(
|
||||
'test', codec_options=UNINT_DECODER_CODECOPTS)
|
||||
for doc in test.find({}, batch_size=1):
|
||||
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
|
||||
|
||||
@client_context.require_version_max(4, 1, 0, -1)
|
||||
def test_group_w_custom_type(self):
|
||||
db = self.db
|
||||
test = db.get_collection('test', codec_options=UNINT_CODECOPTS)
|
||||
test.insert_many([
|
||||
{'sku': 'a', 'qty': UndecipherableInt64Type(2)},
|
||||
{'sku': 'b', 'qty': UndecipherableInt64Type(5)},
|
||||
{'sku': 'a', 'qty': UndecipherableInt64Type(1)}])
|
||||
|
||||
self.assertEqual([{'sku': 'b', 'qty': UndecipherableInt64Type(5)},],
|
||||
test.group(["sku", "qty"], {"sku": "b"}, {},
|
||||
"function (obj, prev) { }"))
|
||||
|
||||
def test_aggregate_w_custom_type_decoder(self):
|
||||
db = self.db
|
||||
db.test.insert_many([
|
||||
{'status': 'in progress', 'qty': Int64(1)},
|
||||
{'status': 'complete', 'qty': Int64(10)},
|
||||
{'status': 'in progress', 'qty': Int64(1)},
|
||||
{'status': 'complete', 'qty': Int64(10)},
|
||||
{'status': 'in progress', 'qty': Int64(1)},])
|
||||
test = db.get_collection(
|
||||
'test', codec_options=UNINT_DECODER_CODECOPTS)
|
||||
|
||||
pipeline = [
|
||||
{'$match': {'status': 'complete'}},
|
||||
{'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},]
|
||||
result = test.aggregate(pipeline)
|
||||
|
||||
res = list(result)[0]
|
||||
self.assertEqual(res['_id'], 'complete')
|
||||
self.assertIsInstance(res['total_qty'], UndecipherableInt64Type)
|
||||
self.assertEqual(res['total_qty'].value, 20)
|
||||
|
||||
# collection.distinct does not support custom type decoding
|
||||
def test_distinct_w_custom_type(self):
|
||||
self.db.drop_collection("test")
|
||||
|
||||
test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS)
|
||||
test.insert_many([
|
||||
{"a": UndecipherableInt64Type(1)},
|
||||
{"a": UndecipherableInt64Type(2)},
|
||||
{"a": UndecipherableInt64Type(2)},
|
||||
{"a": UndecipherableInt64Type(2)},
|
||||
{"a": UndecipherableInt64Type(3)}])
|
||||
|
||||
distinct = test.distinct("a")
|
||||
distinct.sort()
|
||||
|
||||
self.assertEqual([
|
||||
UndecipherableInt64Type(1), UndecipherableInt64Type(2),
|
||||
UndecipherableInt64Type(3)], distinct)
|
||||
|
||||
def test_map_reduce_w_custom_type(self):
|
||||
test = self.db.get_collection(
|
||||
'test', codec_options=UPPERSTR_DECODER_CODECOPTS)
|
||||
|
||||
test.insert_many([
|
||||
{'_id': 1, 'sku': 'abcd', 'qty': 1},
|
||||
{'_id': 2, 'sku': 'abcd', 'qty': 2},
|
||||
{'_id': 3, 'sku': 'abcd', 'qty': 3}])
|
||||
|
||||
map = Code("function () {"
|
||||
" emit(this.sku, this.qty);"
|
||||
"}")
|
||||
reduce = Code("function (key, values) {"
|
||||
" return Array.sum(values);"
|
||||
"}")
|
||||
|
||||
result = test.map_reduce(map, reduce, out={'inline': 1})
|
||||
self.assertTrue(isinstance(result, dict))
|
||||
self.assertTrue('results' in result)
|
||||
self.assertEqual(result['results'][0], {'_id': 'ABCD', 'value': 6})
|
||||
|
||||
result = test.inline_map_reduce(map, reduce)
|
||||
self.assertTrue(isinstance(result, list))
|
||||
self.assertEqual(1, len(result))
|
||||
self.assertEqual(result[0]["_id"], 'ABCD')
|
||||
|
||||
full_result = test.inline_map_reduce(map, reduce,
|
||||
full_response=True)
|
||||
self.assertEqual(3, full_result["counts"]["emit"])
|
||||
|
||||
def test_find_one_and__w_custom_type_decoder(self):
|
||||
db = self.db
|
||||
c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS)
|
||||
c.insert_one({'_id': 1, 'x': Int64(1)})
|
||||
|
||||
doc = c.find_one_and_update({'_id': 1}, {'$inc': {'x': 1}},
|
||||
return_document=ReturnDocument.AFTER)
|
||||
self.assertEqual(doc['_id'], 1)
|
||||
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
|
||||
self.assertEqual(doc['x'].value, 2)
|
||||
|
||||
doc = c.find_one_and_replace({'_id': 1}, {'x': Int64(3), 'y': True},
|
||||
return_document=ReturnDocument.AFTER)
|
||||
self.assertEqual(doc['_id'], 1)
|
||||
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
|
||||
self.assertEqual(doc['x'].value, 3)
|
||||
self.assertEqual(doc['y'], True)
|
||||
|
||||
doc = c.find_one_and_delete({'y': True})
|
||||
self.assertEqual(doc['_id'], 1)
|
||||
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
|
||||
self.assertEqual(doc['x'].value, 3)
|
||||
self.assertIsNone(c.find_one())
|
||||
|
||||
@ignore_deprecations
|
||||
def test_find_and_modify_w_custom_type_decoder(self):
|
||||
db = self.db
|
||||
c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS)
|
||||
c.insert_one({'_id': 1, 'x': Int64(1)})
|
||||
|
||||
doc = c.find_and_modify({'_id': 1}, {'$inc': {'x': Int64(10)}})
|
||||
self.assertEqual(doc['_id'], 1)
|
||||
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
|
||||
self.assertEqual(doc['x'].value, 1)
|
||||
|
||||
doc = c.find_one()
|
||||
self.assertEqual(doc['_id'], 1)
|
||||
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
|
||||
self.assertEqual(doc['x'].value, 11)
|
||||
|
||||
|
||||
class TestGridFileCustomType(IntegrationTest):
|
||||
def setUp(self):
|
||||
self.db.drop_collection('fs.files')
|
||||
self.db.drop_collection('fs.chunks')
|
||||
|
||||
def test_grid_out_custom_opts(self):
|
||||
db = self.db.with_options(codec_options=UPPERSTR_DECODER_CODECOPTS)
|
||||
one = GridIn(db.fs, _id=5, filename="my_file",
|
||||
contentType="text/html", chunkSize=1000, aliases=["foo"],
|
||||
metadata={"foo": 'red', "bar": 'blue'}, bar=3,
|
||||
baz="hello")
|
||||
one.write(b"hello world")
|
||||
one.close()
|
||||
|
||||
two = GridOut(db.fs, 5)
|
||||
|
||||
self.assertEqual("my_file", two.name)
|
||||
self.assertEqual("my_file", two.filename)
|
||||
self.assertEqual(5, two._id)
|
||||
self.assertEqual(11, two.length)
|
||||
self.assertEqual("text/html", two.content_type)
|
||||
self.assertEqual(1000, two.chunk_size)
|
||||
self.assertTrue(isinstance(two.upload_date, datetime.datetime))
|
||||
self.assertEqual(["foo"], two.aliases)
|
||||
self.assertEqual({"foo": 'red', "bar": 'blue'}, two.metadata)
|
||||
self.assertEqual(3, two.bar)
|
||||
self.assertEqual("5eb63bbbe01eeed093cb22bb8f5acdc3", two.md5)
|
||||
|
||||
for attr in ["_id", "name", "content_type", "length", "chunk_size",
|
||||
"upload_date", "aliases", "metadata", "md5"]:
|
||||
self.assertRaises(AttributeError, setattr, two, attr, 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -58,6 +58,7 @@ from test.utils import (ignore_deprecations,
|
||||
server_started_with_auth,
|
||||
IMPOSSIBLE_WRITE_CONCERN,
|
||||
OvertCommandListener)
|
||||
from test.test_custom_types import DECIMAL_CODECOPTS
|
||||
|
||||
|
||||
if PY3:
|
||||
@ -976,6 +977,36 @@ class TestDatabase(IntegrationTest):
|
||||
"maxTimeAlwaysTimeOut",
|
||||
mode="off")
|
||||
|
||||
def test_with_options(self):
|
||||
codec_options = DECIMAL_CODECOPTS
|
||||
read_preference = ReadPreference.SECONDARY_PREFERRED
|
||||
write_concern = WriteConcern(j=True)
|
||||
read_concern = ReadConcern(level="majority")
|
||||
|
||||
# List of all options to compare.
|
||||
allopts = ['name', 'client', 'codec_options',
|
||||
'read_preference', 'write_concern', 'read_concern']
|
||||
|
||||
db1 = self.client.get_database(
|
||||
'with_options_test', codec_options=codec_options,
|
||||
read_preference=read_preference, write_concern=write_concern,
|
||||
read_concern=read_concern)
|
||||
|
||||
# Case 1: swap no options
|
||||
db2 = db1.with_options()
|
||||
for opt in allopts:
|
||||
self.assertEqual(getattr(db1, opt), getattr(db2, opt))
|
||||
|
||||
# Case 2: swap all options
|
||||
newopts = {'codec_options': CodecOptions(),
|
||||
'read_preference': ReadPreference.PRIMARY,
|
||||
'write_concern': WriteConcern(w=1),
|
||||
'read_concern': ReadConcern(level="local")}
|
||||
db2 = db1.with_options(**newopts)
|
||||
for opt in newopts:
|
||||
self.assertEqual(
|
||||
getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))
|
||||
|
||||
def test_current_op_codec_options(self):
|
||||
class MySON(SON):
|
||||
pass
|
||||
|
||||
Loading…
Reference in New Issue
Block a user