diff --git a/doc/changelog.rst b/doc/changelog.rst index 89f10e69f..08e4fb06b 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -62,6 +62,8 @@ Version 3.9 adds support for MongoDB 4.2. Highlights include: :meth:`~pymongo.operations.UpdateMany`. - :class:`~bson.binary.Binary` now supports any bytes-like type that implements the buffer protocol. +- Resume tokens can now be accessed from a ``ChangeStream`` cursor using the + :attr:`~pymongo.change_stream.ChangeStream.resume_token` attribute. .. _URI options specification: https://github.com/mongodb/specifications/blob/master/source/uri-options/uri-options.rst diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 0c5e2e31b..dc72d0bda 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -77,13 +77,16 @@ class ChangeStream(object): self._pipeline = copy.deepcopy(pipeline) self._full_document = full_document - self._resume_token = copy.deepcopy(resume_after) + self._uses_start_after = start_after is not None + self._uses_resume_after = resume_after is not None + self._resume_token = copy.deepcopy(start_after or resume_after) self._max_await_time_ms = max_await_time_ms self._batch_size = batch_size self._collation = collation self._start_at_operation_time = start_at_operation_time self._session = session - self._start_after = copy.deepcopy(start_after) + + # Initialize cursor. self._cursor = self._create_cursor() @property @@ -102,10 +105,14 @@ class ChangeStream(object): options = {} if self._full_document is not None: options['fullDocument'] = self._full_document - if self._resume_token is not None: - options['resumeAfter'] = self._resume_token - if self._start_after is not None: - options['startAfter'] = self._start_after + + resume_token = self.resume_token + if resume_token is not None: + if self._uses_start_after: + options['startAfter'] = resume_token + if self._uses_resume_after: + options['resumeAfter'] = resume_token + if self._start_at_operation_time is not None: options['startAtOperationTime'] = self._start_at_operation_time return options @@ -127,12 +134,18 @@ class ChangeStream(object): return full_pipeline def _process_result(self, result, session, server, sock_info, slave_ok): - """Callback that records a change stream cursor's operationTime.""" - if (self._start_at_operation_time is None and - self._resume_token is None and - self._start_after is None and - sock_info.max_wire_version >= 7): - self._start_at_operation_time = result["operationTime"] + """Callback that caches the startAtOperationTime from a changeStream + aggregate command response containing an empty batch of change + documents. + + This is implemented as a callback because we need access to the wire + version in order to determine whether to cache this value. + """ + if not result['cursor']['firstBatch']: + if (self._start_at_operation_time is None and + self.resume_token is None and + sock_info.max_wire_version >= 7): + self._start_at_operation_time = result["operationTime"] def _run_aggregation_cmd(self, session, explicit_session): """Run the full aggregation pipeline for this ChangeStream and return @@ -168,6 +181,15 @@ class ChangeStream(object): def __iter__(self): return self + @property + def resume_token(self): + """The cached resume token that will be used to resume after the most + recently returned change. + + .. versionadded:: 3.9 + """ + return copy.deepcopy(self._resume_token) + def next(self): """Advance the cursor. @@ -249,10 +271,18 @@ class ChangeStream(object): self._resume() change = self._cursor._try_next(False) - # No changes are available. + # If no changes are available. if change is None: - return None + # We have either iterated over all documents in the cursor, + # OR the most-recently returned batch is empty. In either case, + # update the cached resume token with the postBatchResumeToken if + # one was returned. We also clear the startAtOperationTime. + if self._cursor._post_batch_resume_token is not None: + self._resume_token = self._cursor._post_batch_resume_token + self._start_at_operation_time = None + return change + # Else, changes are available. try: resume_token = change['_id'] except KeyError: @@ -260,9 +290,20 @@ class ChangeStream(object): raise InvalidOperation( "Cannot provide resume functionality when the resume " "token is missing.") - self._resume_token = copy.copy(resume_token) + + # If this is the last change document from the current batch, cache the + # postBatchResumeToken. + if (not self._cursor._has_next() and + self._cursor._post_batch_resume_token): + resume_token = self._cursor._post_batch_resume_token + + # Hereafter, don't use startAfter; instead use resumeAfter. + self._uses_start_after = False + self._uses_resume_after = True + + # Cache the resume token and clear startAtOperationTime. + self._resume_token = resume_token self._start_at_operation_time = None - self._start_after = None if self._decode_custom: return _bson_to_dict(change.raw, self._orig_codec_options) diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 40b1ad8ae..196d94dda 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -14,18 +14,14 @@ """CommandCursor class to iterate over command results.""" -import datetime - from collections import deque from bson.py3compat import integer_types -from pymongo import helpers from pymongo.errors import (ConnectionFailure, InvalidOperation, NotMasterError, OperationFailure) -from pymongo.message import (_convert_exception, - _CursorAddress, +from pymongo.message import (_CursorAddress, _GetMore, _RawBatchGetMore) @@ -43,8 +39,9 @@ class CommandCursor(object): """ self.__collection = collection self.__id = cursor_info['id'] - self.__address = address self.__data = deque(cursor_info['firstBatch']) + self.__postbatchresumetoken = cursor_info.get('postBatchResumeToken') + self.__address = address self.__batch_size = batch_size self.__max_await_time_ms = max_await_time_ms self.__session = session @@ -119,6 +116,17 @@ class CommandCursor(object): self.__batch_size = batch_size == 1 and 2 or batch_size return self + def _has_next(self): + """Returns `True` if the cursor has documents remaining from the + previous batch.""" + return len(self.__data) > 0 + + @property + def _post_batch_resume_token(self): + """Retrieve the postBatchResumeToken from the response to a + changeStream aggregate or getMore.""" + return self.__postbatchresumetoken + def __send_message(self, operation): """Send a getmore message and handle the response. """ @@ -157,6 +165,7 @@ class CommandCursor(object): if from_command: cursor = docs[0]['cursor'] documents = cursor['nextBatch'] + self.__postbatchresumetoken = cursor.get('postBatchResumeToken') self.__id = cursor['id'] else: documents = docs diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 8d0825f2f..c9415c463 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -23,12 +23,11 @@ import threading import time import uuid -from contextlib import contextmanager from itertools import product sys.path[0:0] = [''] -from bson import BSON, ObjectId, SON, json_util +from bson import BSON, ObjectId, SON, Timestamp, json_util from bson.binary import (ALL_UUID_REPRESENTATIONS, Binary, STANDARD, @@ -36,6 +35,7 @@ from bson.binary import (ALL_UUID_REPRESENTATIONS, from bson.py3compat import iteritems from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument +from pymongo import MongoClient from pymongo.change_stream import _NON_RESUMABLE_GETMORE_ERRORS from pymongo.command_cursor import CommandCursor from pymongo.errors import (InvalidOperation, OperationFailure, @@ -46,29 +46,105 @@ from pymongo.write_concern import WriteConcern from test import client_context, unittest, IntegrationTest from test.utils import ( - EventListener, WhiteListEventListener, rs_or_single_client -) + EventListener, WhiteListEventListener, rs_or_single_client, wait_until) -class ChangeStreamTryNextMixin(object): - +class TestChangeStreamBase(IntegrationTest): def change_stream_with_client(self, client, *args, **kwargs): + """Create a change stream using the given client and return it.""" raise NotImplementedError def change_stream(self, *args, **kwargs): + """Create a change stream using the default client and return it.""" return self.change_stream_with_client(self.client, *args, **kwargs) - def watched_collection(self): + def client_with_listener(self, *commands): + """Return a client with a WhiteListEventListener.""" + listener = WhiteListEventListener(*commands) + client = rs_or_single_client(event_listeners=[listener]) + self.addCleanup(client.close) + return client, listener + + def watched_collection(self, *args, **kwargs): """Return a collection that is watched by self.change_stream().""" + # Construct a unique collection for each test. + collname = '.'.join(self.id().rsplit('.', 2)[1:]) + return self.db.get_collection(collname, *args, **kwargs) + + def generate_invalidate_event(self, change_stream): + """Cause a change stream invalidate event.""" + raise NotImplementedError + + def generate_unique_collnames(self, numcolls): + """Generate numcolls collection names unique to a test.""" + collnames = [] + for idx in range(1, numcolls + 1): + collnames.append(self.id() + '_' + str(idx)) + return collnames + + def get_resume_token(self, invalidate=False): + """Get a resume token to use for starting a change stream.""" + # Ensure targeted collection exists before starting. + coll = self.watched_collection(write_concern=WriteConcern('majority')) + coll.insert_one({}) + + if invalidate: + with self.change_stream( + [{'$match': {'operationType': 'invalidate'}}]) as cs: + if isinstance(cs._target, MongoClient): + self.skipTest( + "cluster-level change streams cannot be invalidated") + self.generate_invalidate_event(cs) + return cs.next()['_id'] + else: + with self.change_stream() as cs: + coll.insert_one({'data': 1}) + return cs.next()['_id'] + + def get_start_at_operation_time(self): + """Get an operationTime. Advances the operation clock beyond the most + recently returned timestamp.""" + optime = self.client.admin.command("ping")["operationTime"] + return Timestamp(optime.time, optime.inc + 1) + + def insert_one_and_check(self, change_stream, doc): + """Insert a document and check that it shows up in the change stream.""" raise NotImplementedError def kill_change_stream_cursor(self, change_stream): - # Cause a cursor not found error on the next getMore. + """Cause a cursor not found error on the next getMore.""" cursor = change_stream._cursor address = _CursorAddress(cursor.address, cursor._CommandCursor__ns) client = self.watched_collection().database.client client._close_cursor_now(cursor.cursor_id, address) + +class APITestsMixin(object): + def test_watch(self): + with self.change_stream( + [{'$project': {'foo': 0}}], full_document='updateLookup', + max_await_time_ms=1000, batch_size=100) as change_stream: + self.assertEqual([{'$project': {'foo': 0}}], + change_stream._pipeline) + self.assertEqual('updateLookup', change_stream._full_document) + self.assertIsNone(change_stream.resume_token) + self.assertEqual(1000, change_stream._max_await_time_ms) + self.assertEqual(100, change_stream._batch_size) + self.assertIsInstance(change_stream._cursor, CommandCursor) + self.assertEqual( + 1000, change_stream._cursor._CommandCursor__max_await_time_ms) + self.watched_collection( + write_concern=WriteConcern("majority")).insert_one({}) + _ = change_stream.next() + resume_token = change_stream.resume_token + with self.assertRaises(TypeError): + self.change_stream(pipeline={}) + with self.assertRaises(TypeError): + self.change_stream(full_document={}) + # No Error. + with self.change_stream(resume_after=resume_token): + pass + def test_try_next(self): # ChangeStreams only read majority committed data so use w:majority. coll = self.watched_collection().with_options( @@ -77,13 +153,13 @@ class ChangeStreamTryNextMixin(object): coll.insert_one({}) self.addCleanup(coll.drop) with self.change_stream(max_await_time_ms=250) as stream: - self.assertIsNone(stream.try_next()) - self.assertIsNone(stream._resume_token) - coll.insert_one({}) - change = stream.try_next() - self.assertEqual(change['_id'], stream._resume_token) - self.assertIsNone(stream.try_next()) - self.assertEqual(change['_id'], stream._resume_token) + self.assertIsNone(stream.try_next()) # No changes initially. + coll.insert_one({}) # Generate a change. + # On sharded clusters, even majority-committed changes only show + # up once an event that sorts after it shows up on the other + # shard. So, we wait on try_next to eventually return changes. + wait_until(lambda: stream.try_next() is not None, + "get change from try_next") def test_try_next_runs_one_getmore(self): listener = EventListener() @@ -115,8 +191,8 @@ class ChangeStreamTryNextMixin(object): # Get at least one change before resuming. coll.insert_one({'_id': 2}) - change = stream.try_next() - self.assertEqual(change['_id'], stream._resume_token) + wait_until(lambda: stream.try_next() is not None, + "get change from try_next") listener.results.clear() # Cause the next request to initiate the resume process. @@ -134,9 +210,10 @@ class ChangeStreamTryNextMixin(object): # Stream still works after a resume. coll.insert_one({'_id': 3}) - change = stream.try_next() - self.assertEqual(change['_id'], stream._resume_token) - self.assertEqual(listener.started_command_names(), ["getMore"]) + wait_until(lambda: stream.try_next() is not None, + "get change from try_next") + self.assertEqual(set(listener.started_command_names()), + set(["getMore"])) self.assertIsNone(stream.try_next()) def test_batch_size_is_honored(self): @@ -167,9 +244,504 @@ class ChangeStreamTryNextMixin(object): key = next(iter(expected)) self.assertEqual(expected[key], cmd[key]) + # $changeStream.startAtOperationTime was added in 4.0.0. + @client_context.require_version_min(4, 0, 0) + def test_start_at_operation_time(self): + optime = self.get_start_at_operation_time() -class TestClusterChangeStream(IntegrationTest, ChangeStreamTryNextMixin): + coll = self.watched_collection( + write_concern=WriteConcern("majority")) + ndocs = 3 + coll.insert_many([{"data": i} for i in range(ndocs)]) + with self.change_stream(start_at_operation_time=optime) as cs: + for i in range(ndocs): + cs.next() + + def _test_full_pipeline(self, expected_cs_stage): + client, listener = self.client_with_listener("aggregate") + results = listener.results + with self.change_stream_with_client( + client, [{'$project': {'foo': 0}}]) as _: + pass + + self.assertEqual(1, len(results['started'])) + command = results['started'][0] + self.assertEqual('aggregate', command.command_name) + self.assertEqual([ + {'$changeStream': expected_cs_stage}, + {'$project': {'foo': 0}}], + command.command['pipeline']) + + def test_full_pipeline(self): + """$changeStream must be the first stage in a change stream pipeline + sent to the server. + """ + self._test_full_pipeline({}) + + def test_iteration(self): + with self.change_stream(batch_size=2) as change_stream: + num_inserted = 10 + self.watched_collection().insert_many( + [{} for _ in range(num_inserted)]) + inserts_received = 0 + for change in change_stream: + self.assertEqual(change['operationType'], 'insert') + inserts_received += 1 + if inserts_received == num_inserted: + break + self._test_invalidate_stops_iteration(change_stream) + + def _test_next_blocks(self, change_stream): + inserted_doc = {'_id': ObjectId()} + changes = [] + t = threading.Thread( + target=lambda: changes.append(change_stream.next())) + t.start() + # Sleep for a bit to prove that the call to next() blocks. + time.sleep(1) + self.assertTrue(t.is_alive()) + self.assertFalse(changes) + self.watched_collection().insert_one(inserted_doc) + # Join with large timeout to give the server time to return the change, + # in particular for shard clusters. + t.join(30) + self.assertFalse(t.is_alive()) + self.assertEqual(1, len(changes)) + self.assertEqual(changes[0]['operationType'], 'insert') + self.assertEqual(changes[0]['fullDocument'], inserted_doc) + + def test_next_blocks(self): + """Test that next blocks until a change is readable""" + # Use a short await time to speed up the test. + with self.change_stream(max_await_time_ms=250) as change_stream: + self._test_next_blocks(change_stream) + + def test_aggregate_cursor_blocks(self): + """Test that an aggregate cursor blocks until a change is readable.""" + with self.watched_collection().aggregate( + [{'$changeStream': {}}], maxAwaitTimeMS=250) as change_stream: + self._test_next_blocks(change_stream) + + def test_concurrent_close(self): + """Ensure a ChangeStream can be closed from another thread.""" + # Use a short await time to speed up the test. + with self.change_stream(max_await_time_ms=250) as change_stream: + def iterate_cursor(): + for _ in change_stream: + pass + t = threading.Thread(target=iterate_cursor) + t.start() + self.watched_collection().insert_one({}) + time.sleep(1) + change_stream.close() + t.join(3) + self.assertFalse(t.is_alive()) + + def test_unknown_full_document(self): + """Must rely on the server to raise an error on unknown fullDocument. + """ + try: + with self.change_stream(full_document='notValidatedByPyMongo'): + pass + except OperationFailure: + pass + + def test_change_operations(self): + """Test each operation type.""" + expected_ns = {'db': self.watched_collection().database.name, + 'coll': self.watched_collection().name} + with self.change_stream() as change_stream: + # Insert. + inserted_doc = {'_id': ObjectId(), 'foo': 'bar'} + self.watched_collection().insert_one(inserted_doc) + change = change_stream.next() + self.assertTrue(change['_id']) + self.assertEqual(change['operationType'], 'insert') + self.assertEqual(change['ns'], expected_ns) + self.assertEqual(change['fullDocument'], inserted_doc) + # Update. + update_spec = {'$set': {'new': 1}, '$unset': {'foo': 1}} + self.watched_collection().update_one(inserted_doc, update_spec) + change = change_stream.next() + self.assertTrue(change['_id']) + self.assertEqual(change['operationType'], 'update') + self.assertEqual(change['ns'], expected_ns) + self.assertNotIn('fullDocument', change) + self.assertEqual({'updatedFields': {'new': 1}, + 'removedFields': ['foo']}, + change['updateDescription']) + # Replace. + self.watched_collection().replace_one({'new': 1}, {'foo': 'bar'}) + change = change_stream.next() + self.assertTrue(change['_id']) + self.assertEqual(change['operationType'], 'replace') + self.assertEqual(change['ns'], expected_ns) + self.assertEqual(change['fullDocument'], inserted_doc) + # Delete. + self.watched_collection().delete_one({'foo': 'bar'}) + change = change_stream.next() + self.assertTrue(change['_id']) + self.assertEqual(change['operationType'], 'delete') + self.assertEqual(change['ns'], expected_ns) + self.assertNotIn('fullDocument', change) + # Invalidate. + self._test_get_invalidate_event(change_stream) + + @client_context.require_version_min(4, 1, 1) + def test_start_after(self): + resume_token = self.get_resume_token(invalidate=True) + + # resume_after cannot resume after invalidate. + with self.assertRaises(OperationFailure): + self.change_stream(resume_after=resume_token) + + # start_after can resume after invalidate. + with self.change_stream(start_after=resume_token) as change_stream: + self.watched_collection().insert_one({'_id': 2}) + change = change_stream.next() + self.assertEqual(change['operationType'], 'insert') + self.assertEqual(change['fullDocument'], {'_id': 2}) + + @client_context.require_version_min(4, 1, 1) + def test_start_after_resume_process_with_changes(self): + resume_token = self.get_resume_token(invalidate=True) + + with self.change_stream(start_after=resume_token, + max_await_time_ms=250) as change_stream: + self.watched_collection().insert_one({'_id': 2}) + change = change_stream.next() + self.assertEqual(change['operationType'], 'insert') + self.assertEqual(change['fullDocument'], {'_id': 2}) + + self.assertIsNone(change_stream.try_next()) + self.kill_change_stream_cursor(change_stream) + + self.watched_collection().insert_one({'_id': 3}) + change = change_stream.next() + self.assertEqual(change['operationType'], 'insert') + self.assertEqual(change['fullDocument'], {'_id': 3}) + + @client_context.require_no_mongos # Remove after SERVER-41196 + @client_context.require_version_min(4, 1, 1) + def test_start_after_resume_process_without_changes(self): + resume_token = self.get_resume_token(invalidate=True) + + with self.change_stream(start_after=resume_token, + max_await_time_ms=250) as change_stream: + self.assertIsNone(change_stream.try_next()) + self.kill_change_stream_cursor(change_stream) + + self.watched_collection().insert_one({'_id': 2}) + change = change_stream.next() + self.assertEqual(change['operationType'], 'insert') + self.assertEqual(change['fullDocument'], {'_id': 2}) + + +class ProseSpecTestsMixin(object): + def _client_with_listener(self, *commands): + listener = WhiteListEventListener(*commands) + client = rs_or_single_client(event_listeners=[listener]) + self.addCleanup(client.close) + return client, listener + + def _populate_and_exhaust_change_stream(self, change_stream, batch_size=3): + self.watched_collection().insert_many( + [{"data": k} for k in range(batch_size)]) + for _ in range(batch_size): + change = next(change_stream) + return change + + def _get_expected_resume_token_legacy(self, stream, + listener, previous_change=None): + """Predicts what the resume token should currently be for server + versions that don't support postBatchResumeToken. Assumes the stream + has never returned any changes if previous_change is None.""" + if previous_change is None: + agg_cmd = listener.results['started'][0] + stage = agg_cmd.command["pipeline"][0]["$changeStream"] + return stage.get("resumeAfter") or stage.get("startAfter") + + return previous_change['_id'] + + def _get_expected_resume_token(self, stream, listener, + previous_change=None): + """Predicts what the resume token should currently be for server + versions that support postBatchResumeToken. Assumes the stream has + never returned any changes if previous_change is None. Assumes + listener is a WhiteListEventListener that listens for aggregate and + getMore commands.""" + if previous_change is None or stream._cursor._has_next(): + return self._get_expected_resume_token_legacy( + stream, listener, previous_change) + + response = listener.results['succeeded'][-1].reply + return response['cursor']['postBatchResumeToken'] + + def _test_raises_error_on_missing_id(self, expected_exception): + """ChangeStream will raise an exception if the server response is + missing the resume token. + """ + with self.change_stream([{'$project': {'_id': 0}}]) as change_stream: + self.watched_collection().insert_one({}) + with self.assertRaises(expected_exception): + next(change_stream) + # The cursor should now be closed. + with self.assertRaises(StopIteration): + next(change_stream) + + def _test_update_resume_token(self, expected_rt_getter): + """ChangeStream must continuously track the last seen resumeToken.""" + client, listener = self._client_with_listener("aggregate", "getMore") + coll = self.watched_collection(write_concern=WriteConcern('majority')) + with self.change_stream_with_client(client) as change_stream: + self.assertEqual( + change_stream.resume_token, + expected_rt_getter(change_stream, listener)) + for _ in range(3): + coll.insert_one({}) + change = next(change_stream) + self.assertEqual( + change_stream.resume_token, + expected_rt_getter(change_stream, listener, change)) + + # Prose test no. 1 + @client_context.require_version_min(4, 0, 7) + def test_update_resume_token(self): + self._test_update_resume_token(self._get_expected_resume_token) + + # Prose test no. 1 + @client_context.require_version_max(4, 0, 7) + def test_update_resume_token_legacy(self): + self._test_update_resume_token(self._get_expected_resume_token_legacy) + + # Prose test no. 2 + @client_context.require_version_min(4, 1, 8) + def test_raises_error_on_missing_id_418plus(self): + # Server returns an error on 4.1.8+ + self._test_raises_error_on_missing_id(OperationFailure) + + # Prose test no. 2 + @client_context.require_version_max(4, 1, 8) + def test_raises_error_on_missing_id_418minus(self): + # PyMongo raises an error + self._test_raises_error_on_missing_id(InvalidOperation) + + # Prose test no. 3 + def test_resume_on_error(self): + with self.change_stream() as change_stream: + self.insert_one_and_check(change_stream, {'_id': 1}) + # Cause a cursor not found error on the next getMore. + self.kill_change_stream_cursor(change_stream) + self.insert_one_and_check(change_stream, {'_id': 2}) + + # Prose test no. 5 + def test_does_not_resume_fatal_errors(self): + """ChangeStream will not attempt to resume fatal server errors.""" + for code in _NON_RESUMABLE_GETMORE_ERRORS: + with self.change_stream() as change_stream: + self.watched_collection().insert_one({}) + + def mock_try_next(*args, **kwargs): + change_stream._cursor.close() + raise OperationFailure('Mock server error', code=code) + + original_try_next = change_stream._cursor._try_next + change_stream._cursor._try_next = mock_try_next + + with self.assertRaises(OperationFailure): + next(change_stream) + change_stream._cursor._try_next = original_try_next + with self.assertRaises(StopIteration): + next(change_stream) + + # Prose test no. 7 + def test_initial_empty_batch(self): + with self.change_stream() as change_stream: + # The first batch should be empty. + self.assertFalse(change_stream._cursor._has_next()) + cursor_id = change_stream._cursor.cursor_id + self.assertTrue(cursor_id) + self.insert_one_and_check(change_stream, {}) + # Make sure we're still using the same cursor. + self.assertEqual(cursor_id, change_stream._cursor.cursor_id) + + # Prose test no. 8 + def test_kill_cursors(self): + def raise_error(): + raise ServerSelectionTimeoutError('mock error') + with self.change_stream() as change_stream: + self.insert_one_and_check(change_stream, {'_id': 1}) + # Cause a cursor not found error on the next getMore. + cursor = change_stream._cursor + self.kill_change_stream_cursor(change_stream) + cursor.close = raise_error + self.insert_one_and_check(change_stream, {'_id': 2}) + + # Prose test no. 9 + @client_context.require_version_min(4, 0, 0) + @client_context.require_version_max(4, 0, 7) + def test_start_at_operation_time_caching(self): + # Case 1: change stream not started with startAtOperationTime + client, listener = self.client_with_listener("aggregate") + with self.change_stream_with_client(client) as cs: + self.kill_change_stream_cursor(cs) + cs.try_next() + cmd = listener.results['started'][-1].command + self.assertIsNotNone(cmd["pipeline"][0]["$changeStream"].get( + "startAtOperationTime")) + + # Case 2: change stream started with startAtOperationTime + listener.results.clear() + optime = self.get_start_at_operation_time() + with self.change_stream_with_client( + client, start_at_operation_time=optime) as cs: + self.kill_change_stream_cursor(cs) + cs.try_next() + cmd = listener.results['started'][-1].command + self.assertEqual(cmd["pipeline"][0]["$changeStream"].get( + "startAtOperationTime"), optime, str([k.command for k in + listener.results['started']])) + + # Prose test no. 11 + @client_context.require_version_min(4, 0, 7) + def test_resumetoken_empty_batch(self): + client, listener = self._client_with_listener("getMore") + with self.change_stream_with_client(client) as change_stream: + self.assertIsNone(change_stream.try_next()) + resume_token = change_stream.resume_token + + response = listener.results['succeeded'][0].reply + self.assertEqual(resume_token, + response["cursor"]["postBatchResumeToken"]) + + # Prose test no. 11 + @client_context.require_version_min(4, 0, 7) + def test_resumetoken_exhausted_batch(self): + client, listener = self._client_with_listener("getMore") + with self.change_stream_with_client(client) as change_stream: + self._populate_and_exhaust_change_stream(change_stream) + resume_token = change_stream.resume_token + + response = listener.results['succeeded'][-1].reply + self.assertEqual(resume_token, + response["cursor"]["postBatchResumeToken"]) + + # Prose test no. 12 + @client_context.require_version_max(4, 0, 7) + def test_resumetoken_empty_batch_legacy(self): + resume_point = self.get_resume_token() + + # Empty resume token when neither resumeAfter or startAfter specified. + with self.change_stream() as change_stream: + change_stream.try_next() + self.assertIsNone(change_stream.resume_token) + + # Resume token value is same as resumeAfter. + with self.change_stream(resume_after=resume_point) as change_stream: + change_stream.try_next() + resume_token = change_stream.resume_token + self.assertEqual(resume_token, resume_point) + + # Prose test no. 12 + @client_context.require_version_max(4, 0, 7) + def test_resumetoken_exhausted_batch_legacy(self): + # Resume token is _id of last change. + with self.change_stream() as change_stream: + change = self._populate_and_exhaust_change_stream(change_stream) + self.assertEqual(change_stream.resume_token, change["_id"]) + resume_point = change['_id'] + + # Resume token is _id of last change even if resumeAfter is specified. + with self.change_stream(resume_after=resume_point) as change_stream: + change = self._populate_and_exhaust_change_stream(change_stream) + self.assertEqual(change_stream.resume_token, change["_id"]) + + # Prose test no. 13 + def test_resumetoken_partially_iterated_batch(self): + # When batch has been iterated up to but not including the last element. + # Resume token should be _id of previous change document. + with self.change_stream() as change_stream: + self.watched_collection( + write_concern=WriteConcern('majority')).insert_many( + [{"data": k} for k in range(3)]) + for _ in range(2): + change = next(change_stream) + resume_token = change_stream.resume_token + + self.assertEqual(resume_token, change["_id"]) + + def _test_resumetoken_uniterated_nonempty_batch(self, resume_option): + # When the batch is not empty and hasn't been iterated at all. + # Resume token should be same as the resume option used. + resume_point = self.get_resume_token() + + # Insert some documents so that firstBatch isn't empty. + self.watched_collection( + write_concern=WriteConcern("majority")).insert_many( + [{'a': 1}, {'b': 2}, {'c': 3}]) + + # Resume token should be same as the resume option. + with self.change_stream( + **{resume_option: resume_point}) as change_stream: + self.assertTrue(change_stream._cursor._has_next()) + resume_token = change_stream.resume_token + self.assertEqual(resume_token, resume_point) + + # Prose test no. 14 + @client_context.require_no_mongos + def test_resumetoken_uniterated_nonempty_batch_resumeafter(self): + self._test_resumetoken_uniterated_nonempty_batch("resume_after") + + # Prose test no. 14 + @client_context.require_no_mongos + @client_context.require_version_min(4, 1, 1) + def test_resumetoken_uniterated_nonempty_batch_startafter(self): + self._test_resumetoken_uniterated_nonempty_batch("start_after") + + # Prose test no. 17 + @client_context.require_version_min(4, 1, 1) + def test_startafter_resume_uses_startafter_after_empty_getMore(self): + # Resume should use startAfter after no changes have been returned. + resume_point = self.get_resume_token() + + client, listener = self._client_with_listener("aggregate") + with self.change_stream_with_client( + client, start_after=resume_point) as change_stream: + self.assertFalse(change_stream._cursor._has_next()) # No changes + change_stream.try_next() # No changes + self.kill_change_stream_cursor(change_stream) + change_stream.try_next() # Resume attempt + + response = listener.results['started'][-1] + self.assertIsNone( + response.command["pipeline"][0]["$changeStream"].get("resumeAfter")) + self.assertIsNotNone( + response.command["pipeline"][0]["$changeStream"].get("startAfter")) + + # Prose test no. 18 + @client_context.require_version_min(4, 1, 1) + def test_startafter_resume_uses_resumeafter_after_nonempty_getMore(self): + # Resume should use resumeAfter after some changes have been returned. + resume_point = self.get_resume_token() + + client, listener = self._client_with_listener("aggregate") + with self.change_stream_with_client( + client, start_after=resume_point) as change_stream: + self.assertFalse(change_stream._cursor._has_next()) # No changes + self.watched_collection().insert_one({}) + next(change_stream) # Changes + self.kill_change_stream_cursor(change_stream) + change_stream.try_next() # Resume attempt + + response = listener.results['started'][-1] + self.assertIsNotNone( + response.command["pipeline"][0]["$changeStream"].get("resumeAfter")) + self.assertIsNone( + response.command["pipeline"][0]["$changeStream"].get("startAfter")) + + +class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_no_mmap @@ -187,17 +759,18 @@ class TestClusterChangeStream(IntegrationTest, ChangeStreamTryNextMixin): def change_stream_with_client(self, client, *args, **kwargs): return client.watch(*args, **kwargs) - def watched_collection(self): - return self.db.test + def generate_invalidate_event(self, change_stream): + self.skipTest("cluster-level change streams cannot be invalidated") - def generate_unique_collnames(self, numcolls): - # Generate N collection names unique to a test. - collnames = [] - for idx in range(1, numcolls + 1): - collnames.append(self.id() + '_' + str(idx)) - return collnames + def _test_get_invalidate_event(self, change_stream): + # Cluster-level change streams don't get invalidated. + pass - def insert_and_check(self, change_stream, db, collname, doc): + def _test_invalidate_stops_iteration(self, change_stream): + # Cluster-level change streams don't get invalidated. + pass + + def _insert_and_check(self, change_stream, db, collname, doc): coll = db[collname] coll.insert_one(doc) change = next(change_stream) @@ -206,17 +779,34 @@ class TestClusterChangeStream(IntegrationTest, ChangeStreamTryNextMixin): 'coll': collname}) self.assertEqual(change['fullDocument'], doc) + def insert_one_and_check(self, change_stream, doc): + db = random.choice(self.dbs) + collname = self.id() + self._insert_and_check(change_stream, db, collname, doc) + def test_simple(self): collnames = self.generate_unique_collnames(3) with self.change_stream() as change_stream: for db, collname in product(self.dbs, collnames): - self.insert_and_check( + self._insert_and_check( change_stream, db, collname, {'_id': collname} ) + def test_aggregate_cursor_blocks(self): + """Test that an aggregate cursor blocks until a change is readable.""" + with self.client.admin.aggregate( + [{'$changeStream': {'allChangesForCluster': True}}], + maxAwaitTimeMS=250) as change_stream: + self._test_next_blocks(change_stream) -class TestDatabaseChangeStream(IntegrationTest, ChangeStreamTryNextMixin): + def test_full_pipeline(self): + """$changeStream must be the first stage in a change stream pipeline + sent to the server. + """ + self._test_full_pipeline({'allChangesForCluster': True}) + +class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_no_mmap @@ -227,17 +817,54 @@ class TestDatabaseChangeStream(IntegrationTest, ChangeStreamTryNextMixin): def change_stream_with_client(self, client, *args, **kwargs): return client[self.db.name].watch(*args, **kwargs) - def watched_collection(self): - return self.db.test + def generate_invalidate_event(self, change_stream): + # Dropping the database invalidates the change stream. + change_stream._client.drop_database(self.db.name) - def generate_unique_collnames(self, numcolls): - # Generate N collection names unique to a test. - collnames = [] - for idx in range(1, numcolls + 1): - collnames.append(self.id() + '_' + str(idx)) - return collnames + def _test_get_invalidate_event(self, change_stream): + # Cache collection names. + dropped_colls = self.db.list_collection_names() + # Drop the watched database to get an invalidate event. + self.generate_invalidate_event(change_stream) + change = change_stream.next() + # 4.1+ returns "drop" events for each collection in dropped database + # and a "dropDatabase" event for the database itself. + if change['operationType'] == 'drop': + self.assertTrue(change['_id']) + for _ in range(len(dropped_colls)): + ns = change['ns'] + self.assertEqual(ns['db'], change_stream._target.name) + self.assertIn(ns['coll'], dropped_colls) + change = change_stream.next() + self.assertEqual(change['operationType'], 'dropDatabase') + self.assertTrue(change['_id']) + self.assertEqual(change['ns'], {'db': change_stream._target.name}) + # Get next change. + change = change_stream.next() + self.assertTrue(change['_id']) + self.assertEqual(change['operationType'], 'invalidate') + self.assertNotIn('ns', change) + self.assertNotIn('fullDocument', change) + # The ChangeStream should be dead. + with self.assertRaises(StopIteration): + change_stream.next() - def insert_and_check(self, change_stream, collname, doc): + def _test_invalidate_stops_iteration(self, change_stream): + # Drop the watched database to get an invalidate event. + change_stream._client.drop_database(self.db.name) + # Check drop and dropDatabase events. + for change in change_stream: + self.assertIn(change['operationType'], ( + 'drop', 'dropDatabase', 'invalidate')) + # Last change must be invalidate. + self.assertEqual(change['operationType'], 'invalidate') + # Change stream must not allow further iteration. + with self.assertRaises(StopIteration): + change_stream.next() + with self.assertRaises(StopIteration): + next(change_stream) + + def _insert_and_check(self, change_stream, collname, doc): coll = self.db[collname] coll.insert_one(doc) change = next(change_stream) @@ -246,357 +873,121 @@ class TestDatabaseChangeStream(IntegrationTest, ChangeStreamTryNextMixin): 'coll': collname}) self.assertEqual(change['fullDocument'], doc) + def insert_one_and_check(self, change_stream, doc): + self._insert_and_check(change_stream, self.id(), doc) + def test_simple(self): collnames = self.generate_unique_collnames(3) with self.change_stream() as change_stream: for collname in collnames: - self.insert_and_check( - change_stream, collname, {'_id': uuid.uuid4()} - ) + self._insert_and_check( + change_stream, collname, {'_id': uuid.uuid4()}) def test_isolation(self): # Ensure inserts to other dbs don't show up in our ChangeStream. other_db = self.client.pymongo_test_temp self.assertNotEqual( - other_db, self.db, msg="Isolation must be tested on separate DBs" - ) + other_db, self.db, msg="Isolation must be tested on separate DBs") collname = self.id() with self.change_stream() as change_stream: other_db[collname].insert_one({'_id': uuid.uuid4()}) - self.insert_and_check( - change_stream, collname, {'_id': uuid.uuid4()} - ) + self._insert_and_check( + change_stream, collname, {'_id': uuid.uuid4()}) self.client.drop_database(other_db) -class TestCollectionChangeStream(IntegrationTest, ChangeStreamTryNextMixin): - +class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, + ProseSpecTestsMixin): @classmethod @client_context.require_version_min(3, 5, 11) @client_context.require_no_mmap @client_context.require_no_standalone def setUpClass(cls): super(TestCollectionChangeStream, cls).setUpClass() - cls.coll = cls.db.change_stream_test - # SERVER-31885 On a mongos the database must exist in order to create - # a changeStream cursor. However, WiredTiger drops the database when - # there are no more collections. Let's prevent that. - cls.db.prevent_implicit_database_deletion.insert_one({}) - - @classmethod - def tearDownClass(cls): - cls.db.prevent_implicit_database_deletion.drop() - super(TestCollectionChangeStream, cls).tearDownClass() def setUp(self): # Use a new collection for each test. - self.coll = self.db[self.id()] - - def tearDown(self): - self.coll.drop() + self.watched_collection().drop() + self.watched_collection().insert_one({}) def change_stream_with_client(self, client, *args, **kwargs): - return client[self.db.name].test.watch(*args, **kwargs) + return client[self.db.name].get_collection( + self.watched_collection().name).watch(*args, **kwargs) - def watched_collection(self): - return self.db.test + def generate_invalidate_event(self, change_stream): + # Dropping the collection invalidates the change stream. + change_stream._target.drop() - def insert_and_check(self, change_stream, doc): - self.coll.insert_one(doc) + def _test_invalidate_stops_iteration(self, change_stream): + self.generate_invalidate_event(change_stream) + # Check drop and dropDatabase events. + for change in change_stream: + self.assertIn(change['operationType'], ('drop', 'invalidate')) + # Last change must be invalidate. + self.assertEqual(change['operationType'], 'invalidate') + # Change stream must not allow further iteration. + with self.assertRaises(StopIteration): + change_stream.next() + with self.assertRaises(StopIteration): + next(change_stream) + + def _test_get_invalidate_event(self, change_stream): + # Drop the watched database to get an invalidate event. + change_stream._target.drop() + change = change_stream.next() + # 4.1+ returns a "drop" change document. + if change['operationType'] == 'drop': + self.assertTrue(change['_id']) + self.assertEqual(change['ns'], { + 'db': change_stream._target.database.name, + 'coll': change_stream._target.name}) + # Last change should be invalidate. + change = change_stream.next() + self.assertTrue(change['_id']) + self.assertEqual(change['operationType'], 'invalidate') + self.assertNotIn('ns', change) + self.assertNotIn('fullDocument', change) + # The ChangeStream should be dead. + with self.assertRaises(StopIteration): + change_stream.next() + + def insert_one_and_check(self, change_stream, doc): + self.watched_collection().insert_one(doc) change = next(change_stream) self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['ns'], {'db': self.coll.database.name, - 'coll': self.coll.name}) + self.assertEqual( + change['ns'], {'db': self.watched_collection().database.name, + 'coll': self.watched_collection().name}) self.assertEqual(change['fullDocument'], doc) - def test_watch(self): - with self.coll.watch( - [{'$project': {'foo': 0}}], full_document='updateLookup', - max_await_time_ms=1000, batch_size=100) as change_stream: - self.assertEqual([{'$project': {'foo': 0}}], - change_stream._pipeline) - self.assertEqual('updateLookup', change_stream._full_document) - self.assertIsNone(change_stream._resume_token) - self.assertEqual(1000, change_stream._max_await_time_ms) - self.assertEqual(100, change_stream._batch_size) - self.assertIsInstance(change_stream._cursor, CommandCursor) - self.assertEqual( - 1000, change_stream._cursor._CommandCursor__max_await_time_ms) - self.coll.insert_one({}) - change = change_stream.next() - resume_token = change['_id'] - with self.assertRaises(TypeError): - self.coll.watch(pipeline={}) - with self.assertRaises(TypeError): - self.coll.watch(full_document={}) - # No Error. - with self.coll.watch(resume_after=resume_token): - pass - - def test_full_pipeline(self): - """$changeStream must be the first stage in a change stream pipeline - sent to the server. - """ - listener = WhiteListEventListener("aggregate") - results = listener.results - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) - coll = client[self.db.name][self.coll.name] - with coll.watch([{'$project': {'foo': 0}}]) as _: - pass - - self.assertEqual(1, len(results['started'])) - command = results['started'][0] - self.assertEqual('aggregate', command.command_name) - self.assertEqual([ - {'$changeStream': {}}, - {'$project': {'foo': 0}}], - command.command['pipeline']) - - def test_iteration(self): - with self.coll.watch(batch_size=2) as change_stream: - num_inserted = 10 - self.coll.insert_many([{} for _ in range(num_inserted)]) - self.coll.drop() - inserts_received = 0 - for change in change_stream: - if change['operationType'] not in ('drop', 'invalidate'): - self.assertEqual(change['operationType'], 'insert') - inserts_received += 1 - self.assertEqual(num_inserted, inserts_received) - # Last change should be invalidate. - self.assertEqual(change['operationType'], 'invalidate') - with self.assertRaises(StopIteration): - change_stream.next() - with self.assertRaises(StopIteration): - next(change_stream) - - def _test_next_blocks(self, change_stream): - inserted_doc = {'_id': ObjectId()} - changes = [] - t = threading.Thread( - target=lambda: changes.append(change_stream.next())) - t.start() - # Sleep for a bit to prove that the call to next() blocks. - time.sleep(1) - self.assertTrue(t.is_alive()) - self.assertFalse(changes) - self.coll.insert_one(inserted_doc) - # Join with large timeout to give the server time to return the change, - # in particular for shard clusters. - t.join(30) - self.assertFalse(t.is_alive()) - self.assertEqual(1, len(changes)) - self.assertEqual(changes[0]['operationType'], 'insert') - self.assertEqual(changes[0]['fullDocument'], inserted_doc) - - def test_next_blocks(self): - """Test that next blocks until a change is readable""" - # Use a short await time to speed up the test. - with self.coll.watch(max_await_time_ms=250) as change_stream: - self._test_next_blocks(change_stream) - - def test_aggregate_cursor_blocks(self): - """Test that an aggregate cursor blocks until a change is readable.""" - with self.coll.aggregate([{'$changeStream': {}}], - maxAwaitTimeMS=250) as change_stream: - self._test_next_blocks(change_stream) - - def test_concurrent_close(self): - """Ensure a ChangeStream can be closed from another thread.""" - # Use a short await time to speed up the test. - with self.coll.watch(max_await_time_ms=250) as change_stream: - def iterate_cursor(): - for _ in change_stream: - pass - t = threading.Thread(target=iterate_cursor) - t.start() - self.coll.insert_one({}) - time.sleep(1) - change_stream.close() - t.join(3) - self.assertFalse(t.is_alive()) - - def test_update_resume_token(self): - """ChangeStream must continuously track the last seen resumeToken.""" - with self.coll.watch() as change_stream: - self.assertIsNone(change_stream._resume_token) - for _ in range(3): - self.coll.insert_one({}) - change = next(change_stream) - self.assertEqual(change['_id'], change_stream._resume_token) - - def _test_raises_error_on_missing_id(self, expected_exception): - """ChangeStream will raise an exception if the server response is - missing the resume token. - """ - with self.coll.watch([{'$project': {'_id': 0}}]) as change_stream: - self.coll.insert_one({}) - with self.assertRaises(expected_exception): - next(change_stream) - # The cursor should now be closed. - with self.assertRaises(StopIteration): - next(change_stream) - - @client_context.require_version_min(4, 1, 8) - def test_raises_error_on_missing_id_418plus(self): - # Server returns an error on 4.1.8+ - self._test_raises_error_on_missing_id(OperationFailure) - - @client_context.require_version_max(4, 1, 8) - def test_raises_error_on_missing_id_418minus(self): - # PyMongo raises an error - self._test_raises_error_on_missing_id(InvalidOperation) - - def test_resume_on_error(self): - """ChangeStream will automatically resume one time on a resumable - error (including not master) with the initial pipeline and options, - except for the addition/update of a resumeToken. - """ - with self.coll.watch([]) as change_stream: - self.insert_and_check(change_stream, {'_id': 1}) - # Cause a cursor not found error on the next getMore. - self.kill_change_stream_cursor(change_stream) - self.insert_and_check(change_stream, {'_id': 2}) - - def test_does_not_resume_fatal_errors(self): - """ChangeStream will not attempt to resume fatal server errors.""" - for code in _NON_RESUMABLE_GETMORE_ERRORS: - with self.coll.watch() as change_stream: - self.coll.insert_one({}) - - def mock_try_next(*args, **kwargs): - change_stream._cursor.close() - raise OperationFailure('Mock server error', code=code) - - original_try_next = change_stream._cursor._try_next - change_stream._cursor._try_next = mock_try_next - - with self.assertRaises(OperationFailure): - next(change_stream) - change_stream._cursor._try_next = original_try_next - with self.assertRaises(StopIteration): - next(change_stream) - - def test_initial_empty_batch(self): - """Ensure that a cursor returned from an aggregate command with a - cursor id, and an initial empty batch, is not closed on the driver - side. - """ - with self.coll.watch() as change_stream: - # The first batch should be empty. - self.assertEqual( - 0, len(change_stream._cursor._CommandCursor__data)) - cursor_id = change_stream._cursor.cursor_id - self.assertTrue(cursor_id) - self.insert_and_check(change_stream, {}) - # Make sure we're still using the same cursor. - self.assertEqual(cursor_id, change_stream._cursor.cursor_id) - - def test_kill_cursors(self): - """The killCursors command sent during the resume process must not be - allowed to raise an exception. - """ - def raise_error(): - raise ServerSelectionTimeoutError('mock error') - with self.coll.watch([]) as change_stream: - self.insert_and_check(change_stream, {'_id': 1}) - # Cause a cursor not found error on the next getMore. - cursor = change_stream._cursor - self.kill_change_stream_cursor(change_stream) - cursor.close = raise_error - self.insert_and_check(change_stream, {'_id': 2}) - - def test_unknown_full_document(self): - """Must rely on the server to raise an error on unknown fullDocument. - """ - try: - with self.coll.watch(full_document='notValidatedByPyMongo'): - pass - except OperationFailure: - pass - - def test_change_operations(self): - """Test each operation type.""" - expected_ns = {'db': self.coll.database.name, 'coll': self.coll.name} - with self.coll.watch() as change_stream: - # Insert. - inserted_doc = {'_id': ObjectId(), 'foo': 'bar'} - self.coll.insert_one(inserted_doc) - change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['ns'], expected_ns) - self.assertEqual(change['fullDocument'], inserted_doc) - # Update. - update_spec = {'$set': {'new': 1}, '$unset': {'foo': 1}} - self.coll.update_one(inserted_doc, update_spec) - change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'update') - self.assertEqual(change['ns'], expected_ns) - self.assertNotIn('fullDocument', change) - self.assertEqual({'updatedFields': {'new': 1}, - 'removedFields': ['foo']}, - change['updateDescription']) - # Replace. - self.coll.replace_one({'new': 1}, {'foo': 'bar'}) - change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'replace') - self.assertEqual(change['ns'], expected_ns) - self.assertEqual(change['fullDocument'], inserted_doc) - # Delete. - self.coll.delete_one({'foo': 'bar'}) - change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'delete') - self.assertEqual(change['ns'], expected_ns) - self.assertNotIn('fullDocument', change) - # Invalidate. - self.coll.drop() - change = change_stream.next() - # 4.1 returns a "drop" change document. - if change['operationType'] == 'drop': - self.assertTrue(change['_id']) - self.assertEqual(change['ns'], expected_ns) - # Last change should be invalidate. - change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'invalidate') - self.assertNotIn('ns', change) - self.assertNotIn('fullDocument', change) - # The ChangeStream should be dead. - with self.assertRaises(StopIteration): - change_stream.next() - def test_raw(self): """Test with RawBSONDocument.""" - raw_coll = self.coll.with_options( + raw_coll = self.watched_collection( codec_options=DEFAULT_RAW_BSON_OPTIONS) with raw_coll.watch() as change_stream: raw_doc = RawBSONDocument(BSON.encode({'_id': 1})) - self.coll.insert_one(raw_doc) + self.watched_collection().insert_one(raw_doc) change = next(change_stream) self.assertIsInstance(change, RawBSONDocument) self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['ns']['db'], self.coll.database.name) - self.assertEqual(change['ns']['coll'], self.coll.name) + self.assertEqual( + change['ns']['db'], self.watched_collection().database.name) + self.assertEqual( + change['ns']['coll'], self.watched_collection().name) self.assertEqual(change['fullDocument'], raw_doc) - self.assertEqual(change['_id'], change_stream._resume_token) def test_uuid_representations(self): """Test with uuid document _ids and different uuid_representation.""" for uuid_representation in ALL_UUID_REPRESENTATIONS: for id_subtype in (STANDARD, PYTHON_LEGACY): - resume_token = None - options = self.coll.codec_options.with_options( + options = self.watched_collection().codec_options.with_options( uuid_representation=uuid_representation) - coll = self.coll.with_options(codec_options=options) + coll = self.watched_collection(codec_options=options) with coll.watch() as change_stream: coll.insert_one( {'_id': Binary(uuid.uuid4().bytes, id_subtype)}) - resume_token = change_stream.next()['_id'] + _ = change_stream.next() + resume_token = change_stream.resume_token # Should not error. coll.watch(resume_after=resume_token) @@ -607,12 +998,13 @@ class TestCollectionChangeStream(IntegrationTest, ChangeStreamTryNextMixin): len(string.ascii_letters)) random_doc = {'_id': SON([(key, key) for key in random_keys])} for document_class in (dict, SON, RawBSONDocument): - options = self.coll.codec_options.with_options( + options = self.watched_collection().codec_options.with_options( document_class=document_class) - coll = self.coll.with_options(codec_options=options) + coll = self.watched_collection(codec_options=options) with coll.watch() as change_stream: coll.insert_one(random_doc) - resume_token = change_stream.next()['_id'] + _ = change_stream.next() + resume_token = change_stream.resume_token # The resume token is always a document. self.assertIsInstance(resume_token, document_class) @@ -623,73 +1015,15 @@ class TestCollectionChangeStream(IntegrationTest, ChangeStreamTryNextMixin): def test_read_concern(self): """Test readConcern is not validated by the driver.""" # Read concern 'local' is not allowed for $changeStream. - coll = self.coll.with_options(read_concern=ReadConcern('local')) + coll = self.watched_collection(read_concern=ReadConcern('local')) with self.assertRaises(OperationFailure): coll.watch() # Does not error. - coll = self.coll.with_options(read_concern=ReadConcern('majority')) + coll = self.watched_collection(read_concern=ReadConcern('majority')) with coll.watch(): pass - def invalidate_resume_token(self): - with self.coll.watch( - [{'$match': {'operationType': 'invalidate'}}]) as cs: - self.coll.insert_one({'_id': 1}) - self.coll.drop() - resume_token = cs.next()['_id'] - self.assertFalse(cs.alive) - return resume_token - - @client_context.require_version_min(4, 1, 1) - def test_start_after(self): - resume_token = self.invalidate_resume_token() - - # resume_after cannot resume after invalidate. - with self.assertRaises(OperationFailure): - self.coll.watch(resume_after=resume_token) - - # start_after can resume after invalidate. - with self.coll.watch(start_after=resume_token) as change_stream: - self.coll.insert_one({'_id': 2}) - change = change_stream.next() - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['fullDocument'], {'_id': 2}) - - @client_context.require_version_min(4, 1, 1) - def test_start_after_resume_process_with_changes(self): - resume_token = self.invalidate_resume_token() - - with self.coll.watch(start_after=resume_token, - max_await_time_ms=250) as change_stream: - self.coll.insert_one({'_id': 2}) - change = change_stream.next() - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['fullDocument'], {'_id': 2}) - - self.assertIsNone(change_stream.try_next()) - self.kill_change_stream_cursor(change_stream) - - self.coll.insert_one({'_id': 3}) - change = change_stream.next() - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['fullDocument'], {'_id': 3}) - - @client_context.require_no_mongos # Remove after SERVER-41196 - @client_context.require_version_min(4, 1, 1) - def test_start_after_resume_process_without_changes(self): - resume_token = self.invalidate_resume_token() - - with self.coll.watch(start_after=resume_token, - max_await_time_ms=250) as change_stream: - self.assertIsNone(change_stream.try_next()) - self.kill_change_stream_cursor(change_stream) - - self.coll.insert_one({'_id': 2}) - change = change_stream.next() - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['fullDocument'], {'_id': 2}) - class TestAllScenarios(unittest.TestCase): @@ -701,7 +1035,7 @@ class TestAllScenarios(unittest.TestCase): @classmethod def tearDownClass(cls): - cls.client + cls.client.close() def setUp(self): self.listener.results.clear()