From 2f2fe9db0db78bfac59ef3317533c344c2aa9a5f Mon Sep 17 00:00:00 2001 From: Prashant Mital Date: Wed, 17 Apr 2019 11:37:54 -0700 Subject: [PATCH] PYTHON-1818 TypeCodec support for ChangeStreams --- bson/__init__.py | 8 +- bson/raw_bson.py | 4 +- pymongo/change_stream.py | 19 +++- test/test_custom_types.py | 183 +++++++++++++++++++++++++++++++++++++- 4 files changed, 207 insertions(+), 7 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index 8b0b928de..98c784f2b 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -948,7 +948,13 @@ if _USE_C: def _decode_selective(rawdoc, fields, codec_options): - doc = {} + if _raw_document_class(codec_options.document_class): + # If document_class is RawBSONDocument, use vanilla dictionary for + # decoding command response. + doc = {} + else: + # Else, use the specified document_class. + doc = codec_options.document_class() for key, value in iteritems(rawdoc): if key in fields: if fields[key] == 1: diff --git a/bson/raw_bson.py b/bson/raw_bson.py index c13b1a9b3..429b2acc3 100644 --- a/bson/raw_bson.py +++ b/bson/raw_bson.py @@ -19,6 +19,7 @@ from bson import _elements_to_dict, _get_object_size from bson.py3compat import abc, iteritems from bson.codec_options import ( DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER) +from bson.son import SON class RawBSONDocument(abc.Mapping): @@ -93,8 +94,9 @@ class RawBSONDocument(abc.Mapping): if self.__inflated_doc is None: # We already validated the object's size when this document was # created, so no need to do that again. + # Use SON to preserve ordering of elements. self.__inflated_doc = _elements_to_dict( - self.__raw, 4, len(self.__raw)-1, self.__codec_options, {}) + self.__raw, 4, len(self.__raw)-1, self.__codec_options, SON()) return self.__inflated_doc def __getitem__(self, item): diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index a48980c52..ad3dbaa61 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -16,6 +16,8 @@ import copy +from bson import _bson_to_dict +from bson.raw_bson import RawBSONDocument from bson.son import SON from pymongo import common @@ -60,7 +62,16 @@ class ChangeStream(object): validate_collation_or_none(collation) common.validate_non_negative_integer_or_none("batchSize", batch_size) - self._target = target + self._decode_custom = False + self._orig_codec_options = target.codec_options + if target.codec_options.type_registry._decoder_map: + self._decode_custom = True + self._target = target.with_options( + codec_options=target.codec_options.with_options( + document_class=RawBSONDocument, type_registry=None)) + else: + self._target = target + self._pipeline = copy.deepcopy(pipeline) self._full_document = full_document self._resume_token = copy.deepcopy(resume_after) @@ -145,8 +156,7 @@ class ChangeStream(object): aggregation_collection, cursor, sock_info.address, batch_size=self._batch_size or 0, max_await_time_ms=self._max_await_time_ms, - session=session, explicit_session=explicit_session - ) + session=session, explicit_session=explicit_session) def _create_cursor(self): with self._database.client._tmp_session(self._session, close=False) as s: @@ -263,6 +273,9 @@ class ChangeStream(object): "token is missing.") self._resume_token = copy.copy(resume_token) self._start_at_operation_time = None + + if self._decode_custom: + return _bson_to_dict(change.raw, self._orig_codec_options) return change def __enter__(self): diff --git a/test/test_custom_types.py b/test/test_custom_types.py index c830258c9..cc83c1e6d 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -17,6 +17,7 @@ import datetime import sys import tempfile +from collections import OrderedDict from decimal import Decimal from random import random @@ -36,16 +37,18 @@ from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder, TypeEncoder, TypeRegistry) from bson.errors import InvalidDocument from bson.int64 import Int64 +from bson.raw_bson import RawBSONDocument from bson.py3compat import text_type from gridfs import GridIn, GridOut from pymongo.collection import ReturnDocument from pymongo.errors import DuplicateKeyError +from pymongo.message import _CursorAddress from test import client_context, unittest from test.test_client import IntegrationTest -from test.utils import ignore_deprecations +from test.utils import ignore_deprecations, rs_client class DecimalEncoder(TypeEncoder): @@ -115,6 +118,14 @@ UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry( [UppercaseTextDecoder(),])) +def type_obfuscating_decoder_factory(rt_type): + class ResumeTokenToNanDecoder(TypeDecoder): + bson_type = rt_type + def transform_bson(self, value): + return "NaN" + return ResumeTokenToNanDecoder + + class CustomBSONTypeTests(object): def roundtrip(self, doc): bsonbytes = BSON().encode(doc, codec_options=self.codecopts) @@ -549,7 +560,7 @@ class TestCollectionWCustomType(IntegrationTest): def test_find_w_custom_type_decoder(self): db = self.db input_docs = [ - {'x': Int64(k)} for k in [1.0, 2.0, 3.0]] + {'x': Int64(k)} for k in [1, 2, 3]] for doc in input_docs: db.test.insert_one(doc) @@ -558,6 +569,24 @@ class TestCollectionWCustomType(IntegrationTest): for doc in test.find({}, batch_size=1): self.assertIsInstance(doc['x'], UndecipherableInt64Type) + def test_find_w_custom_type_decoder_and_document_class(self): + def run_test(doc_cls): + db = self.db + input_docs = [ + {'x': Int64(k)} for k in [1, 2, 3]] + for doc in input_docs: + db.test.insert_one(doc) + + test = db.get_collection('test', codec_options=CodecOptions( + type_registry=TypeRegistry([UndecipherableIntDecoder()]), + document_class=doc_cls)) + for doc in test.find({}, batch_size=1): + self.assertIsInstance(doc, doc_cls) + self.assertIsInstance(doc['x'], UndecipherableInt64Type) + + for doc_cls in [RawBSONDocument, OrderedDict]: + run_test(doc_cls) + @client_context.require_version_max(4, 1, 0, -1) def test_group_w_custom_type(self): db = self.db @@ -709,5 +738,155 @@ class TestGridFileCustomType(IntegrationTest): self.assertRaises(AttributeError, setattr, two, attr, 5) +class ChangeStreamsWCustomTypesTestMixin(object): + def change_stream(self, *args, **kwargs): + return self.watched_target.watch(*args, **kwargs) + + def insert_and_check(self, change_stream, insert_doc, + expected_doc): + self.input_target.insert_one(insert_doc) + change = next(change_stream) + self.assertEqual(change['fullDocument'], expected_doc) + + def kill_change_stream_cursor(self, change_stream): + # Cause a cursor not found error on the next getMore. + cursor = change_stream._cursor + address = _CursorAddress(cursor.address, cursor._CommandCursor__ns) + client = self.input_target.database.client + client._close_cursor_now(cursor.cursor_id, address) + + def test_simple(self): + codecopts = CodecOptions(type_registry=TypeRegistry([ + UndecipherableIntEncoder(), UppercaseTextDecoder()])) + self.create_targets(codec_options=codecopts) + + input_docs = [ + {'_id': UndecipherableInt64Type(1), 'data': 'hello'}, + {'_id': 2, 'data': 'world'}, + {'_id': UndecipherableInt64Type(3), 'data': '!'},] + expected_docs = [ + {'_id': 1, 'data': 'HELLO'}, + {'_id': 2, 'data': 'WORLD'}, + {'_id': 3, 'data': '!'},] + + change_stream = self.change_stream() + + self.insert_and_check(change_stream, input_docs[0], expected_docs[0]) + self.kill_change_stream_cursor(change_stream) + self.insert_and_check(change_stream, input_docs[1], expected_docs[1]) + self.kill_change_stream_cursor(change_stream) + self.insert_and_check(change_stream, input_docs[2], expected_docs[2]) + + def test_break_resume_token(self): + # Get one document from a change stream to determine resumeToken type. + self.create_targets() + change_stream = self.change_stream() + self.input_target.insert_one({"data": "test"}) + change = next(change_stream) + resume_token_decoder = type_obfuscating_decoder_factory( + type(change['_id']['_data'])) + + # Custom-decoding the resumeToken type breaks resume tokens. + codecopts = CodecOptions(type_registry=TypeRegistry([ + resume_token_decoder(), UndecipherableIntEncoder()])) + + # Re-create targets, change stream and proceed. + self.create_targets(codec_options=codecopts) + + docs = [{'_id': 1}, {'_id': 2}, {'_id': 3}] + + change_stream = self.change_stream() + self.insert_and_check(change_stream, docs[0], docs[0]) + self.kill_change_stream_cursor(change_stream) + self.insert_and_check(change_stream, docs[1], docs[1]) + self.kill_change_stream_cursor(change_stream) + self.insert_and_check(change_stream, docs[2], docs[2]) + + def test_document_class(self): + def run_test(doc_cls): + codecopts = CodecOptions(type_registry=TypeRegistry([ + UppercaseTextDecoder(), UndecipherableIntEncoder()]), + document_class=doc_cls) + + self.create_targets(codec_options=codecopts) + change_stream = self.change_stream() + + doc = {'a': UndecipherableInt64Type(101), 'b': 'xyz'} + self.input_target.insert_one(doc) + change = next(change_stream) + + self.assertIsInstance(change, doc_cls) + self.assertEqual(change['fullDocument']['a'], 101) + self.assertEqual(change['fullDocument']['b'], 'XYZ') + + for doc_cls in [OrderedDict, RawBSONDocument]: + run_test(doc_cls) + + +class TestCollectionChangeStreamsWCustomTypes( + IntegrationTest, ChangeStreamsWCustomTypesTestMixin): + @classmethod + @client_context.require_version_min(3, 6, 0) + @client_context.require_no_mmap + @client_context.require_no_standalone + def setUpClass(cls): + super(TestCollectionChangeStreamsWCustomTypes, cls).setUpClass() + + def tearDown(self): + self.input_target.drop() + + def create_targets(self, *args, **kwargs): + self.watched_target = self.db.get_collection( + 'test', *args, **kwargs) + self.input_target = self.watched_target + # Insert a record to ensure db, coll are created. + self.input_target.insert_one({'data': 'dummy'}) + + +class TestDatabaseChangeStreamsWCustomTypes( + IntegrationTest, ChangeStreamsWCustomTypesTestMixin): + @classmethod + @client_context.require_version_min(4, 0, 0) + @client_context.require_no_mmap + @client_context.require_no_standalone + def setUpClass(cls): + super(TestDatabaseChangeStreamsWCustomTypes, cls).setUpClass() + + def tearDown(self): + self.input_target.drop() + self.client.drop_database(self.watched_target) + + def create_targets(self, *args, **kwargs): + self.watched_target = self.client.get_database( + self.db.name, *args, **kwargs) + self.input_target = self.watched_target.test + # Insert a record to ensure db, coll are created. + self.input_target.insert_one({'data': 'dummy'}) + + +class TestClusterChangeStreamsWCustomTypes( + IntegrationTest, ChangeStreamsWCustomTypesTestMixin): + @classmethod + @client_context.require_version_min(4, 0, 0) + @client_context.require_no_mmap + @client_context.require_no_standalone + def setUpClass(cls): + super(TestClusterChangeStreamsWCustomTypes, cls).setUpClass() + + def tearDown(self): + self.input_target.drop() + self.client.drop_database(self.db) + + def create_targets(self, *args, **kwargs): + codec_options = kwargs.pop('codec_options', None) + if codec_options: + kwargs['type_registry'] = codec_options.type_registry + kwargs['document_class'] = codec_options.document_class + self.watched_target = rs_client(*args, **kwargs) + self.input_target = self.watched_target[self.db.name].test + # Insert a record to ensure db, coll are created. + self.input_target.insert_one({'data': 'dummy'}) + + if __name__ == "__main__": unittest.main()