PYTHON-4763 Migrate test_change_stream.py to async (#1853)

This commit is contained in:
Iris 2024-09-16 10:20:34 -07:00 committed by GitHub
parent 9b9cf73368
commit 0c0633da23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1274 additions and 21 deletions

File diff suppressed because it is too large Load Diff

View File

@ -15,6 +15,7 @@
"""Test the change_stream module."""
from __future__ import annotations
import asyncio
import os
import random
import string
@ -48,8 +49,11 @@ from pymongo.errors import (
from pymongo.message import _CursorAddress
from pymongo.read_concern import ReadConcern
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.helpers import next
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class TestChangeStreamBase(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
@ -97,17 +101,17 @@ class TestChangeStreamBase(IntegrationTest):
if isinstance(cs._target, MongoClient):
self.skipTest("cluster-level change streams cannot be invalidated")
self.generate_invalidate_event(cs)
return cs.next()["_id"]
return (cs.next())["_id"]
else:
with self.change_stream() as cs:
coll.insert_one({"data": 1})
return cs.next()["_id"]
return (cs.next())["_id"]
def get_start_at_operation_time(self):
"""Get an operationTime. Advances the operation clock beyond the most
recently returned timestamp.
"""
optime = self.client.admin.command("ping")["operationTime"]
optime = (self.client.admin.command("ping"))["operationTime"]
return Timestamp(optime.time, optime.inc + 1)
def insert_one_and_check(self, change_stream, doc):
@ -158,10 +162,14 @@ class APITestsMixin:
with self.change_stream(max_await_time_ms=250) as stream:
self.assertIsNone(stream.try_next()) # No changes initially.
coll.insert_one({}) # Generate a change.
# On sharded clusters, even majority-committed changes only show
# up once an event that sorts after it shows up on the other
# shard. So, we wait on try_next to eventually return changes.
wait_until(lambda: stream.try_next() is not None, "get change from try_next")
def _wait_until():
return stream.try_next() is not None
wait_until(_wait_until, "get change from try_next")
@no_type_check
def test_try_next_runs_one_getmore(self):
@ -192,7 +200,11 @@ class APITestsMixin:
# Get at least one change before resuming.
coll.insert_one({"_id": 2})
wait_until(lambda: stream.try_next() is not None, "get change from try_next")
def _wait_until():
return stream.try_next() is not None
wait_until(_wait_until, "get change from try_next")
listener.reset()
# Cause the next request to initiate the resume process.
@ -209,7 +221,11 @@ class APITestsMixin:
# Stream still works after a resume.
coll.insert_one({"_id": 3})
wait_until(lambda: stream.try_next() is not None, "get change from try_next")
def _wait_until():
return stream.try_next() is not None
wait_until(_wait_until, "get change from try_next")
self.assertEqual(set(listener.started_command_names()), {"getMore"})
self.assertIsNone(stream.try_next())
@ -289,6 +305,7 @@ class APITestsMixin:
self._test_invalidate_stops_iteration(change_stream)
@no_type_check
@client_context.require_sync
def _test_next_blocks(self, change_stream):
inserted_doc = {"_id": ObjectId()}
changes = []
@ -308,13 +325,15 @@ class APITestsMixin:
self.assertEqual(changes[0]["fullDocument"], inserted_doc)
@no_type_check
@client_context.require_sync
def test_next_blocks(self):
"""Test that next blocks until a change is readable"""
# Use a short await time to speed up the test.
# Use a short wait time to speed up the test.
with self.change_stream(max_await_time_ms=250) as change_stream:
self._test_next_blocks(change_stream)
@no_type_check
@client_context.require_sync
def test_aggregate_cursor_blocks(self):
"""Test that an aggregate cursor blocks until a change is readable."""
with self.watched_collection().aggregate(
@ -323,9 +342,10 @@ class APITestsMixin:
self._test_next_blocks(change_stream)
@no_type_check
@client_context.require_sync
def test_concurrent_close(self):
"""Ensure a ChangeStream can be closed from another thread."""
# Use a short await time to speed up the test.
# Use a short wait time to speed up the test.
with self.change_stream(max_await_time_ms=250) as change_stream:
def iterate_cursor():
@ -798,15 +818,15 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
@classmethod
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
cls.dbs = [cls.db, cls.client.pymongo_test_2]
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
for db in cls.dbs:
cls.client.drop_database(db)
super().tearDownClass()
super()._tearDown_class()
def change_stream_with_client(self, client, *args, **kwargs):
return client.watch(*args, **kwargs)
@ -841,6 +861,7 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
for db, collname in product(self.dbs, collnames):
self._insert_and_check(change_stream, db, collname, {"_id": collname})
@client_context.require_sync
def test_aggregate_cursor_blocks(self):
"""Test that an aggregate cursor blocks until a change is readable."""
with self.client.admin.aggregate(
@ -859,8 +880,8 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
@classmethod
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
def change_stream_with_client(self, client, *args, **kwargs):
return client[self.db.name].watch(*args, **kwargs)
@ -944,8 +965,8 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin):
@classmethod
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
def setUp(self):
# Use a new collection for each test.
@ -1023,7 +1044,7 @@ class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecT
@client_context.require_version_min(4, 0) # Needed for start_at_operation_time.
def test_uuid_representations(self):
"""Test with uuid document _ids and different uuid_representation."""
optime = self.db.command("ping")["operationTime"]
optime = (self.db.command("ping"))["operationTime"]
self.watched_collection().insert_many(
[
{"_id": Binary(uuid.uuid4().bytes, id_subtype)}
@ -1087,15 +1108,15 @@ class TestAllLegacyScenarios(IntegrationTest):
@classmethod
@client_context.require_connection
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
cls.client.close()
super().tearDownClass()
super()._tearDown_class()
def setUp(self):
super().setUp()

View File

@ -164,6 +164,7 @@ converted_tests = [
"test_auth.py",
"test_auth_spec.py",
"test_bulk.py",
"test_change_stream.py",
"test_client.py",
"test_client_bulk_write.py",
"test_collection.py",