PYTHON-1818 TypeCodec support for ChangeStreams

This commit is contained in:
Prashant Mital 2019-04-17 11:37:54 -07:00
parent fbb56a2311
commit 2f2fe9db0d
No known key found for this signature in database
GPG Key ID: 3D2DAA9E483ABE51
4 changed files with 207 additions and 7 deletions

View File

@ -948,7 +948,13 @@ if _USE_C:
def _decode_selective(rawdoc, fields, codec_options):
doc = {}
if _raw_document_class(codec_options.document_class):
# If document_class is RawBSONDocument, use vanilla dictionary for
# decoding command response.
doc = {}
else:
# Else, use the specified document_class.
doc = codec_options.document_class()
for key, value in iteritems(rawdoc):
if key in fields:
if fields[key] == 1:

View File

@ -19,6 +19,7 @@ from bson import _elements_to_dict, _get_object_size
from bson.py3compat import abc, iteritems
from bson.codec_options import (
DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER)
from bson.son import SON
class RawBSONDocument(abc.Mapping):
@ -93,8 +94,9 @@ class RawBSONDocument(abc.Mapping):
if self.__inflated_doc is None:
# We already validated the object's size when this document was
# created, so no need to do that again.
# Use SON to preserve ordering of elements.
self.__inflated_doc = _elements_to_dict(
self.__raw, 4, len(self.__raw)-1, self.__codec_options, {})
self.__raw, 4, len(self.__raw)-1, self.__codec_options, SON())
return self.__inflated_doc
def __getitem__(self, item):

View File

@ -16,6 +16,8 @@
import copy
from bson import _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo import common
@ -60,7 +62,16 @@ class ChangeStream(object):
validate_collation_or_none(collation)
common.validate_non_negative_integer_or_none("batchSize", batch_size)
self._target = target
self._decode_custom = False
self._orig_codec_options = target.codec_options
if target.codec_options.type_registry._decoder_map:
self._decode_custom = True
self._target = target.with_options(
codec_options=target.codec_options.with_options(
document_class=RawBSONDocument, type_registry=None))
else:
self._target = target
self._pipeline = copy.deepcopy(pipeline)
self._full_document = full_document
self._resume_token = copy.deepcopy(resume_after)
@ -145,8 +156,7 @@ class ChangeStream(object):
aggregation_collection, cursor, sock_info.address,
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session, explicit_session=explicit_session
)
session=session, explicit_session=explicit_session)
def _create_cursor(self):
with self._database.client._tmp_session(self._session, close=False) as s:
@ -263,6 +273,9 @@ class ChangeStream(object):
"token is missing.")
self._resume_token = copy.copy(resume_token)
self._start_at_operation_time = None
if self._decode_custom:
return _bson_to_dict(change.raw, self._orig_codec_options)
return change
def __enter__(self):

View File

@ -17,6 +17,7 @@
import datetime
import sys
import tempfile
from collections import OrderedDict
from decimal import Decimal
from random import random
@ -36,16 +37,18 @@ from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder,
TypeEncoder, TypeRegistry)
from bson.errors import InvalidDocument
from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument
from bson.py3compat import text_type
from gridfs import GridIn, GridOut
from pymongo.collection import ReturnDocument
from pymongo.errors import DuplicateKeyError
from pymongo.message import _CursorAddress
from test import client_context, unittest
from test.test_client import IntegrationTest
from test.utils import ignore_deprecations
from test.utils import ignore_deprecations, rs_client
class DecimalEncoder(TypeEncoder):
@ -115,6 +118,14 @@ UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
[UppercaseTextDecoder(),]))
def type_obfuscating_decoder_factory(rt_type):
class ResumeTokenToNanDecoder(TypeDecoder):
bson_type = rt_type
def transform_bson(self, value):
return "NaN"
return ResumeTokenToNanDecoder
class CustomBSONTypeTests(object):
def roundtrip(self, doc):
bsonbytes = BSON().encode(doc, codec_options=self.codecopts)
@ -549,7 +560,7 @@ class TestCollectionWCustomType(IntegrationTest):
def test_find_w_custom_type_decoder(self):
db = self.db
input_docs = [
{'x': Int64(k)} for k in [1.0, 2.0, 3.0]]
{'x': Int64(k)} for k in [1, 2, 3]]
for doc in input_docs:
db.test.insert_one(doc)
@ -558,6 +569,24 @@ class TestCollectionWCustomType(IntegrationTest):
for doc in test.find({}, batch_size=1):
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
def test_find_w_custom_type_decoder_and_document_class(self):
def run_test(doc_cls):
db = self.db
input_docs = [
{'x': Int64(k)} for k in [1, 2, 3]]
for doc in input_docs:
db.test.insert_one(doc)
test = db.get_collection('test', codec_options=CodecOptions(
type_registry=TypeRegistry([UndecipherableIntDecoder()]),
document_class=doc_cls))
for doc in test.find({}, batch_size=1):
self.assertIsInstance(doc, doc_cls)
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
for doc_cls in [RawBSONDocument, OrderedDict]:
run_test(doc_cls)
@client_context.require_version_max(4, 1, 0, -1)
def test_group_w_custom_type(self):
db = self.db
@ -709,5 +738,155 @@ class TestGridFileCustomType(IntegrationTest):
self.assertRaises(AttributeError, setattr, two, attr, 5)
class ChangeStreamsWCustomTypesTestMixin(object):
def change_stream(self, *args, **kwargs):
return self.watched_target.watch(*args, **kwargs)
def insert_and_check(self, change_stream, insert_doc,
expected_doc):
self.input_target.insert_one(insert_doc)
change = next(change_stream)
self.assertEqual(change['fullDocument'], expected_doc)
def kill_change_stream_cursor(self, change_stream):
# Cause a cursor not found error on the next getMore.
cursor = change_stream._cursor
address = _CursorAddress(cursor.address, cursor._CommandCursor__ns)
client = self.input_target.database.client
client._close_cursor_now(cursor.cursor_id, address)
def test_simple(self):
codecopts = CodecOptions(type_registry=TypeRegistry([
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
self.create_targets(codec_options=codecopts)
input_docs = [
{'_id': UndecipherableInt64Type(1), 'data': 'hello'},
{'_id': 2, 'data': 'world'},
{'_id': UndecipherableInt64Type(3), 'data': '!'},]
expected_docs = [
{'_id': 1, 'data': 'HELLO'},
{'_id': 2, 'data': 'WORLD'},
{'_id': 3, 'data': '!'},]
change_stream = self.change_stream()
self.insert_and_check(change_stream, input_docs[0], expected_docs[0])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, input_docs[1], expected_docs[1])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, input_docs[2], expected_docs[2])
def test_break_resume_token(self):
# Get one document from a change stream to determine resumeToken type.
self.create_targets()
change_stream = self.change_stream()
self.input_target.insert_one({"data": "test"})
change = next(change_stream)
resume_token_decoder = type_obfuscating_decoder_factory(
type(change['_id']['_data']))
# Custom-decoding the resumeToken type breaks resume tokens.
codecopts = CodecOptions(type_registry=TypeRegistry([
resume_token_decoder(), UndecipherableIntEncoder()]))
# Re-create targets, change stream and proceed.
self.create_targets(codec_options=codecopts)
docs = [{'_id': 1}, {'_id': 2}, {'_id': 3}]
change_stream = self.change_stream()
self.insert_and_check(change_stream, docs[0], docs[0])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, docs[1], docs[1])
self.kill_change_stream_cursor(change_stream)
self.insert_and_check(change_stream, docs[2], docs[2])
def test_document_class(self):
def run_test(doc_cls):
codecopts = CodecOptions(type_registry=TypeRegistry([
UppercaseTextDecoder(), UndecipherableIntEncoder()]),
document_class=doc_cls)
self.create_targets(codec_options=codecopts)
change_stream = self.change_stream()
doc = {'a': UndecipherableInt64Type(101), 'b': 'xyz'}
self.input_target.insert_one(doc)
change = next(change_stream)
self.assertIsInstance(change, doc_cls)
self.assertEqual(change['fullDocument']['a'], 101)
self.assertEqual(change['fullDocument']['b'], 'XYZ')
for doc_cls in [OrderedDict, RawBSONDocument]:
run_test(doc_cls)
class TestCollectionChangeStreamsWCustomTypes(
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(3, 6, 0)
@client_context.require_no_mmap
@client_context.require_no_standalone
def setUpClass(cls):
super(TestCollectionChangeStreamsWCustomTypes, cls).setUpClass()
def tearDown(self):
self.input_target.drop()
def create_targets(self, *args, **kwargs):
self.watched_target = self.db.get_collection(
'test', *args, **kwargs)
self.input_target = self.watched_target
# Insert a record to ensure db, coll are created.
self.input_target.insert_one({'data': 'dummy'})
class TestDatabaseChangeStreamsWCustomTypes(
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0)
@client_context.require_no_mmap
@client_context.require_no_standalone
def setUpClass(cls):
super(TestDatabaseChangeStreamsWCustomTypes, cls).setUpClass()
def tearDown(self):
self.input_target.drop()
self.client.drop_database(self.watched_target)
def create_targets(self, *args, **kwargs):
self.watched_target = self.client.get_database(
self.db.name, *args, **kwargs)
self.input_target = self.watched_target.test
# Insert a record to ensure db, coll are created.
self.input_target.insert_one({'data': 'dummy'})
class TestClusterChangeStreamsWCustomTypes(
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0)
@client_context.require_no_mmap
@client_context.require_no_standalone
def setUpClass(cls):
super(TestClusterChangeStreamsWCustomTypes, cls).setUpClass()
def tearDown(self):
self.input_target.drop()
self.client.drop_database(self.db)
def create_targets(self, *args, **kwargs):
codec_options = kwargs.pop('codec_options', None)
if codec_options:
kwargs['type_registry'] = codec_options.type_registry
kwargs['document_class'] = codec_options.document_class
self.watched_target = rs_client(*args, **kwargs)
self.input_target = self.watched_target[self.db.name].test
# Insert a record to ensure db, coll are created.
self.input_target.insert_one({'data': 'dummy'})
if __name__ == "__main__":
unittest.main()