PYTHON-1818 TypeCodec support for ChangeStreams
This commit is contained in:
parent
fbb56a2311
commit
2f2fe9db0d
@ -948,7 +948,13 @@ if _USE_C:
|
||||
|
||||
|
||||
def _decode_selective(rawdoc, fields, codec_options):
|
||||
doc = {}
|
||||
if _raw_document_class(codec_options.document_class):
|
||||
# If document_class is RawBSONDocument, use vanilla dictionary for
|
||||
# decoding command response.
|
||||
doc = {}
|
||||
else:
|
||||
# Else, use the specified document_class.
|
||||
doc = codec_options.document_class()
|
||||
for key, value in iteritems(rawdoc):
|
||||
if key in fields:
|
||||
if fields[key] == 1:
|
||||
|
||||
@ -19,6 +19,7 @@ from bson import _elements_to_dict, _get_object_size
|
||||
from bson.py3compat import abc, iteritems
|
||||
from bson.codec_options import (
|
||||
DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER)
|
||||
from bson.son import SON
|
||||
|
||||
|
||||
class RawBSONDocument(abc.Mapping):
|
||||
@ -93,8 +94,9 @@ class RawBSONDocument(abc.Mapping):
|
||||
if self.__inflated_doc is None:
|
||||
# We already validated the object's size when this document was
|
||||
# created, so no need to do that again.
|
||||
# Use SON to preserve ordering of elements.
|
||||
self.__inflated_doc = _elements_to_dict(
|
||||
self.__raw, 4, len(self.__raw)-1, self.__codec_options, {})
|
||||
self.__raw, 4, len(self.__raw)-1, self.__codec_options, SON())
|
||||
return self.__inflated_doc
|
||||
|
||||
def __getitem__(self, item):
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
|
||||
import copy
|
||||
|
||||
from bson import _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.son import SON
|
||||
|
||||
from pymongo import common
|
||||
@ -60,7 +62,16 @@ class ChangeStream(object):
|
||||
validate_collation_or_none(collation)
|
||||
common.validate_non_negative_integer_or_none("batchSize", batch_size)
|
||||
|
||||
self._target = target
|
||||
self._decode_custom = False
|
||||
self._orig_codec_options = target.codec_options
|
||||
if target.codec_options.type_registry._decoder_map:
|
||||
self._decode_custom = True
|
||||
self._target = target.with_options(
|
||||
codec_options=target.codec_options.with_options(
|
||||
document_class=RawBSONDocument, type_registry=None))
|
||||
else:
|
||||
self._target = target
|
||||
|
||||
self._pipeline = copy.deepcopy(pipeline)
|
||||
self._full_document = full_document
|
||||
self._resume_token = copy.deepcopy(resume_after)
|
||||
@ -145,8 +156,7 @@ class ChangeStream(object):
|
||||
aggregation_collection, cursor, sock_info.address,
|
||||
batch_size=self._batch_size or 0,
|
||||
max_await_time_ms=self._max_await_time_ms,
|
||||
session=session, explicit_session=explicit_session
|
||||
)
|
||||
session=session, explicit_session=explicit_session)
|
||||
|
||||
def _create_cursor(self):
|
||||
with self._database.client._tmp_session(self._session, close=False) as s:
|
||||
@ -263,6 +273,9 @@ class ChangeStream(object):
|
||||
"token is missing.")
|
||||
self._resume_token = copy.copy(resume_token)
|
||||
self._start_at_operation_time = None
|
||||
|
||||
if self._decode_custom:
|
||||
return _bson_to_dict(change.raw, self._orig_codec_options)
|
||||
return change
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
import datetime
|
||||
import sys
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
from random import random
|
||||
|
||||
@ -36,16 +37,18 @@ from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder,
|
||||
TypeEncoder, TypeRegistry)
|
||||
from bson.errors import InvalidDocument
|
||||
from bson.int64 import Int64
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.py3compat import text_type
|
||||
|
||||
from gridfs import GridIn, GridOut
|
||||
|
||||
from pymongo.collection import ReturnDocument
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
from pymongo.message import _CursorAddress
|
||||
|
||||
from test import client_context, unittest
|
||||
from test.test_client import IntegrationTest
|
||||
from test.utils import ignore_deprecations
|
||||
from test.utils import ignore_deprecations, rs_client
|
||||
|
||||
|
||||
class DecimalEncoder(TypeEncoder):
|
||||
@ -115,6 +118,14 @@ UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry(
|
||||
[UppercaseTextDecoder(),]))
|
||||
|
||||
|
||||
def type_obfuscating_decoder_factory(rt_type):
|
||||
class ResumeTokenToNanDecoder(TypeDecoder):
|
||||
bson_type = rt_type
|
||||
def transform_bson(self, value):
|
||||
return "NaN"
|
||||
return ResumeTokenToNanDecoder
|
||||
|
||||
|
||||
class CustomBSONTypeTests(object):
|
||||
def roundtrip(self, doc):
|
||||
bsonbytes = BSON().encode(doc, codec_options=self.codecopts)
|
||||
@ -549,7 +560,7 @@ class TestCollectionWCustomType(IntegrationTest):
|
||||
def test_find_w_custom_type_decoder(self):
|
||||
db = self.db
|
||||
input_docs = [
|
||||
{'x': Int64(k)} for k in [1.0, 2.0, 3.0]]
|
||||
{'x': Int64(k)} for k in [1, 2, 3]]
|
||||
for doc in input_docs:
|
||||
db.test.insert_one(doc)
|
||||
|
||||
@ -558,6 +569,24 @@ class TestCollectionWCustomType(IntegrationTest):
|
||||
for doc in test.find({}, batch_size=1):
|
||||
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
|
||||
|
||||
def test_find_w_custom_type_decoder_and_document_class(self):
|
||||
def run_test(doc_cls):
|
||||
db = self.db
|
||||
input_docs = [
|
||||
{'x': Int64(k)} for k in [1, 2, 3]]
|
||||
for doc in input_docs:
|
||||
db.test.insert_one(doc)
|
||||
|
||||
test = db.get_collection('test', codec_options=CodecOptions(
|
||||
type_registry=TypeRegistry([UndecipherableIntDecoder()]),
|
||||
document_class=doc_cls))
|
||||
for doc in test.find({}, batch_size=1):
|
||||
self.assertIsInstance(doc, doc_cls)
|
||||
self.assertIsInstance(doc['x'], UndecipherableInt64Type)
|
||||
|
||||
for doc_cls in [RawBSONDocument, OrderedDict]:
|
||||
run_test(doc_cls)
|
||||
|
||||
@client_context.require_version_max(4, 1, 0, -1)
|
||||
def test_group_w_custom_type(self):
|
||||
db = self.db
|
||||
@ -709,5 +738,155 @@ class TestGridFileCustomType(IntegrationTest):
|
||||
self.assertRaises(AttributeError, setattr, two, attr, 5)
|
||||
|
||||
|
||||
class ChangeStreamsWCustomTypesTestMixin(object):
|
||||
def change_stream(self, *args, **kwargs):
|
||||
return self.watched_target.watch(*args, **kwargs)
|
||||
|
||||
def insert_and_check(self, change_stream, insert_doc,
|
||||
expected_doc):
|
||||
self.input_target.insert_one(insert_doc)
|
||||
change = next(change_stream)
|
||||
self.assertEqual(change['fullDocument'], expected_doc)
|
||||
|
||||
def kill_change_stream_cursor(self, change_stream):
|
||||
# Cause a cursor not found error on the next getMore.
|
||||
cursor = change_stream._cursor
|
||||
address = _CursorAddress(cursor.address, cursor._CommandCursor__ns)
|
||||
client = self.input_target.database.client
|
||||
client._close_cursor_now(cursor.cursor_id, address)
|
||||
|
||||
def test_simple(self):
|
||||
codecopts = CodecOptions(type_registry=TypeRegistry([
|
||||
UndecipherableIntEncoder(), UppercaseTextDecoder()]))
|
||||
self.create_targets(codec_options=codecopts)
|
||||
|
||||
input_docs = [
|
||||
{'_id': UndecipherableInt64Type(1), 'data': 'hello'},
|
||||
{'_id': 2, 'data': 'world'},
|
||||
{'_id': UndecipherableInt64Type(3), 'data': '!'},]
|
||||
expected_docs = [
|
||||
{'_id': 1, 'data': 'HELLO'},
|
||||
{'_id': 2, 'data': 'WORLD'},
|
||||
{'_id': 3, 'data': '!'},]
|
||||
|
||||
change_stream = self.change_stream()
|
||||
|
||||
self.insert_and_check(change_stream, input_docs[0], expected_docs[0])
|
||||
self.kill_change_stream_cursor(change_stream)
|
||||
self.insert_and_check(change_stream, input_docs[1], expected_docs[1])
|
||||
self.kill_change_stream_cursor(change_stream)
|
||||
self.insert_and_check(change_stream, input_docs[2], expected_docs[2])
|
||||
|
||||
def test_break_resume_token(self):
|
||||
# Get one document from a change stream to determine resumeToken type.
|
||||
self.create_targets()
|
||||
change_stream = self.change_stream()
|
||||
self.input_target.insert_one({"data": "test"})
|
||||
change = next(change_stream)
|
||||
resume_token_decoder = type_obfuscating_decoder_factory(
|
||||
type(change['_id']['_data']))
|
||||
|
||||
# Custom-decoding the resumeToken type breaks resume tokens.
|
||||
codecopts = CodecOptions(type_registry=TypeRegistry([
|
||||
resume_token_decoder(), UndecipherableIntEncoder()]))
|
||||
|
||||
# Re-create targets, change stream and proceed.
|
||||
self.create_targets(codec_options=codecopts)
|
||||
|
||||
docs = [{'_id': 1}, {'_id': 2}, {'_id': 3}]
|
||||
|
||||
change_stream = self.change_stream()
|
||||
self.insert_and_check(change_stream, docs[0], docs[0])
|
||||
self.kill_change_stream_cursor(change_stream)
|
||||
self.insert_and_check(change_stream, docs[1], docs[1])
|
||||
self.kill_change_stream_cursor(change_stream)
|
||||
self.insert_and_check(change_stream, docs[2], docs[2])
|
||||
|
||||
def test_document_class(self):
|
||||
def run_test(doc_cls):
|
||||
codecopts = CodecOptions(type_registry=TypeRegistry([
|
||||
UppercaseTextDecoder(), UndecipherableIntEncoder()]),
|
||||
document_class=doc_cls)
|
||||
|
||||
self.create_targets(codec_options=codecopts)
|
||||
change_stream = self.change_stream()
|
||||
|
||||
doc = {'a': UndecipherableInt64Type(101), 'b': 'xyz'}
|
||||
self.input_target.insert_one(doc)
|
||||
change = next(change_stream)
|
||||
|
||||
self.assertIsInstance(change, doc_cls)
|
||||
self.assertEqual(change['fullDocument']['a'], 101)
|
||||
self.assertEqual(change['fullDocument']['b'], 'XYZ')
|
||||
|
||||
for doc_cls in [OrderedDict, RawBSONDocument]:
|
||||
run_test(doc_cls)
|
||||
|
||||
|
||||
class TestCollectionChangeStreamsWCustomTypes(
|
||||
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(3, 6, 0)
|
||||
@client_context.require_no_mmap
|
||||
@client_context.require_no_standalone
|
||||
def setUpClass(cls):
|
||||
super(TestCollectionChangeStreamsWCustomTypes, cls).setUpClass()
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
|
||||
def create_targets(self, *args, **kwargs):
|
||||
self.watched_target = self.db.get_collection(
|
||||
'test', *args, **kwargs)
|
||||
self.input_target = self.watched_target
|
||||
# Insert a record to ensure db, coll are created.
|
||||
self.input_target.insert_one({'data': 'dummy'})
|
||||
|
||||
|
||||
class TestDatabaseChangeStreamsWCustomTypes(
|
||||
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0)
|
||||
@client_context.require_no_mmap
|
||||
@client_context.require_no_standalone
|
||||
def setUpClass(cls):
|
||||
super(TestDatabaseChangeStreamsWCustomTypes, cls).setUpClass()
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
self.client.drop_database(self.watched_target)
|
||||
|
||||
def create_targets(self, *args, **kwargs):
|
||||
self.watched_target = self.client.get_database(
|
||||
self.db.name, *args, **kwargs)
|
||||
self.input_target = self.watched_target.test
|
||||
# Insert a record to ensure db, coll are created.
|
||||
self.input_target.insert_one({'data': 'dummy'})
|
||||
|
||||
|
||||
class TestClusterChangeStreamsWCustomTypes(
|
||||
IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0)
|
||||
@client_context.require_no_mmap
|
||||
@client_context.require_no_standalone
|
||||
def setUpClass(cls):
|
||||
super(TestClusterChangeStreamsWCustomTypes, cls).setUpClass()
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
self.client.drop_database(self.db)
|
||||
|
||||
def create_targets(self, *args, **kwargs):
|
||||
codec_options = kwargs.pop('codec_options', None)
|
||||
if codec_options:
|
||||
kwargs['type_registry'] = codec_options.type_registry
|
||||
kwargs['document_class'] = codec_options.document_class
|
||||
self.watched_target = rs_client(*args, **kwargs)
|
||||
self.input_target = self.watched_target[self.db.name].test
|
||||
# Insert a record to ensure db, coll are created.
|
||||
self.input_target.insert_one({'data': 'dummy'})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user