From 749116287ae970fd3c1d72efaa1c94c92f2b1a65 Mon Sep 17 00:00:00 2001 From: Prashant Mital Date: Tue, 2 Apr 2019 16:47:51 -0700 Subject: [PATCH] PYTHON-1783 Decode user-facing documents but not internal driver-server communications. --- bson/__init__.py | 55 +++- gridfs/__init__.py | 7 +- gridfs/grid_file.py | 14 +- pymongo/collection.py | 43 ++- pymongo/command_cursor.py | 21 +- pymongo/cursor.py | 23 +- pymongo/database.py | 40 +++ pymongo/message.py | 16 +- pymongo/network.py | 6 +- pymongo/pool.py | 6 +- test/test_collection.py | 6 + ...n_custom_types.py => test_custom_types.py} | 277 ++++++++++++++++-- test/test_database.py | 31 ++ 13 files changed, 488 insertions(+), 57 deletions(-) rename test/{test_bson_custom_types.py => test_custom_types.py} (63%) diff --git a/bson/__init__.py b/bson/__init__.py index 464fba92d..2fcd668c4 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -215,7 +215,7 @@ def _get_array(data, position, obj_end, opts, element_name): end -= 1 result = [] - # Avoid doing global and attibute lookups in the loop. + # Avoid doing global and attribute lookups in the loop. append = result.append index = data.index getter = _ELEMENT_GETTER @@ -940,6 +940,59 @@ if _USE_C: decode_all = _cbson.decode_all +def _decode_selective(rawdoc, fields, codec_options): + doc = codec_options.document_class() + for key, value in iteritems(rawdoc): + if key in fields: + if fields[key] == list: + doc[key] = [_bson_to_dict(r.raw, codec_options) for r in value] + elif fields[key] == dict: + doc[key] = _bson_to_dict(value.raw, codec_options) + else: + doc[key] = _decode_selective(value, fields[key], codec_options) + continue + doc[key] = value + return doc + + +def _decode_all_selective(data, codec_options, fields): + """Decode BSON data to a single document while using user-provided + custom decoding logic. + + `data` must be a string representing a valid, BSON-encoded document. + + :Parameters: + - `data`: BSON data + - `codec_options`: An instance of + :class:`~bson.codec_options.CodecOptions` with user-specified type + decoders. If no decoders are found, this method is the same as + ``decode_all``. + - `fields`: Map of document namespaces where data that needs + to be custom decoded lives or None. For example, to custom decode a + list of objects in 'field1.subfield1', the specified value should be + ``{'field1': {'subfield1': list}}``. Use ``dict`` instead of ``list`` + if the field contains a single object to custom decode. If ``fields`` + is an empty map or None, this method is the same as ``decode_all``. + + :Returns: + - `document_list`: Single-member list containing the decoded document. + + .. versionadded:: 3.8 + """ + if not codec_options.type_registry._decoder_map: + return decode_all(data, codec_options) + + if not fields: + return decode_all(data, codec_options.with_options(type_registry=None)) + + # Decode documents for internal use. + from bson.raw_bson import RawBSONDocument + internal_codec_options = codec_options.with_options( + document_class=RawBSONDocument, type_registry=None) + _doc = _bson_to_dict(data, internal_codec_options) + return [_decode_selective(_doc, fields, codec_options,)] + + def decode_iter(data, codec_options=DEFAULT_CODEC_OPTIONS): """Decode BSON data to multiple documents as a generator. diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 11084171d..6c56a605e 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -25,7 +25,8 @@ from gridfs.errors import NoFile from gridfs.grid_file import (GridIn, GridOut, GridOutCursor, - DEFAULT_CHUNK_SIZE) + DEFAULT_CHUNK_SIZE, + _clear_entity_type_registry) from pymongo import (ASCENDING, DESCENDING) from pymongo.common import UNAUTHORIZED_CODES, validate_string @@ -61,6 +62,8 @@ class GridFS(object): if not isinstance(database, Database): raise TypeError("database must be an instance of Database") + database = _clear_entity_type_registry(database) + if not database.write_concern.acknowledged: raise ConfigurationError('database must use ' 'acknowledged write_concern') @@ -443,6 +446,8 @@ class GridFSBucket(object): if not isinstance(db, Database): raise TypeError("database must be an instance of Database") + db = _clear_entity_type_registry(db) + wtc = write_concern if write_concern is not None else db.write_concern if not wtc.acknowledged: raise ConfigurationError('write concern must be acknowledged') diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index cdc815267..4ee93a4f2 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -98,6 +98,12 @@ def _grid_out_property(field_name, docstring): return property(getter, doc=docstring) +def _clear_entity_type_registry(entity, **kwargs): + """Clear the given database/collection object's type registry.""" + codecopts = entity.codec_options.with_options(type_registry=None) + return entity.with_options(codec_options=codecopts, **kwargs) + + class GridIn(object): """Class to write data to GridFS. """ @@ -168,8 +174,8 @@ class GridIn(object): if "chunk_size" in kwargs: kwargs["chunkSize"] = kwargs.pop("chunk_size") - coll = root_collection.with_options( - read_preference=ReadPreference.PRIMARY) + coll = _clear_entity_type_registry( + root_collection, read_preference=ReadPreference.PRIMARY) if not disable_md5: kwargs["md5"] = hashlib.md5() @@ -449,6 +455,8 @@ class GridOut(object): raise TypeError("root_collection must be an " "instance of Collection") + root_collection = _clear_entity_type_registry(root_collection) + self.__chunks = root_collection.chunks self.__files = root_collection.files self.__file_id = file_id @@ -800,6 +808,8 @@ class GridOutCursor(Cursor): .. mongodoc:: cursors """ + collection = _clear_entity_type_registry(collection) + # Hold on to the base "fs" collection to create GridOut objects later. self.__root_collection = collection diff --git a/pymongo/collection.py b/pymongo/collection.py index ad066ce00..ed948dd23 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -53,6 +53,7 @@ from pymongo.write_concern import WriteConcern _NO_OBJ_ERROR = "No matching object found" _UJOIN = u"%s.%s" +_FIND_AND_MODIFY_DOC_FIELDS = {'value': dict} class ReturnDocument(object): @@ -203,7 +204,8 @@ class Collection(common.BaseObject): write_concern=None, collation=None, session=None, - retryable_write=False): + retryable_write=False, + user_fields=None): """Internal command helper. :Parameters: @@ -242,7 +244,8 @@ class Collection(common.BaseObject): collation=collation, session=s, client=self.__database.client, - retryable_write=retryable_write) + retryable_write=retryable_write, + user_fields=user_fields) def __create(self, options, collation, session): """Sends a create command with the given options. @@ -315,9 +318,8 @@ class Collection(common.BaseObject): """ return self.__database - def with_options( - self, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def with_options(self, codec_options=None, read_preference=None, + write_concern=None, read_concern=None): """Get a clone of this collection changing the specified settings. >>> coll1.read_preference @@ -2310,7 +2312,8 @@ class Collection(common.BaseObject): write_concern=write_concern, collation=collation, session=session, - client=self.__database.client) + client=self.__database.client, + user_fields={'cursor': {'firstBatch': list}}) if "cursor" in result: cursor = result["cursor"] @@ -2571,7 +2574,8 @@ class Collection(common.BaseObject): with self._socket_for_reads(session=None) as (sock_info, slave_ok): return self._command(sock_info, cmd, slave_ok, - collation=collation)["retval"] + collation=collation, + user_fields={'retval': list})["retval"] def rename(self, new_name, session=None, **kwargs): """Rename this collection. @@ -2675,7 +2679,8 @@ class Collection(common.BaseObject): with self._socket_for_reads(session) as (sock_info, slave_ok): return self._command(sock_info, cmd, slave_ok, read_concern=self.read_concern, - collation=collation, session=session)["values"] + collation=collation, + session=session)["values"] def map_reduce(self, map, reduce, out, full_response=False, session=None, **kwargs): @@ -2755,12 +2760,17 @@ class Collection(common.BaseObject): write_concern = self._write_concern_for(session) else: write_concern = None + if inline: + user_fields = {'results': list} + else: + user_fields = None response = self._command( sock_info, cmd, slave_ok, read_pref, read_concern=read_concern, write_concern=write_concern, - collation=collation, session=session) + collation=collation, session=session, + user_fields=user_fields) if full_response or not response.get('result'): return response @@ -2810,16 +2820,19 @@ class Collection(common.BaseObject): ("map", map), ("reduce", reduce), ("out", {"inline": 1})]) + user_fields = {'results': list} collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) with self._socket_for_reads(session) as (sock_info, slave_ok): if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd: res = self._command(sock_info, cmd, slave_ok, read_concern=self.read_concern, - collation=collation, session=session) + collation=collation, session=session, + user_fields=user_fields) else: res = self._command(sock_info, cmd, slave_ok, - collation=collation, session=session) + collation=collation, session=session, + user_fields=user_fields) if full_response: return res @@ -2837,6 +2850,7 @@ class Collection(common.BaseObject): return_document=ReturnDocument.BEFORE, array_filters=None, session=None, **kwargs): """Internal findAndModify helper.""" + common.validate_is_mapping("filter", filter) if not isinstance(return_document, bool): raise ValueError("return_document must be " @@ -2876,8 +2890,10 @@ class Collection(common.BaseObject): write_concern=write_concern, allowable_errors=[_NO_OBJ_ERROR], collation=collation, session=session, - retryable_write=retryable_write) + retryable_write=retryable_write, + user_fields=_FIND_AND_MODIFY_DOC_FIELDS) _check_write_command_response(out) + return out.get("value") return self.__database.client._retryable_write( @@ -3293,7 +3309,8 @@ class Collection(common.BaseObject): result = self._command( sock_info, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=[_NO_OBJ_ERROR], collation=collation, - session=session, retryable_write=retryable_write) + session=session, retryable_write=retryable_write, + user_fields=_FIND_AND_MODIFY_DOC_FIELDS) _check_write_command_response(result) return result diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index ace4a6aee..3b5281d9b 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -150,14 +150,18 @@ class CommandCursor(object): try: with client._reset_on_error(self.__address, self.__session): - docs = self._unpack_response(reply, - self.__id, - self.__collection.codec_options) + user_fields = None + legacy_response = True + if from_command: + user_fields = {'cursor': {'nextBatch': list}} + legacy_response = False + docs = self._unpack_response( + reply, self.__id, self.__collection.codec_options, + legacy_response=legacy_response, user_fields=user_fields) if from_command: first = docs[0] client._process_response(first, self.__session) helpers._check_command_response(first) - except OperationFailure as exc: kill() @@ -208,8 +212,10 @@ class CommandCursor(object): kill() self.__data = deque(documents) - def _unpack_response(self, response, cursor_id, codec_options): - return response.unpack_response(cursor_id, codec_options) + def _unpack_response(self, response, cursor_id, codec_options, + user_fields=None, legacy_response=False): + return response.unpack_response(cursor_id, codec_options, user_fields, + legacy_response) def _refresh(self): """Refreshes the cursor with more data from the server. @@ -330,7 +336,8 @@ class RawBatchCommandCursor(CommandCursor): collection, cursor_info, address, retrieved, batch_size, max_await_time_ms, session, explicit_session) - def _unpack_response(self, response, cursor_id, codec_options): + def _unpack_response(self, response, cursor_id, codec_options, + user_fields=None, legacy_response=False): return response.raw_response(cursor_id) def __getitem__(self, index): diff --git a/pymongo/cursor.py b/pymongo/cursor.py index da9bef392..481a132b4 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -40,7 +40,7 @@ from pymongo.message import (_convert_exception, _RawBatchGetMore, _Query, _RawBatchQuery) -from pymongo.read_preferences import ReadPreference + _QUERY_OPTIONS = { "tailable_cursor": 2, @@ -50,6 +50,7 @@ _QUERY_OPTIONS = { "await_data": 32, "exhaust": 64, "partial": 128} +_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': list, 'nextBatch': list}} class CursorType(object): @@ -996,9 +997,14 @@ class Cursor(object): try: with client._reset_on_error(self.__address, self.__session): - docs = self._unpack_response(reply, - self.__id, - self.__collection.codec_options) + user_fields = None + legacy_response = True + if from_command: + user_fields = _CURSOR_DOC_FIELDS + legacy_response = False + docs = self._unpack_response( + reply, self.__id, self.__collection.codec_options, + legacy_response=legacy_response, user_fields=user_fields) if from_command: first = docs[0] client._process_response(first, self.__session) @@ -1085,8 +1091,10 @@ class Cursor(object): if self.__limit and self.__id and self.__limit <= self.__retrieved: self.__die() - def _unpack_response(self, response, cursor_id, codec_options): - return response.unpack_response(cursor_id, codec_options) + def _unpack_response(self, response, cursor_id, codec_options, + user_fields=None, legacy_response=False): + return response.unpack_response(cursor_id, codec_options, user_fields, + legacy_response) def _read_preference(self): if self.__read_preference is None: @@ -1303,7 +1311,8 @@ class RawBatchCursor(Cursor): raise InvalidOperation( "Cannot use RawBatchCursor with manipulate=True") - def _unpack_response(self, response, cursor_id, codec_options): + def _unpack_response(self, response, cursor_id, codec_options, + user_fields=None, legacy_response=False): return response.raw_response(cursor_id) def explain(self): diff --git a/pymongo/database.py b/pymongo/database.py index 8ad91c4e1..c4e061dfa 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -222,6 +222,46 @@ class Database(common.BaseObject): return [manipulator.__class__.__name__ for manipulator in self.__outgoing_copying_manipulators] + def with_options(self, codec_options=None, read_preference=None, + write_concern=None, read_concern=None): + """Get a clone of this database changing the specified settings. + + >>> db1.read_preference + Primary() + >>> from pymongo import ReadPreference + >>> db2 = db1.with_options(read_preference=ReadPreference.SECONDARY) + >>> db1.read_preference + Primary() + >>> db2.read_preference + Secondary(tag_sets=None) + + :Parameters: + - `codec_options` (optional): An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`Collection` + is used. + - `read_preference` (optional): The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`Collection` is used. See :mod:`~pymongo.read_preferences` + for options. + - `write_concern` (optional): An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`Collection` + is used. + - `read_concern` (optional): An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`Collection` + is used. + + .. versionadded:: 3.8 + """ + return Database(self.client, + self.__name, + codec_options or self.codec_options, + read_preference or self.read_preference, + write_concern or self.write_concern, + read_concern or self.read_concern) + def __eq__(self, other): if isinstance(other, Database): return (self.__client == other.client and diff --git a/pymongo/message.py b/pymongo/message.py index 05f9e4601..693d01932 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -1398,7 +1398,8 @@ class _OpReply(object): return [self.documents] def unpack_response(self, cursor_id=None, - codec_options=_UNICODE_REPLACE_CODEC_OPTIONS): + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + user_fields=None, legacy_response=False): """Unpack a response from the database and decode the BSON document(s). Check the response for errors and unpack, returning a dictionary @@ -1415,7 +1416,10 @@ class _OpReply(object): :class:`~bson.codec_options.CodecOptions` """ self.raw_response(cursor_id) - return bson.decode_all(self.documents, codec_options) + if legacy_response: + return bson.decode_all(self.documents, codec_options) + return bson._decode_all_selective( + self.documents, codec_options, user_fields) def command_response(self): """Unpack a command response.""" @@ -1451,7 +1455,8 @@ class _OpMsg(object): raise NotImplementedError def unpack_response(self, cursor_id=None, - codec_options=_UNICODE_REPLACE_CODEC_OPTIONS): + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + user_fields=None, legacy_response=False): """Unpack a OP_MSG command response. :Parameters: @@ -1459,7 +1464,10 @@ class _OpMsg(object): - `codec_options` (optional): an instance of :class:`~bson.codec_options.CodecOptions` """ - return bson.decode_all(self.payload_document, codec_options) + # If _OpMsg is in-use, this cannot be a legacy response. + assert not legacy_response + return bson._decode_all_selective( + self.payload_document, codec_options, user_fields) def command_response(self): """Unpack a command response.""" diff --git a/pymongo/network.py b/pymongo/network.py index 3d42fa3c3..20c120c12 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -58,7 +58,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos, collation=None, compression_ctx=None, use_op_msg=False, - unacknowledged=False): + unacknowledged=False, + user_fields=None): """Execute a command over the socket, or raise socket.error. :Parameters: @@ -139,7 +140,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos, response_doc = {"ok": 1} else: reply = receive_message(sock, request_id) - unpacked_docs = reply.unpack_response(codec_options=codec_options) + unpacked_docs = reply.unpack_response( + codec_options=codec_options, user_fields=user_fields) response_doc = unpacked_docs[0] if client: diff --git a/pymongo/pool.py b/pymongo/pool.py index 78c74034b..8fe8f81f2 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -512,7 +512,8 @@ class SocketInfo(object): session=None, client=None, retryable_write=False, - publish_events=True): + publish_events=True, + user_fields=None): """Execute a command or raise an error. :Parameters: @@ -576,7 +577,8 @@ class SocketInfo(object): collation=collation, compression_ctx=self.compression_context, use_op_msg=self.op_msg_enabled, - unacknowledged=unacknowledged) + unacknowledged=unacknowledged, + user_fields=user_fields) except OperationFailure: raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. diff --git a/test/test_collection.py b/test/test_collection.py index 2f3358e9b..38a77eb84 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -132,6 +132,12 @@ class TestCollection(IntegrationTest): def tearDownClass(cls): cls.db.drop_collection("test_large_limit") + def setUp(self): + self.db.test.drop() + + def tearDown(self): + self.db.test.drop() + @contextlib.contextmanager def write_concern_collection(self): if client_context.version.at_least(3, 3, 9) and client_context.is_rs: diff --git a/test/test_bson_custom_types.py b/test/test_custom_types.py similarity index 63% rename from test/test_bson_custom_types.py rename to test/test_custom_types.py index 14c54c441..93275d542 100644 --- a/test/test_bson_custom_types.py +++ b/test/test_custom_types.py @@ -14,6 +14,7 @@ """Test support for callbacks to encode/decode custom types.""" +import datetime import sys import tempfile from decimal import Decimal @@ -30,11 +31,21 @@ from bson import (BSON, _BUILT_IN_TYPES, _dict_to_bson, _bson_to_dict) +from bson.code import Code from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder, TypeEncoder, TypeRegistry) from bson.errors import InvalidDocument +from bson.int64 import Int64 +from bson.py3compat import text_type -from test import unittest +from gridfs import GridIn, GridOut + +from pymongo.collection import ReturnDocument +from pymongo.errors import DuplicateKeyError + +from test import client_context, unittest +from test.test_client import IntegrationTest +from test.utils import ignore_deprecations class DecimalEncoder(TypeEncoder): @@ -59,7 +70,51 @@ class DecimalCodec(DecimalDecoder, DecimalEncoder): pass -class CustomTypeTests(object): +DECIMAL_CODECOPTS = CodecOptions( + type_registry=TypeRegistry([DecimalCodec()])) + + +class UndecipherableInt64Type(object): + def __init__(self, value): + self.value = value + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + return self.value == other + + +class UndecipherableIntDecoder(TypeDecoder): + bson_type = Int64 + def transform_bson(self, value): + return UndecipherableInt64Type(value) + + +class UndecipherableIntEncoder(TypeEncoder): + python_type = UndecipherableInt64Type + def transform_python(self, value): + return Int64(value.value) + + +UNINT_DECODER_CODECOPTS = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntDecoder(), ])) + + +UNINT_CODECOPTS = CodecOptions(type_registry=TypeRegistry( + [UndecipherableIntDecoder(), UndecipherableIntEncoder()])) + + +class UppercaseTextDecoder(TypeDecoder): + bson_type = text_type + def transform_bson(self, value): + return value.upper() + + +UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry( + [UppercaseTextDecoder(),])) + + +class CustomBSONTypeTests(object): def test_encode_decode_roundtrip(self): document = {'average': Decimal('56.47')} bsonbytes = BSON().encode(document, codec_options=self.codecopts) @@ -117,25 +172,24 @@ class CustomTypeTests(object): fileobj.close() -class TestCustomPythonTypeToBSONMonolithicCodec(CustomTypeTests, - unittest.TestCase): + +class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, + unittest.TestCase): @classmethod def setUpClass(cls): - type_registry = TypeRegistry((DecimalCodec(),)) - codec_options = CodecOptions(type_registry=type_registry) + cls.codecopts = DECIMAL_CODECOPTS + + +class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, + unittest.TestCase): + @classmethod + def setUpClass(cls): + codec_options = CodecOptions( + type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder()))) cls.codecopts = codec_options -class TestCustomPythonTypeToBSONMultiplexedCodec(CustomTypeTests, - unittest.TestCase): - @classmethod - def setUpClass(cls): - type_registry = TypeRegistry((DecimalEncoder(), DecimalDecoder())) - codec_options = CodecOptions(type_registry=type_registry) - cls.codecopts = codec_options - - -class TestFallbackEncoder(unittest.TestCase): +class TestBSONFallbackEncoder(unittest.TestCase): def _get_codec_options(self, fallback_encoder): type_registry = TypeRegistry(fallback_encoder=fallback_encoder) return CodecOptions(type_registry=type_registry) @@ -180,7 +234,7 @@ class TestFallbackEncoder(unittest.TestCase): BSON().encode(document, codec_options=codecopts) -class TestTypeEnDeCodecs(unittest.TestCase): +class TestBSONTypeEnDeCodecs(unittest.TestCase): def test_instantiation(self): msg = "Can't instantiate abstract class .* with abstract methods .*" def run_test(base, attrs, fail): @@ -220,7 +274,7 @@ class TestTypeEnDeCodecs(unittest.TestCase): self.assertFalse(issubclass(TypeEncoder, TypeDecoder)) -class TestCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): +class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): @classmethod def setUpClass(cls): class TypeA(object): @@ -467,5 +521,192 @@ class TestTypeRegistry(unittest.TestCase): 'transform_bson': lambda x: x}) +class TestCollectionWCustomType(IntegrationTest): + def setUp(self): + self.db.test.drop() + + def tearDown(self): + self.db.test.drop() + + def test_command_errors_w_custom_type_decoder(self): + db = self.db + test_doc = {'_id': 1, 'data': 'a'} + test = db.get_collection('test', + codec_options=UNINT_DECODER_CODECOPTS) + + result = test.insert_one(test_doc) + self.assertEqual(result.inserted_id, test_doc['_id']) + with self.assertRaises(DuplicateKeyError): + test.insert_one(test_doc) + + def test_find_w_custom_type_decoder(self): + db = self.db + input_docs = [ + {'x': Int64(k)} for k in [1.0, 2.0, 3.0]] + for doc in input_docs: + db.test.insert_one(doc) + + test = db.get_collection( + 'test', codec_options=UNINT_DECODER_CODECOPTS) + for doc in test.find({}, batch_size=1): + self.assertIsInstance(doc['x'], UndecipherableInt64Type) + + @client_context.require_version_max(4, 1, 0, -1) + def test_group_w_custom_type(self): + db = self.db + test = db.get_collection('test', codec_options=UNINT_CODECOPTS) + test.insert_many([ + {'sku': 'a', 'qty': UndecipherableInt64Type(2)}, + {'sku': 'b', 'qty': UndecipherableInt64Type(5)}, + {'sku': 'a', 'qty': UndecipherableInt64Type(1)}]) + + self.assertEqual([{'sku': 'b', 'qty': UndecipherableInt64Type(5)},], + test.group(["sku", "qty"], {"sku": "b"}, {}, + "function (obj, prev) { }")) + + def test_aggregate_w_custom_type_decoder(self): + db = self.db + db.test.insert_many([ + {'status': 'in progress', 'qty': Int64(1)}, + {'status': 'complete', 'qty': Int64(10)}, + {'status': 'in progress', 'qty': Int64(1)}, + {'status': 'complete', 'qty': Int64(10)}, + {'status': 'in progress', 'qty': Int64(1)},]) + test = db.get_collection( + 'test', codec_options=UNINT_DECODER_CODECOPTS) + + pipeline = [ + {'$match': {'status': 'complete'}}, + {'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},] + result = test.aggregate(pipeline) + + res = list(result)[0] + self.assertEqual(res['_id'], 'complete') + self.assertIsInstance(res['total_qty'], UndecipherableInt64Type) + self.assertEqual(res['total_qty'].value, 20) + + # collection.distinct does not support custom type decoding + def test_distinct_w_custom_type(self): + self.db.drop_collection("test") + + test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS) + test.insert_many([ + {"a": UndecipherableInt64Type(1)}, + {"a": UndecipherableInt64Type(2)}, + {"a": UndecipherableInt64Type(2)}, + {"a": UndecipherableInt64Type(2)}, + {"a": UndecipherableInt64Type(3)}]) + + distinct = test.distinct("a") + distinct.sort() + + self.assertEqual([ + UndecipherableInt64Type(1), UndecipherableInt64Type(2), + UndecipherableInt64Type(3)], distinct) + + def test_map_reduce_w_custom_type(self): + test = self.db.get_collection( + 'test', codec_options=UPPERSTR_DECODER_CODECOPTS) + + test.insert_many([ + {'_id': 1, 'sku': 'abcd', 'qty': 1}, + {'_id': 2, 'sku': 'abcd', 'qty': 2}, + {'_id': 3, 'sku': 'abcd', 'qty': 3}]) + + map = Code("function () {" + " emit(this.sku, this.qty);" + "}") + reduce = Code("function (key, values) {" + " return Array.sum(values);" + "}") + + result = test.map_reduce(map, reduce, out={'inline': 1}) + self.assertTrue(isinstance(result, dict)) + self.assertTrue('results' in result) + self.assertEqual(result['results'][0], {'_id': 'ABCD', 'value': 6}) + + result = test.inline_map_reduce(map, reduce) + self.assertTrue(isinstance(result, list)) + self.assertEqual(1, len(result)) + self.assertEqual(result[0]["_id"], 'ABCD') + + full_result = test.inline_map_reduce(map, reduce, + full_response=True) + self.assertEqual(3, full_result["counts"]["emit"]) + + def test_find_one_and__w_custom_type_decoder(self): + db = self.db + c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS) + c.insert_one({'_id': 1, 'x': Int64(1)}) + + doc = c.find_one_and_update({'_id': 1}, {'$inc': {'x': 1}}, + return_document=ReturnDocument.AFTER) + self.assertEqual(doc['_id'], 1) + self.assertIsInstance(doc['x'], UndecipherableInt64Type) + self.assertEqual(doc['x'].value, 2) + + doc = c.find_one_and_replace({'_id': 1}, {'x': Int64(3), 'y': True}, + return_document=ReturnDocument.AFTER) + self.assertEqual(doc['_id'], 1) + self.assertIsInstance(doc['x'], UndecipherableInt64Type) + self.assertEqual(doc['x'].value, 3) + self.assertEqual(doc['y'], True) + + doc = c.find_one_and_delete({'y': True}) + self.assertEqual(doc['_id'], 1) + self.assertIsInstance(doc['x'], UndecipherableInt64Type) + self.assertEqual(doc['x'].value, 3) + self.assertIsNone(c.find_one()) + + @ignore_deprecations + def test_find_and_modify_w_custom_type_decoder(self): + db = self.db + c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS) + c.insert_one({'_id': 1, 'x': Int64(1)}) + + doc = c.find_and_modify({'_id': 1}, {'$inc': {'x': Int64(10)}}) + self.assertEqual(doc['_id'], 1) + self.assertIsInstance(doc['x'], UndecipherableInt64Type) + self.assertEqual(doc['x'].value, 1) + + doc = c.find_one() + self.assertEqual(doc['_id'], 1) + self.assertIsInstance(doc['x'], UndecipherableInt64Type) + self.assertEqual(doc['x'].value, 11) + + +class TestGridFileCustomType(IntegrationTest): + def setUp(self): + self.db.drop_collection('fs.files') + self.db.drop_collection('fs.chunks') + + def test_grid_out_custom_opts(self): + db = self.db.with_options(codec_options=UPPERSTR_DECODER_CODECOPTS) + one = GridIn(db.fs, _id=5, filename="my_file", + contentType="text/html", chunkSize=1000, aliases=["foo"], + metadata={"foo": 'red', "bar": 'blue'}, bar=3, + baz="hello") + one.write(b"hello world") + one.close() + + two = GridOut(db.fs, 5) + + self.assertEqual("my_file", two.name) + self.assertEqual("my_file", two.filename) + self.assertEqual(5, two._id) + self.assertEqual(11, two.length) + self.assertEqual("text/html", two.content_type) + self.assertEqual(1000, two.chunk_size) + self.assertTrue(isinstance(two.upload_date, datetime.datetime)) + self.assertEqual(["foo"], two.aliases) + self.assertEqual({"foo": 'red', "bar": 'blue'}, two.metadata) + self.assertEqual(3, two.bar) + self.assertEqual("5eb63bbbe01eeed093cb22bb8f5acdc3", two.md5) + + for attr in ["_id", "name", "content_type", "length", "chunk_size", + "upload_date", "aliases", "metadata", "md5"]: + self.assertRaises(AttributeError, setattr, two, attr, 5) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_database.py b/test/test_database.py index b6f66cada..b256c9be5 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -58,6 +58,7 @@ from test.utils import (ignore_deprecations, server_started_with_auth, IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener) +from test.test_custom_types import DECIMAL_CODECOPTS if PY3: @@ -976,6 +977,36 @@ class TestDatabase(IntegrationTest): "maxTimeAlwaysTimeOut", mode="off") + def test_with_options(self): + codec_options = DECIMAL_CODECOPTS + read_preference = ReadPreference.SECONDARY_PREFERRED + write_concern = WriteConcern(j=True) + read_concern = ReadConcern(level="majority") + + # List of all options to compare. + allopts = ['name', 'client', 'codec_options', + 'read_preference', 'write_concern', 'read_concern'] + + db1 = self.client.get_database( + 'with_options_test', codec_options=codec_options, + read_preference=read_preference, write_concern=write_concern, + read_concern=read_concern) + + # Case 1: swap no options + db2 = db1.with_options() + for opt in allopts: + self.assertEqual(getattr(db1, opt), getattr(db2, opt)) + + # Case 2: swap all options + newopts = {'codec_options': CodecOptions(), + 'read_preference': ReadPreference.PRIMARY, + 'write_concern': WriteConcern(w=1), + 'read_concern': ReadConcern(level="local")} + db2 = db1.with_options(**newopts) + for opt in newopts: + self.assertEqual( + getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) + def test_current_op_codec_options(self): class MySON(SON): pass