From 2b5e2d08a24b76812573535e983c6f8bc5f80448 Mon Sep 17 00:00:00 2001 From: Rohan Pota Date: Wed, 13 May 2026 16:17:41 -0400 Subject: [PATCH] feat: fix linting errors and modified code to fit server output --- poc_asp.py | 153 ++++++++++++++++++++ pymongo/__init__.py | 35 ++--- pymongo/asynchronous/stream_processing.py | 50 +++---- pymongo/stream_processing_options.py | 15 +- pymongo/synchronous/stream_processing.py | 52 +++---- test/asynchronous/test_stream_processing.py | 46 +++--- test/test_stream_processing.py | 64 ++++---- 7 files changed, 273 insertions(+), 142 deletions(-) create mode 100644 poc_asp.py diff --git a/poc_asp.py b/poc_asp.py new file mode 100644 index 000000000..bd0d2d87f --- /dev/null +++ b/poc_asp.py @@ -0,0 +1,153 @@ +""" +POC: Atlas Stream Processing — create / start / stop / drop a stream processor. + +Fill in the FILL_ME values below before running: + python3 poc_asp.py + +Pipeline used: + $source → sample_stream_solar + $emit → __testLog +""" +from __future__ import annotations + +import asyncio +import pprint + +from pymongo import AsyncStreamProcessingClient +from pymongo.errors import OperationFailure + +# --------------------------------------------------------------------------- +# Configuration — fill these in before running +# --------------------------------------------------------------------------- + +# Workspace connection string from Atlas UI: +# Stream Processing → your workspace → Connect +# Format: mongodb:/// (no mongodb+srv://) +WORKSPACE_URI = "mongodb://atlas-stream-69ed590869155100cecc8b33-lulzki.virginia-usa.a.query.mongodb-dev.net/" # e.g. "mongodb://atlas-stream-..a.query.mongodb.net/" + +# Atlas DB user credentials (must have atlasAdmin role on the workspace project) +USERNAME = "streams" +PASSWORD = "letsdostreaming123" + +# --------------------------------------------------------------------------- +# Pipeline — hardcoded per your setup +# --------------------------------------------------------------------------- + +PROCESSOR_NAME = "simpletestSP" + +PIPELINE = [ + { + "$source": { + "connectionName": "sample_stream_solar", + } + }, + { + "$emit": { + "connectionName": "__testLog", + } + }, +] + +# --------------------------------------------------------------------------- +# POC steps +# --------------------------------------------------------------------------- + +async def main() -> None: + if "FILL_ME" in (WORKSPACE_URI, USERNAME, PASSWORD): + raise SystemExit("Fill in WORKSPACE_URI, USERNAME, and PASSWORD at the top of this file.") + + async with AsyncStreamProcessingClient(WORKSPACE_URI, username=USERNAME, password=PASSWORD) as client: + sps = client.stream_processors() + + # ------------------------------------------------------------------ + # 1. Create + # ------------------------------------------------------------------ + print(f"\n[1] Creating processor '{PROCESSOR_NAME}' ...") + try: + await sps.create(PROCESSOR_NAME, pipeline=PIPELINE) + print(" Created OK") + except OperationFailure as e: + raise SystemExit(f" Create failed (code {e.code}): {e}") from e + + # ------------------------------------------------------------------ + # 2. Inspect before starting + # ------------------------------------------------------------------ + print("\n[2] Getting info ...") + info = await sps.get_info(PROCESSOR_NAME) + print(f" state : {info.state}") + print(f" pipeline_version : {info.pipeline_version}") + print(f" has_started : {info.has_started}") + + # ------------------------------------------------------------------ + # 3. Start + # ------------------------------------------------------------------ + proc = sps.get(PROCESSOR_NAME) + print("\n[3] Starting processor ...") + try: + await proc.start() + print(" Start command sent OK") + except OperationFailure as e: + raise SystemExit(f" Start failed (code {e.code}): {e}") from e + + # Give the server a moment to transition state + await asyncio.sleep(2) + + info = await sps.get_info(PROCESSOR_NAME) + print(f" state after start: {info.state}") + + # ------------------------------------------------------------------ + # 4. Stats + # ------------------------------------------------------------------ + print("\n[4] Fetching stats ...") + try: + raw_stats = await proc.stats() + pprint.pprint(raw_stats) + except OperationFailure as e: + print(f" Stats unavailable (code {e.code}): {e}") + + # ------------------------------------------------------------------ + # 5. Sample (up to 5 docs) + # Note: breaking manually after N docs because the dev server does not + # signal cursor exhaustion with cursorId=0 as the spec requires. + # ------------------------------------------------------------------ + print("\n[5] Sampling up to 5 documents ...") + try: + count = 0 + async for doc in proc.sample(): + print(f" doc: {doc}") + count += 1 + if count >= 5: + break + print(f" Sampled {count} document(s)") + except OperationFailure as e: + print(f" Sample unavailable (code {e.code}): {e}") + + # ------------------------------------------------------------------ + # 6. Stop + # ------------------------------------------------------------------ + print("\n[6] Stopping processor ...") + try: + await proc.stop() + print(" Stop command sent OK") + except OperationFailure as e: + raise SystemExit(f" Stop failed (code {e.code}): {e}") from e + + await asyncio.sleep(1) + info = await sps.get_info(PROCESSOR_NAME) + print(f" state after stop : {info.state}") + + # ------------------------------------------------------------------ + # 7. Drop (permanent — comment out to keep the processor alive) + # ------------------------------------------------------------------ + print("\n[7] Dropping processor ...") + try: + await proc.drop() + print(" Dropped OK") + except OperationFailure as e: + raise SystemExit(f" Drop failed (code {e.code}): {e}") from e + + print("\nDone.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pymongo/__init__.py b/pymongo/__init__.py index b988338c6..e56fed34a 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """Python driver for MongoDB.""" + from __future__ import annotations from typing import ContextManager, Optional @@ -104,6 +105,14 @@ TEXT = "text" from pymongo import _csot from pymongo._version import __version__, get_version_string, version_tuple from pymongo.asynchronous.mongo_client import AsyncMongoClient + +# Atlas Stream Processing (experimental) +from pymongo.asynchronous.stream_processing import ( + AsyncSampleCursor, + AsyncStreamProcessingClient, + AsyncStreamProcessor, + AsyncStreamProcessors, +) from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION, has_c from pymongo.cursor import CursorType from pymongo.operations import ( @@ -116,23 +125,6 @@ from pymongo.operations import ( UpdateOne, ) from pymongo.read_preferences import ReadPreference -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.write_concern import WriteConcern - -# Atlas Stream Processing (experimental) -from pymongo.asynchronous.stream_processing import ( - AsyncSampleCursor, - AsyncStreamProcessingClient, - AsyncStreamProcessor, - AsyncStreamProcessors, -) -from pymongo.synchronous.stream_processing import ( - SampleCursor, - StreamProcessingClient, - StreamProcessor, - StreamProcessors, -) from pymongo.stream_processing_options import ( CreateStreamProcessorOptions, GetStreamProcessorSamplesOptions, @@ -141,6 +133,15 @@ from pymongo.stream_processing_options import ( StartStreamProcessorOptions, StreamProcessorInfo, ) +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.mongo_client import MongoClient +from pymongo.synchronous.stream_processing import ( + SampleCursor, + StreamProcessingClient, + StreamProcessor, + StreamProcessors, +) +from pymongo.write_concern import WriteConcern # Public module compatibility imports # isort: off diff --git a/pymongo/asynchronous/stream_processing.py b/pymongo/asynchronous/stream_processing.py index ce207e0f2..ecf85d38a 100644 --- a/pymongo/asynchronous/stream_processing.py +++ b/pymongo/asynchronous/stream_processing.py @@ -51,6 +51,7 @@ do not maintain a closed list of valid codes — applications should branch on ``exc.code`` only when they need to react to a specific known code, and should always have a generic fallback for unknown codes. """ + from __future__ import annotations from typing import TYPE_CHECKING, Any, Mapping, Optional @@ -146,7 +147,7 @@ class AsyncStreamProcessingClient: if not uri_has_auth_source and not any(k.lower() == "authsource" for k in kwargs): kwargs["authSource"] = "admin" - self._client: AsyncMongoClient = AsyncMongoClient(*args, **kwargs) + self._client: AsyncMongoClient[Any] = AsyncMongoClient(*args, **kwargs) # NOTE: Per the ASP driver spec, server errors MUST be surfaced as-is. # Do NOT introduce error-code branching, rewrapping, or filtering anywhere @@ -180,7 +181,7 @@ class AsyncStreamProcessingClient: return await admin._retryable_read_command(cmd, session=session, operation=operation) return await admin.command(cmd, session=session) - def stream_processors(self) -> "AsyncStreamProcessors": + def stream_processors(self) -> AsyncStreamProcessors: """Return a handle for managing stream processors in this workspace.""" return AsyncStreamProcessors(self) @@ -188,14 +189,14 @@ class AsyncStreamProcessingClient: """Close the underlying client and release all resources.""" await self._client.close() - async def __aenter__(self) -> "AsyncStreamProcessingClient": + async def __aenter__(self) -> AsyncStreamProcessingClient: return self async def __aexit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], - exc_tb: Optional["TracebackType"], + exc_tb: Optional[TracebackType], ) -> None: await self.close() @@ -215,7 +216,7 @@ class AsyncStreamProcessors: Obtained via :meth:`AsyncStreamProcessingClient.stream_processors`. """ - def __init__(self, client: "AsyncStreamProcessingClient") -> None: + def __init__(self, client: AsyncStreamProcessingClient) -> None: self._client = client async def create( @@ -261,7 +262,7 @@ class AsyncStreamProcessors: cmd["options"] = opts await self._client._command(cmd, session=session) - def get(self, name: str) -> "AsyncStreamProcessor": + def get(self, name: str) -> AsyncStreamProcessor: """Return a handle for an existing stream processor by name. This is a pure factory method — it does **not** contact the server @@ -297,7 +298,8 @@ class AsyncStreamProcessors: raise InvalidOperation("Stream processor name must be a non-empty string.") cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR: name} response = await self._client._command(cmd, retryable_read=True, session=session) - return StreamProcessorInfo.from_response(response) + doc = response.get("result", response) + return StreamProcessorInfo.from_response(doc) class AsyncStreamProcessor: @@ -307,7 +309,7 @@ class AsyncStreamProcessor: Obtain via :meth:`AsyncStreamProcessors.get`. """ - def __init__(self, *, client: "AsyncStreamProcessingClient", name: str) -> None: + def __init__(self, *, client: AsyncStreamProcessingClient, name: str) -> None: if not name or not name.strip(): raise InvalidOperation("Stream processor name must be a non-empty string.") self._client = client @@ -374,9 +376,7 @@ class AsyncStreamProcessor: :class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use for this operation. """ - await self._client._command( - {_Op.STOP_STREAM_PROCESSOR: self.name}, session=session - ) + await self._client._command({_Op.STOP_STREAM_PROCESSOR: self.name}, session=session) async def drop( self, @@ -393,9 +393,7 @@ class AsyncStreamProcessor: :class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use for this operation. """ - await self._client._command( - {_Op.DROP_STREAM_PROCESSOR: self.name}, session=session - ) + await self._client._command({_Op.DROP_STREAM_PROCESSOR: self.name}, session=session) async def stats( self, @@ -461,9 +459,7 @@ class AsyncStreamProcessor: options = GetStreamProcessorSamplesOptions() if options.cursor_id == 0: - raise InvalidOperation( - "Sample cursor is exhausted; cursor_id 0 cannot be continued." - ) + raise InvalidOperation("Sample cursor is exhausted; cursor_id 0 cannot be continued.") if options.cursor_id is None: cmd: dict[str, Any] = {_Op.START_SAMPLE_STREAM_PROCESSOR: self.name} @@ -472,7 +468,7 @@ class AsyncStreamProcessor: resp = await self._client._command(cmd, session=session) return GetStreamProcessorSamplesResult( cursor_id=int(resp["cursorId"]), - documents=list(resp["firstBatch"]), + documents=list(resp.get("firstBatch", [])), ) else: cmd = { @@ -484,7 +480,7 @@ class AsyncStreamProcessor: resp = await self._client._command(cmd, session=session) return GetStreamProcessorSamplesResult( cursor_id=int(resp["cursorId"]), - documents=list(resp["nextBatch"]), + documents=list(resp.get("nextBatch", resp.get("messages", []))), ) def sample( @@ -493,7 +489,7 @@ class AsyncStreamProcessor: batch_size: Optional[int] = None, *, session: Optional[AsyncClientSession] = None, - ) -> "AsyncSampleCursor": + ) -> AsyncSampleCursor: """Open a sample cursor over this stream processor's output. Returns an async iterator that yields sampled documents until the @@ -541,10 +537,10 @@ class AsyncSampleCursor: def __init__( self, *, - processor: "AsyncStreamProcessor", + processor: AsyncStreamProcessor, limit: Optional[int] = None, batch_size: Optional[int] = None, - session: Optional["AsyncClientSession"] = None, + session: Optional[AsyncClientSession] = None, ) -> None: self._processor = processor self._limit = limit @@ -587,9 +583,7 @@ class AsyncSampleCursor: batch_size=self._batch_size, ) - result = await self._processor.get_stream_processor_samples( - opts, session=self._session - ) + result = await self._processor.get_stream_processor_samples(opts, session=self._session) self._cursor_id = result.cursor_id self._buffer.extend(result.documents) @@ -597,7 +591,7 @@ class AsyncSampleCursor: if result.cursor_id == 0: self._exhausted = True - def __aiter__(self) -> "AsyncSampleCursor": + def __aiter__(self) -> AsyncSampleCursor: return self async def __anext__(self) -> Mapping[str, Any]: @@ -627,13 +621,13 @@ class AsyncSampleCursor: """ self._closed = True - async def __aenter__(self) -> "AsyncSampleCursor": + async def __aenter__(self) -> AsyncSampleCursor: return self async def __aexit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], - exc_tb: Optional["TracebackType"], + exc_tb: Optional[TracebackType], ) -> None: await self.close() diff --git a/pymongo/stream_processing_options.py b/pymongo/stream_processing_options.py index 162859cb3..519980300 100644 --- a/pymongo/stream_processing_options.py +++ b/pymongo/stream_processing_options.py @@ -17,6 +17,7 @@ These classes are shared between the synchronous and asynchronous APIs — no async/sync split is needed for plain dataclasses. """ + from __future__ import annotations from dataclasses import dataclass, field @@ -55,7 +56,7 @@ class StartStreamProcessorOptions: workers: Optional[int] = None clear_checkpoints: Optional[bool] = None - start_at_operation_time: Optional["Timestamp"] = None + start_at_operation_time: Optional[Timestamp] = None start_after: Optional[Mapping[str, Any]] = None tier: Optional[str] = None enable_auto_scaling: Optional[bool] = None @@ -138,11 +139,11 @@ class StreamProcessorInfo: so that unknown or future fields are not silently discarded. """ - id: str + id: Optional[str] name: str state: str # plain str — drivers MUST NOT hard-code this as an Enum pipeline: list[Mapping[str, Any]] - pipeline_version: int + pipeline_version: Optional[int] tier: Optional[str] = None dlq: Optional[Mapping[str, Any]] = None stream_meta_field_name: Optional[str] = None @@ -161,18 +162,18 @@ class StreamProcessorInfo: raw: Mapping[str, Any] = field(default_factory=dict) @classmethod - def from_response(cls, doc: Mapping[str, Any]) -> "StreamProcessorInfo": + def from_response(cls, doc: Mapping[str, Any]) -> StreamProcessorInfo: """Construct a :class:`StreamProcessorInfo` from a server response document. Maps camelCase server keys to Python snake_case fields and stashes the full *doc* in :attr:`raw` so no fields are silently dropped. """ return cls( - id=doc["id"], + id=doc.get("id"), name=doc["name"], state=doc["state"], - pipeline=doc["pipeline"], - pipeline_version=doc["pipelineVersion"], + pipeline=doc.get("pipeline", []), + pipeline_version=doc.get("pipelineVersion"), tier=doc.get("tier"), dlq=doc.get("dlq"), stream_meta_field_name=doc.get("streamMetaFieldName"), diff --git a/pymongo/synchronous/stream_processing.py b/pymongo/synchronous/stream_processing.py index 246eeca75..acfc6476c 100644 --- a/pymongo/synchronous/stream_processing.py +++ b/pymongo/synchronous/stream_processing.py @@ -51,11 +51,11 @@ do not maintain a closed list of valid codes — applications should branch on ``exc.code`` only when they need to react to a specific known code, and should always have a generic fallback for unknown codes. """ + from __future__ import annotations from typing import TYPE_CHECKING, Any, Mapping, Optional -from pymongo.synchronous.mongo_client import MongoClient from pymongo.errors import ConfigurationError, InvalidOperation from pymongo.operations import _Op from pymongo.stream_processing_options import ( @@ -66,6 +66,7 @@ from pymongo.stream_processing_options import ( StartStreamProcessorOptions, StreamProcessorInfo, ) +from pymongo.synchronous.mongo_client import MongoClient from pymongo.uri_parser_shared import SRV_SCHEME, _validate_uri if TYPE_CHECKING: @@ -146,7 +147,7 @@ class StreamProcessingClient: if not uri_has_auth_source and not any(k.lower() == "authsource" for k in kwargs): kwargs["authSource"] = "admin" - self._client: MongoClient = MongoClient(*args, **kwargs) + self._client: MongoClient[Any] = MongoClient(*args, **kwargs) # NOTE: Per the ASP driver spec, server errors MUST be surfaced as-is. # Do NOT introduce error-code branching, rewrapping, or filtering anywhere @@ -180,7 +181,7 @@ class StreamProcessingClient: return admin._retryable_read_command(cmd, session=session, operation=operation) return admin.command(cmd, session=session) - def stream_processors(self) -> "StreamProcessors": + def stream_processors(self) -> StreamProcessors: """Return a handle for managing stream processors in this workspace.""" return StreamProcessors(self) @@ -188,14 +189,14 @@ class StreamProcessingClient: """Close the underlying client and release all resources.""" self._client.close() - def __enter__(self) -> "StreamProcessingClient": + def __enter__(self) -> StreamProcessingClient: return self def __exit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], - exc_tb: Optional["TracebackType"], + exc_tb: Optional[TracebackType], ) -> None: self.close() @@ -215,7 +216,7 @@ class StreamProcessors: Obtained via :meth:`StreamProcessingClient.stream_processors`. """ - def __init__(self, client: "StreamProcessingClient") -> None: + def __init__(self, client: StreamProcessingClient) -> None: self._client = client def create( @@ -261,7 +262,7 @@ class StreamProcessors: cmd["options"] = opts self._client._command(cmd, session=session) - def get(self, name: str) -> "StreamProcessor": + def get(self, name: str) -> StreamProcessor: """Return a handle for an existing stream processor by name. This is a pure factory method — it does **not** contact the server @@ -297,7 +298,8 @@ class StreamProcessors: raise InvalidOperation("Stream processor name must be a non-empty string.") cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR: name} response = self._client._command(cmd, retryable_read=True, session=session) - return StreamProcessorInfo.from_response(response) + doc = response.get("result", response) + return StreamProcessorInfo.from_response(doc) class StreamProcessor: @@ -307,7 +309,7 @@ class StreamProcessor: Obtain via :meth:`StreamProcessors.get`. """ - def __init__(self, *, client: "StreamProcessingClient", name: str) -> None: + def __init__(self, *, client: StreamProcessingClient, name: str) -> None: if not name or not name.strip(): raise InvalidOperation("Stream processor name must be a non-empty string.") self._client = client @@ -374,9 +376,7 @@ class StreamProcessor: :class:`~pymongo.client_session.ClientSession` to use for this operation. """ - self._client._command( - {_Op.STOP_STREAM_PROCESSOR: self.name}, session=session - ) + self._client._command({_Op.STOP_STREAM_PROCESSOR: self.name}, session=session) def drop( self, @@ -393,9 +393,7 @@ class StreamProcessor: :class:`~pymongo.client_session.ClientSession` to use for this operation. """ - self._client._command( - {_Op.DROP_STREAM_PROCESSOR: self.name}, session=session - ) + self._client._command({_Op.DROP_STREAM_PROCESSOR: self.name}, session=session) def stats( self, @@ -461,9 +459,7 @@ class StreamProcessor: options = GetStreamProcessorSamplesOptions() if options.cursor_id == 0: - raise InvalidOperation( - "Sample cursor is exhausted; cursor_id 0 cannot be continued." - ) + raise InvalidOperation("Sample cursor is exhausted; cursor_id 0 cannot be continued.") if options.cursor_id is None: cmd: dict[str, Any] = {_Op.START_SAMPLE_STREAM_PROCESSOR: self.name} @@ -472,7 +468,7 @@ class StreamProcessor: resp = self._client._command(cmd, session=session) return GetStreamProcessorSamplesResult( cursor_id=int(resp["cursorId"]), - documents=list(resp["firstBatch"]), + documents=list(resp.get("firstBatch", [])), ) else: cmd = { @@ -484,7 +480,7 @@ class StreamProcessor: resp = self._client._command(cmd, session=session) return GetStreamProcessorSamplesResult( cursor_id=int(resp["cursorId"]), - documents=list(resp["nextBatch"]), + documents=list(resp.get("nextBatch", resp.get("messages", []))), ) def sample( @@ -493,7 +489,7 @@ class StreamProcessor: batch_size: Optional[int] = None, *, session: Optional[ClientSession] = None, - ) -> "SampleCursor": + ) -> SampleCursor: """Open a sample cursor over this stream processor's output. Returns an async iterator that yields sampled documents until the @@ -541,10 +537,10 @@ class SampleCursor: def __init__( self, *, - processor: "StreamProcessor", + processor: StreamProcessor, limit: Optional[int] = None, batch_size: Optional[int] = None, - session: Optional["ClientSession"] = None, + session: Optional[ClientSession] = None, ) -> None: self._processor = processor self._limit = limit @@ -587,9 +583,7 @@ class SampleCursor: batch_size=self._batch_size, ) - result = self._processor.get_stream_processor_samples( - opts, session=self._session - ) + result = self._processor.get_stream_processor_samples(opts, session=self._session) self._cursor_id = result.cursor_id self._buffer.extend(result.documents) @@ -597,7 +591,7 @@ class SampleCursor: if result.cursor_id == 0: self._exhausted = True - def __iter__(self) -> "SampleCursor": + def __iter__(self) -> SampleCursor: return self def __next__(self) -> Mapping[str, Any]: @@ -627,13 +621,13 @@ class SampleCursor: """ self._closed = True - def __enter__(self) -> "SampleCursor": + def __enter__(self) -> SampleCursor: return self def __exit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], - exc_tb: Optional["TracebackType"], + exc_tb: Optional[TracebackType], ) -> None: self.close() diff --git a/test/asynchronous/test_stream_processing.py b/test/asynchronous/test_stream_processing.py index 1c4534cc3..f08692f5f 100644 --- a/test/asynchronous/test_stream_processing.py +++ b/test/asynchronous/test_stream_processing.py @@ -18,6 +18,7 @@ All tests run offline — no live workspace is required. The wire layer is stubbed out by replacing ``_command`` on the client with a lightweight spy that records calls and returns pre-configured responses. """ + from __future__ import annotations import inspect @@ -31,6 +32,7 @@ sys.path[0:0] = [""] _IS_SYNC = False +import pymongo.asynchronous.stream_processing from bson import Timestamp from pymongo import ( AsyncSampleCursor, @@ -44,10 +46,8 @@ from pymongo import ( StartStreamProcessorOptions, StreamProcessorInfo, ) -import pymongo.asynchronous.stream_processing from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure - # --------------------------------------------------------------------------- # Spy helper # --------------------------------------------------------------------------- @@ -167,9 +167,7 @@ class TestStreamProcessingClientConfig(unittest.TestCase): def test_workspace_endpoint_detection(self) -> None: from pymongo.asynchronous.stream_processing import _is_workspace_endpoint - self.assertTrue( - _is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net") - ) + self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net")) self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net")) self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net")) self.assertFalse(_is_workspace_endpoint("localhost")) @@ -495,7 +493,9 @@ class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase): responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}] ) proc = AsyncStreamProcessor(client=client, name="demo") - docs = [doc async for doc in proc.sample()] + docs = [] + async for doc in proc.sample(): + docs.append(doc) self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}]) self.assertEqual(len(calls), 1) @@ -508,7 +508,9 @@ class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase): ] ) proc = AsyncStreamProcessor(client=client, name="demo") - docs = [doc async for doc in proc.sample()] + docs = [] + async for doc in proc.sample(): + docs.append(doc) self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}]) self.assertEqual(len(calls), 3) @@ -521,14 +523,14 @@ class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase): ] ) proc = AsyncStreamProcessor(client=client, name="demo") - docs = [doc async for doc in proc.sample()] + docs = [] + async for doc in proc.sample(): + docs.append(doc) self.assertEqual(docs, [{"a": 1}]) self.assertEqual(len(calls), 3) async def test_sample_iterator_stops_after_cursor_id_zero(self) -> None: - client, calls = _spy_client( - responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}] - ) + client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}]) proc = AsyncStreamProcessor(client=client, name="demo") cursor = proc.sample() first = await cursor.__anext__() @@ -548,9 +550,7 @@ class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(calls), 0) async def test_sample_cursor_alive_property(self) -> None: - client, _ = _spy_client( - responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}] - ) + client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}]) proc = AsyncStreamProcessor(client=client, name="demo") cursor = proc.sample() self.assertTrue(cursor.alive) @@ -628,9 +628,7 @@ class AsyncTestErrorHandling(unittest.IsolatedAsyncioTestCase): async def test_get_info_propagates_operation_failure(self) -> None: client, _ = _spy_client(raises=self._failure(9)) - await self._assert_propagates( - lambda: AsyncStreamProcessors(client).get_info("p"), 9 - ) + await self._assert_propagates(lambda: AsyncStreamProcessors(client).get_info("p"), 9) async def test_stats_propagates_operation_failure(self) -> None: client, _ = _spy_client(raises=self._failure(72)) @@ -653,9 +651,7 @@ class AsyncTestErrorHandling(unittest.IsolatedAsyncioTestCase): def test_no_operation_failure_caught_in_module(self) -> None: """Structural: verify no try/except hides server errors in either module.""" - for mod in ( - pymongo.asynchronous.stream_processing, - ): + for mod in (pymongo.asynchronous.stream_processing,): src = inspect.getsource(mod) self.assertNotIn("except OperationFailure", src, mod.__name__) self.assertNotIn("except PyMongoError", src, mod.__name__) @@ -674,7 +670,11 @@ class AsyncTestSpecCompliance(unittest.IsolatedAsyncioTestCase): """Each command must route as retryable or non-retryable per the spec.""" cases: list[tuple[str, Any, bool]] = [ # (label, coroutine-factory, expected_retryable_read) - ("create", lambda c: AsyncStreamProcessors(c).create("demo", pipeline=[{"$x": 1}]), False), + ( + "create", + lambda c: AsyncStreamProcessors(c).create("demo", pipeline=[{"$x": 1}]), + False, + ), ("start", lambda c: AsyncStreamProcessor(client=c, name="demo").start(), False), ("stop", lambda c: AsyncStreamProcessor(client=c, name="demo").stop(), False), ("drop", lambda c: AsyncStreamProcessor(client=c, name="demo").drop(), False), @@ -682,7 +682,9 @@ class AsyncTestSpecCompliance(unittest.IsolatedAsyncioTestCase): ("stats", lambda c: AsyncStreamProcessor(client=c, name="demo").stats(), True), ( "get_samples_initial", - lambda c: AsyncStreamProcessor(client=c, name="demo").get_stream_processor_samples(), + lambda c: AsyncStreamProcessor( + client=c, name="demo" + ).get_stream_processor_samples(), False, ), ( diff --git a/test/test_stream_processing.py b/test/test_stream_processing.py index d630e7d95..b914fd5ae 100644 --- a/test/test_stream_processing.py +++ b/test/test_stream_processing.py @@ -18,6 +18,7 @@ All tests run offline — no live workspace is required. The wire layer is stubbed out by replacing ``_command`` on the client with a lightweight spy that records calls and returns pre-configured responses. """ + from __future__ import annotations import inspect @@ -31,23 +32,22 @@ sys.path[0:0] = [""] _IS_SYNC = True +import pymongo.synchronous.stream_processing from bson import Timestamp from pymongo import ( - SampleCursor, - StreamProcessingClient, - StreamProcessor, - StreamProcessors, CreateStreamProcessorOptions, GetStreamProcessorSamplesOptions, GetStreamProcessorSamplesResult, GetStreamProcessorStatsOptions, + SampleCursor, StartStreamProcessorOptions, + StreamProcessingClient, + StreamProcessor, StreamProcessorInfo, + StreamProcessors, ) -import pymongo.synchronous.stream_processing from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure - # --------------------------------------------------------------------------- # Spy helper # --------------------------------------------------------------------------- @@ -167,9 +167,7 @@ class TestStreamProcessingClientConfig(unittest.TestCase): def test_workspace_endpoint_detection(self) -> None: from pymongo.synchronous.stream_processing import _is_workspace_endpoint - self.assertTrue( - _is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net") - ) + self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net")) self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net")) self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net")) self.assertFalse(_is_workspace_endpoint("localhost")) @@ -465,9 +463,7 @@ class TestSampleCursor(unittest.TestCase): def test_get_samples_continuation_sends_get_more(self) -> None: client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": [{"y": 2}]}]) proc = StreamProcessor(client=client, name="demo") - result = proc.get_stream_processor_samples( - GetStreamProcessorSamplesOptions(cursor_id=42) - ) + result = proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=42)) cmd = calls[0]["cmd"] self.assertIn("getMoreSampleStreamProcessor", cmd) self.assertEqual(cmd.get("cursorId"), 42) @@ -495,7 +491,9 @@ class TestSampleCursor(unittest.TestCase): responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}] ) proc = StreamProcessor(client=client, name="demo") - docs = [doc for doc in proc.sample()] + docs = [] + for doc in proc.sample(): + docs.append(doc) self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}]) self.assertEqual(len(calls), 1) @@ -508,7 +506,9 @@ class TestSampleCursor(unittest.TestCase): ] ) proc = StreamProcessor(client=client, name="demo") - docs = [doc for doc in proc.sample()] + docs = [] + for doc in proc.sample(): + docs.append(doc) self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}]) self.assertEqual(len(calls), 3) @@ -521,14 +521,14 @@ class TestSampleCursor(unittest.TestCase): ] ) proc = StreamProcessor(client=client, name="demo") - docs = [doc for doc in proc.sample()] + docs = [] + for doc in proc.sample(): + docs.append(doc) self.assertEqual(docs, [{"a": 1}]) self.assertEqual(len(calls), 3) def test_sample_iterator_stops_after_cursor_id_zero(self) -> None: - client, calls = _spy_client( - responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}] - ) + client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}]) proc = StreamProcessor(client=client, name="demo") cursor = proc.sample() first = cursor.__next__() @@ -548,9 +548,7 @@ class TestSampleCursor(unittest.TestCase): self.assertEqual(len(calls), 0) def test_sample_cursor_alive_property(self) -> None: - client, _ = _spy_client( - responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}] - ) + client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}]) proc = StreamProcessor(client=client, name="demo") cursor = proc.sample() self.assertTrue(cursor.alive) @@ -610,33 +608,23 @@ class TestErrorHandling(unittest.TestCase): def test_start_propagates_operation_failure(self) -> None: client, _ = _spy_client(raises=self._failure(72)) - self._assert_propagates( - lambda: StreamProcessor(client=client, name="p").start(), 72 - ) + self._assert_propagates(lambda: StreamProcessor(client=client, name="p").start(), 72) def test_stop_propagates_operation_failure(self) -> None: client, _ = _spy_client(raises=self._failure(125)) - self._assert_propagates( - lambda: StreamProcessor(client=client, name="p").stop(), 125 - ) + self._assert_propagates(lambda: StreamProcessor(client=client, name="p").stop(), 125) def test_drop_propagates_operation_failure(self) -> None: client, _ = _spy_client(raises=self._failure(1)) - self._assert_propagates( - lambda: StreamProcessor(client=client, name="p").drop(), 1 - ) + self._assert_propagates(lambda: StreamProcessor(client=client, name="p").drop(), 1) def test_get_info_propagates_operation_failure(self) -> None: client, _ = _spy_client(raises=self._failure(9)) - self._assert_propagates( - lambda: StreamProcessors(client).get_info("p"), 9 - ) + self._assert_propagates(lambda: StreamProcessors(client).get_info("p"), 9) def test_stats_propagates_operation_failure(self) -> None: client, _ = _spy_client(raises=self._failure(72)) - self._assert_propagates( - lambda: StreamProcessor(client=client, name="p").stats(), 72 - ) + self._assert_propagates(lambda: StreamProcessor(client=client, name="p").stats(), 72) def test_get_stream_processor_samples_propagates_operation_failure(self) -> None: client, _ = _spy_client(raises=self._failure(9)) @@ -653,9 +641,7 @@ class TestErrorHandling(unittest.TestCase): def test_no_operation_failure_caught_in_module(self) -> None: """Structural: verify no try/except hides server errors in either module.""" - for mod in ( - pymongo.synchronous.stream_processing, - ): + for mod in (pymongo.synchronous.stream_processing,): src = inspect.getsource(mod) self.assertNotIn("except OperationFailure", src, mod.__name__) self.assertNotIn("except PyMongoError", src, mod.__name__)