feat: fix linting errors and modified code to fit server output

This commit is contained in:
Rohan Pota 2026-05-13 16:17:41 -04:00
parent c747ee1587
commit 2b5e2d08a2
No known key found for this signature in database
GPG Key ID: C0DF6CE95D35117D
7 changed files with 273 additions and 142 deletions

153
poc_asp.py Normal file
View File

@ -0,0 +1,153 @@
"""
POC: Atlas Stream Processing create / start / stop / drop a stream processor.
Fill in the FILL_ME values below before running:
python3 poc_asp.py
Pipeline used:
$source sample_stream_solar
$emit __testLog
"""
from __future__ import annotations
import asyncio
import pprint
from pymongo import AsyncStreamProcessingClient
from pymongo.errors import OperationFailure
# ---------------------------------------------------------------------------
# Configuration — fill these in before running
# ---------------------------------------------------------------------------
# Workspace connection string from Atlas UI:
# Stream Processing → your workspace → Connect
# Format: mongodb://<host>/ (no mongodb+srv://)
WORKSPACE_URI = "mongodb://atlas-stream-69ed590869155100cecc8b33-lulzki.virginia-usa.a.query.mongodb-dev.net/" # e.g. "mongodb://atlas-stream-<id>.<region>.a.query.mongodb.net/"
# Atlas DB user credentials (must have atlasAdmin role on the workspace project)
USERNAME = "streams"
PASSWORD = "letsdostreaming123"
# ---------------------------------------------------------------------------
# Pipeline — hardcoded per your setup
# ---------------------------------------------------------------------------
PROCESSOR_NAME = "simpletestSP"
PIPELINE = [
{
"$source": {
"connectionName": "sample_stream_solar",
}
},
{
"$emit": {
"connectionName": "__testLog",
}
},
]
# ---------------------------------------------------------------------------
# POC steps
# ---------------------------------------------------------------------------
async def main() -> None:
if "FILL_ME" in (WORKSPACE_URI, USERNAME, PASSWORD):
raise SystemExit("Fill in WORKSPACE_URI, USERNAME, and PASSWORD at the top of this file.")
async with AsyncStreamProcessingClient(WORKSPACE_URI, username=USERNAME, password=PASSWORD) as client:
sps = client.stream_processors()
# ------------------------------------------------------------------
# 1. Create
# ------------------------------------------------------------------
print(f"\n[1] Creating processor '{PROCESSOR_NAME}' ...")
try:
await sps.create(PROCESSOR_NAME, pipeline=PIPELINE)
print(" Created OK")
except OperationFailure as e:
raise SystemExit(f" Create failed (code {e.code}): {e}") from e
# ------------------------------------------------------------------
# 2. Inspect before starting
# ------------------------------------------------------------------
print("\n[2] Getting info ...")
info = await sps.get_info(PROCESSOR_NAME)
print(f" state : {info.state}")
print(f" pipeline_version : {info.pipeline_version}")
print(f" has_started : {info.has_started}")
# ------------------------------------------------------------------
# 3. Start
# ------------------------------------------------------------------
proc = sps.get(PROCESSOR_NAME)
print("\n[3] Starting processor ...")
try:
await proc.start()
print(" Start command sent OK")
except OperationFailure as e:
raise SystemExit(f" Start failed (code {e.code}): {e}") from e
# Give the server a moment to transition state
await asyncio.sleep(2)
info = await sps.get_info(PROCESSOR_NAME)
print(f" state after start: {info.state}")
# ------------------------------------------------------------------
# 4. Stats
# ------------------------------------------------------------------
print("\n[4] Fetching stats ...")
try:
raw_stats = await proc.stats()
pprint.pprint(raw_stats)
except OperationFailure as e:
print(f" Stats unavailable (code {e.code}): {e}")
# ------------------------------------------------------------------
# 5. Sample (up to 5 docs)
# Note: breaking manually after N docs because the dev server does not
# signal cursor exhaustion with cursorId=0 as the spec requires.
# ------------------------------------------------------------------
print("\n[5] Sampling up to 5 documents ...")
try:
count = 0
async for doc in proc.sample():
print(f" doc: {doc}")
count += 1
if count >= 5:
break
print(f" Sampled {count} document(s)")
except OperationFailure as e:
print(f" Sample unavailable (code {e.code}): {e}")
# ------------------------------------------------------------------
# 6. Stop
# ------------------------------------------------------------------
print("\n[6] Stopping processor ...")
try:
await proc.stop()
print(" Stop command sent OK")
except OperationFailure as e:
raise SystemExit(f" Stop failed (code {e.code}): {e}") from e
await asyncio.sleep(1)
info = await sps.get_info(PROCESSOR_NAME)
print(f" state after stop : {info.state}")
# ------------------------------------------------------------------
# 7. Drop (permanent — comment out to keep the processor alive)
# ------------------------------------------------------------------
print("\n[7] Dropping processor ...")
try:
await proc.drop()
print(" Dropped OK")
except OperationFailure as e:
raise SystemExit(f" Drop failed (code {e.code}): {e}") from e
print("\nDone.")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Python driver for MongoDB."""
from __future__ import annotations
from typing import ContextManager, Optional
@ -104,6 +105,14 @@ TEXT = "text"
from pymongo import _csot
from pymongo._version import __version__, get_version_string, version_tuple
from pymongo.asynchronous.mongo_client import AsyncMongoClient
# Atlas Stream Processing (experimental)
from pymongo.asynchronous.stream_processing import (
AsyncSampleCursor,
AsyncStreamProcessingClient,
AsyncStreamProcessor,
AsyncStreamProcessors,
)
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION, has_c
from pymongo.cursor import CursorType
from pymongo.operations import (
@ -116,23 +125,6 @@ from pymongo.operations import (
UpdateOne,
)
from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.collection import ReturnDocument
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.write_concern import WriteConcern
# Atlas Stream Processing (experimental)
from pymongo.asynchronous.stream_processing import (
AsyncSampleCursor,
AsyncStreamProcessingClient,
AsyncStreamProcessor,
AsyncStreamProcessors,
)
from pymongo.synchronous.stream_processing import (
SampleCursor,
StreamProcessingClient,
StreamProcessor,
StreamProcessors,
)
from pymongo.stream_processing_options import (
CreateStreamProcessorOptions,
GetStreamProcessorSamplesOptions,
@ -141,6 +133,15 @@ from pymongo.stream_processing_options import (
StartStreamProcessorOptions,
StreamProcessorInfo,
)
from pymongo.synchronous.collection import ReturnDocument
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.stream_processing import (
SampleCursor,
StreamProcessingClient,
StreamProcessor,
StreamProcessors,
)
from pymongo.write_concern import WriteConcern
# Public module compatibility imports
# isort: off

View File

@ -51,6 +51,7 @@ do not maintain a closed list of valid codes — applications should branch
on ``exc.code`` only when they need to react to a specific known code, and
should always have a generic fallback for unknown codes.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
@ -146,7 +147,7 @@ class AsyncStreamProcessingClient:
if not uri_has_auth_source and not any(k.lower() == "authsource" for k in kwargs):
kwargs["authSource"] = "admin"
self._client: AsyncMongoClient = AsyncMongoClient(*args, **kwargs)
self._client: AsyncMongoClient[Any] = AsyncMongoClient(*args, **kwargs)
# NOTE: Per the ASP driver spec, server errors MUST be surfaced as-is.
# Do NOT introduce error-code branching, rewrapping, or filtering anywhere
@ -180,7 +181,7 @@ class AsyncStreamProcessingClient:
return await admin._retryable_read_command(cmd, session=session, operation=operation)
return await admin.command(cmd, session=session)
def stream_processors(self) -> "AsyncStreamProcessors":
def stream_processors(self) -> AsyncStreamProcessors:
"""Return a handle for managing stream processors in this workspace."""
return AsyncStreamProcessors(self)
@ -188,14 +189,14 @@ class AsyncStreamProcessingClient:
"""Close the underlying client and release all resources."""
await self._client.close()
async def __aenter__(self) -> "AsyncStreamProcessingClient":
async def __aenter__(self) -> AsyncStreamProcessingClient:
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional["TracebackType"],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()
@ -215,7 +216,7 @@ class AsyncStreamProcessors:
Obtained via :meth:`AsyncStreamProcessingClient.stream_processors`.
"""
def __init__(self, client: "AsyncStreamProcessingClient") -> None:
def __init__(self, client: AsyncStreamProcessingClient) -> None:
self._client = client
async def create(
@ -261,7 +262,7 @@ class AsyncStreamProcessors:
cmd["options"] = opts
await self._client._command(cmd, session=session)
def get(self, name: str) -> "AsyncStreamProcessor":
def get(self, name: str) -> AsyncStreamProcessor:
"""Return a handle for an existing stream processor by name.
This is a pure factory method it does **not** contact the server
@ -297,7 +298,8 @@ class AsyncStreamProcessors:
raise InvalidOperation("Stream processor name must be a non-empty string.")
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR: name}
response = await self._client._command(cmd, retryable_read=True, session=session)
return StreamProcessorInfo.from_response(response)
doc = response.get("result", response)
return StreamProcessorInfo.from_response(doc)
class AsyncStreamProcessor:
@ -307,7 +309,7 @@ class AsyncStreamProcessor:
Obtain via :meth:`AsyncStreamProcessors.get`.
"""
def __init__(self, *, client: "AsyncStreamProcessingClient", name: str) -> None:
def __init__(self, *, client: AsyncStreamProcessingClient, name: str) -> None:
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
self._client = client
@ -374,9 +376,7 @@ class AsyncStreamProcessor:
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
"""
await self._client._command(
{_Op.STOP_STREAM_PROCESSOR: self.name}, session=session
)
await self._client._command({_Op.STOP_STREAM_PROCESSOR: self.name}, session=session)
async def drop(
self,
@ -393,9 +393,7 @@ class AsyncStreamProcessor:
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
"""
await self._client._command(
{_Op.DROP_STREAM_PROCESSOR: self.name}, session=session
)
await self._client._command({_Op.DROP_STREAM_PROCESSOR: self.name}, session=session)
async def stats(
self,
@ -461,9 +459,7 @@ class AsyncStreamProcessor:
options = GetStreamProcessorSamplesOptions()
if options.cursor_id == 0:
raise InvalidOperation(
"Sample cursor is exhausted; cursor_id 0 cannot be continued."
)
raise InvalidOperation("Sample cursor is exhausted; cursor_id 0 cannot be continued.")
if options.cursor_id is None:
cmd: dict[str, Any] = {_Op.START_SAMPLE_STREAM_PROCESSOR: self.name}
@ -472,7 +468,7 @@ class AsyncStreamProcessor:
resp = await self._client._command(cmd, session=session)
return GetStreamProcessorSamplesResult(
cursor_id=int(resp["cursorId"]),
documents=list(resp["firstBatch"]),
documents=list(resp.get("firstBatch", [])),
)
else:
cmd = {
@ -484,7 +480,7 @@ class AsyncStreamProcessor:
resp = await self._client._command(cmd, session=session)
return GetStreamProcessorSamplesResult(
cursor_id=int(resp["cursorId"]),
documents=list(resp["nextBatch"]),
documents=list(resp.get("nextBatch", resp.get("messages", []))),
)
def sample(
@ -493,7 +489,7 @@ class AsyncStreamProcessor:
batch_size: Optional[int] = None,
*,
session: Optional[AsyncClientSession] = None,
) -> "AsyncSampleCursor":
) -> AsyncSampleCursor:
"""Open a sample cursor over this stream processor's output.
Returns an async iterator that yields sampled documents until the
@ -541,10 +537,10 @@ class AsyncSampleCursor:
def __init__(
self,
*,
processor: "AsyncStreamProcessor",
processor: AsyncStreamProcessor,
limit: Optional[int] = None,
batch_size: Optional[int] = None,
session: Optional["AsyncClientSession"] = None,
session: Optional[AsyncClientSession] = None,
) -> None:
self._processor = processor
self._limit = limit
@ -587,9 +583,7 @@ class AsyncSampleCursor:
batch_size=self._batch_size,
)
result = await self._processor.get_stream_processor_samples(
opts, session=self._session
)
result = await self._processor.get_stream_processor_samples(opts, session=self._session)
self._cursor_id = result.cursor_id
self._buffer.extend(result.documents)
@ -597,7 +591,7 @@ class AsyncSampleCursor:
if result.cursor_id == 0:
self._exhausted = True
def __aiter__(self) -> "AsyncSampleCursor":
def __aiter__(self) -> AsyncSampleCursor:
return self
async def __anext__(self) -> Mapping[str, Any]:
@ -627,13 +621,13 @@ class AsyncSampleCursor:
"""
self._closed = True
async def __aenter__(self) -> "AsyncSampleCursor":
async def __aenter__(self) -> AsyncSampleCursor:
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional["TracebackType"],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()

View File

@ -17,6 +17,7 @@
These classes are shared between the synchronous and asynchronous APIs
no async/sync split is needed for plain dataclasses.
"""
from __future__ import annotations
from dataclasses import dataclass, field
@ -55,7 +56,7 @@ class StartStreamProcessorOptions:
workers: Optional[int] = None
clear_checkpoints: Optional[bool] = None
start_at_operation_time: Optional["Timestamp"] = None
start_at_operation_time: Optional[Timestamp] = None
start_after: Optional[Mapping[str, Any]] = None
tier: Optional[str] = None
enable_auto_scaling: Optional[bool] = None
@ -138,11 +139,11 @@ class StreamProcessorInfo:
so that unknown or future fields are not silently discarded.
"""
id: str
id: Optional[str]
name: str
state: str # plain str — drivers MUST NOT hard-code this as an Enum
pipeline: list[Mapping[str, Any]]
pipeline_version: int
pipeline_version: Optional[int]
tier: Optional[str] = None
dlq: Optional[Mapping[str, Any]] = None
stream_meta_field_name: Optional[str] = None
@ -161,18 +162,18 @@ class StreamProcessorInfo:
raw: Mapping[str, Any] = field(default_factory=dict)
@classmethod
def from_response(cls, doc: Mapping[str, Any]) -> "StreamProcessorInfo":
def from_response(cls, doc: Mapping[str, Any]) -> StreamProcessorInfo:
"""Construct a :class:`StreamProcessorInfo` from a server response document.
Maps camelCase server keys to Python snake_case fields and stashes the
full *doc* in :attr:`raw` so no fields are silently dropped.
"""
return cls(
id=doc["id"],
id=doc.get("id"),
name=doc["name"],
state=doc["state"],
pipeline=doc["pipeline"],
pipeline_version=doc["pipelineVersion"],
pipeline=doc.get("pipeline", []),
pipeline_version=doc.get("pipelineVersion"),
tier=doc.get("tier"),
dlq=doc.get("dlq"),
stream_meta_field_name=doc.get("streamMetaFieldName"),

View File

@ -51,11 +51,11 @@ do not maintain a closed list of valid codes — applications should branch
on ``exc.code`` only when they need to react to a specific known code, and
should always have a generic fallback for unknown codes.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.errors import ConfigurationError, InvalidOperation
from pymongo.operations import _Op
from pymongo.stream_processing_options import (
@ -66,6 +66,7 @@ from pymongo.stream_processing_options import (
StartStreamProcessorOptions,
StreamProcessorInfo,
)
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.uri_parser_shared import SRV_SCHEME, _validate_uri
if TYPE_CHECKING:
@ -146,7 +147,7 @@ class StreamProcessingClient:
if not uri_has_auth_source and not any(k.lower() == "authsource" for k in kwargs):
kwargs["authSource"] = "admin"
self._client: MongoClient = MongoClient(*args, **kwargs)
self._client: MongoClient[Any] = MongoClient(*args, **kwargs)
# NOTE: Per the ASP driver spec, server errors MUST be surfaced as-is.
# Do NOT introduce error-code branching, rewrapping, or filtering anywhere
@ -180,7 +181,7 @@ class StreamProcessingClient:
return admin._retryable_read_command(cmd, session=session, operation=operation)
return admin.command(cmd, session=session)
def stream_processors(self) -> "StreamProcessors":
def stream_processors(self) -> StreamProcessors:
"""Return a handle for managing stream processors in this workspace."""
return StreamProcessors(self)
@ -188,14 +189,14 @@ class StreamProcessingClient:
"""Close the underlying client and release all resources."""
self._client.close()
def __enter__(self) -> "StreamProcessingClient":
def __enter__(self) -> StreamProcessingClient:
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional["TracebackType"],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
@ -215,7 +216,7 @@ class StreamProcessors:
Obtained via :meth:`StreamProcessingClient.stream_processors`.
"""
def __init__(self, client: "StreamProcessingClient") -> None:
def __init__(self, client: StreamProcessingClient) -> None:
self._client = client
def create(
@ -261,7 +262,7 @@ class StreamProcessors:
cmd["options"] = opts
self._client._command(cmd, session=session)
def get(self, name: str) -> "StreamProcessor":
def get(self, name: str) -> StreamProcessor:
"""Return a handle for an existing stream processor by name.
This is a pure factory method it does **not** contact the server
@ -297,7 +298,8 @@ class StreamProcessors:
raise InvalidOperation("Stream processor name must be a non-empty string.")
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR: name}
response = self._client._command(cmd, retryable_read=True, session=session)
return StreamProcessorInfo.from_response(response)
doc = response.get("result", response)
return StreamProcessorInfo.from_response(doc)
class StreamProcessor:
@ -307,7 +309,7 @@ class StreamProcessor:
Obtain via :meth:`StreamProcessors.get`.
"""
def __init__(self, *, client: "StreamProcessingClient", name: str) -> None:
def __init__(self, *, client: StreamProcessingClient, name: str) -> None:
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
self._client = client
@ -374,9 +376,7 @@ class StreamProcessor:
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
"""
self._client._command(
{_Op.STOP_STREAM_PROCESSOR: self.name}, session=session
)
self._client._command({_Op.STOP_STREAM_PROCESSOR: self.name}, session=session)
def drop(
self,
@ -393,9 +393,7 @@ class StreamProcessor:
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
"""
self._client._command(
{_Op.DROP_STREAM_PROCESSOR: self.name}, session=session
)
self._client._command({_Op.DROP_STREAM_PROCESSOR: self.name}, session=session)
def stats(
self,
@ -461,9 +459,7 @@ class StreamProcessor:
options = GetStreamProcessorSamplesOptions()
if options.cursor_id == 0:
raise InvalidOperation(
"Sample cursor is exhausted; cursor_id 0 cannot be continued."
)
raise InvalidOperation("Sample cursor is exhausted; cursor_id 0 cannot be continued.")
if options.cursor_id is None:
cmd: dict[str, Any] = {_Op.START_SAMPLE_STREAM_PROCESSOR: self.name}
@ -472,7 +468,7 @@ class StreamProcessor:
resp = self._client._command(cmd, session=session)
return GetStreamProcessorSamplesResult(
cursor_id=int(resp["cursorId"]),
documents=list(resp["firstBatch"]),
documents=list(resp.get("firstBatch", [])),
)
else:
cmd = {
@ -484,7 +480,7 @@ class StreamProcessor:
resp = self._client._command(cmd, session=session)
return GetStreamProcessorSamplesResult(
cursor_id=int(resp["cursorId"]),
documents=list(resp["nextBatch"]),
documents=list(resp.get("nextBatch", resp.get("messages", []))),
)
def sample(
@ -493,7 +489,7 @@ class StreamProcessor:
batch_size: Optional[int] = None,
*,
session: Optional[ClientSession] = None,
) -> "SampleCursor":
) -> SampleCursor:
"""Open a sample cursor over this stream processor's output.
Returns an async iterator that yields sampled documents until the
@ -541,10 +537,10 @@ class SampleCursor:
def __init__(
self,
*,
processor: "StreamProcessor",
processor: StreamProcessor,
limit: Optional[int] = None,
batch_size: Optional[int] = None,
session: Optional["ClientSession"] = None,
session: Optional[ClientSession] = None,
) -> None:
self._processor = processor
self._limit = limit
@ -587,9 +583,7 @@ class SampleCursor:
batch_size=self._batch_size,
)
result = self._processor.get_stream_processor_samples(
opts, session=self._session
)
result = self._processor.get_stream_processor_samples(opts, session=self._session)
self._cursor_id = result.cursor_id
self._buffer.extend(result.documents)
@ -597,7 +591,7 @@ class SampleCursor:
if result.cursor_id == 0:
self._exhausted = True
def __iter__(self) -> "SampleCursor":
def __iter__(self) -> SampleCursor:
return self
def __next__(self) -> Mapping[str, Any]:
@ -627,13 +621,13 @@ class SampleCursor:
"""
self._closed = True
def __enter__(self) -> "SampleCursor":
def __enter__(self) -> SampleCursor:
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional["TracebackType"],
exc_tb: Optional[TracebackType],
) -> None:
self.close()

View File

@ -18,6 +18,7 @@ All tests run offline — no live workspace is required. The wire layer is
stubbed out by replacing ``_command`` on the client with a lightweight spy
that records calls and returns pre-configured responses.
"""
from __future__ import annotations
import inspect
@ -31,6 +32,7 @@ sys.path[0:0] = [""]
_IS_SYNC = False
import pymongo.asynchronous.stream_processing
from bson import Timestamp
from pymongo import (
AsyncSampleCursor,
@ -44,10 +46,8 @@ from pymongo import (
StartStreamProcessorOptions,
StreamProcessorInfo,
)
import pymongo.asynchronous.stream_processing
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
# ---------------------------------------------------------------------------
# Spy helper
# ---------------------------------------------------------------------------
@ -167,9 +167,7 @@ class TestStreamProcessingClientConfig(unittest.TestCase):
def test_workspace_endpoint_detection(self) -> None:
from pymongo.asynchronous.stream_processing import _is_workspace_endpoint
self.assertTrue(
_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net")
)
self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net"))
self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net"))
self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net"))
self.assertFalse(_is_workspace_endpoint("localhost"))
@ -495,7 +493,9 @@ class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase):
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}]
)
proc = AsyncStreamProcessor(client=client, name="demo")
docs = [doc async for doc in proc.sample()]
docs = []
async for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
self.assertEqual(len(calls), 1)
@ -508,7 +508,9 @@ class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase):
]
)
proc = AsyncStreamProcessor(client=client, name="demo")
docs = [doc async for doc in proc.sample()]
docs = []
async for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
self.assertEqual(len(calls), 3)
@ -521,14 +523,14 @@ class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase):
]
)
proc = AsyncStreamProcessor(client=client, name="demo")
docs = [doc async for doc in proc.sample()]
docs = []
async for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}])
self.assertEqual(len(calls), 3)
async def test_sample_iterator_stops_after_cursor_id_zero(self) -> None:
client, calls = _spy_client(
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}]
)
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = AsyncStreamProcessor(client=client, name="demo")
cursor = proc.sample()
first = await cursor.__anext__()
@ -548,9 +550,7 @@ class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(calls), 0)
async def test_sample_cursor_alive_property(self) -> None:
client, _ = _spy_client(
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}]
)
client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = AsyncStreamProcessor(client=client, name="demo")
cursor = proc.sample()
self.assertTrue(cursor.alive)
@ -628,9 +628,7 @@ class AsyncTestErrorHandling(unittest.IsolatedAsyncioTestCase):
async def test_get_info_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
await self._assert_propagates(
lambda: AsyncStreamProcessors(client).get_info("p"), 9
)
await self._assert_propagates(lambda: AsyncStreamProcessors(client).get_info("p"), 9)
async def test_stats_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(72))
@ -653,9 +651,7 @@ class AsyncTestErrorHandling(unittest.IsolatedAsyncioTestCase):
def test_no_operation_failure_caught_in_module(self) -> None:
"""Structural: verify no try/except hides server errors in either module."""
for mod in (
pymongo.asynchronous.stream_processing,
):
for mod in (pymongo.asynchronous.stream_processing,):
src = inspect.getsource(mod)
self.assertNotIn("except OperationFailure", src, mod.__name__)
self.assertNotIn("except PyMongoError", src, mod.__name__)
@ -674,7 +670,11 @@ class AsyncTestSpecCompliance(unittest.IsolatedAsyncioTestCase):
"""Each command must route as retryable or non-retryable per the spec."""
cases: list[tuple[str, Any, bool]] = [
# (label, coroutine-factory, expected_retryable_read)
("create", lambda c: AsyncStreamProcessors(c).create("demo", pipeline=[{"$x": 1}]), False),
(
"create",
lambda c: AsyncStreamProcessors(c).create("demo", pipeline=[{"$x": 1}]),
False,
),
("start", lambda c: AsyncStreamProcessor(client=c, name="demo").start(), False),
("stop", lambda c: AsyncStreamProcessor(client=c, name="demo").stop(), False),
("drop", lambda c: AsyncStreamProcessor(client=c, name="demo").drop(), False),
@ -682,7 +682,9 @@ class AsyncTestSpecCompliance(unittest.IsolatedAsyncioTestCase):
("stats", lambda c: AsyncStreamProcessor(client=c, name="demo").stats(), True),
(
"get_samples_initial",
lambda c: AsyncStreamProcessor(client=c, name="demo").get_stream_processor_samples(),
lambda c: AsyncStreamProcessor(
client=c, name="demo"
).get_stream_processor_samples(),
False,
),
(

View File

@ -18,6 +18,7 @@ All tests run offline — no live workspace is required. The wire layer is
stubbed out by replacing ``_command`` on the client with a lightweight spy
that records calls and returns pre-configured responses.
"""
from __future__ import annotations
import inspect
@ -31,23 +32,22 @@ sys.path[0:0] = [""]
_IS_SYNC = True
import pymongo.synchronous.stream_processing
from bson import Timestamp
from pymongo import (
SampleCursor,
StreamProcessingClient,
StreamProcessor,
StreamProcessors,
CreateStreamProcessorOptions,
GetStreamProcessorSamplesOptions,
GetStreamProcessorSamplesResult,
GetStreamProcessorStatsOptions,
SampleCursor,
StartStreamProcessorOptions,
StreamProcessingClient,
StreamProcessor,
StreamProcessorInfo,
StreamProcessors,
)
import pymongo.synchronous.stream_processing
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
# ---------------------------------------------------------------------------
# Spy helper
# ---------------------------------------------------------------------------
@ -167,9 +167,7 @@ class TestStreamProcessingClientConfig(unittest.TestCase):
def test_workspace_endpoint_detection(self) -> None:
from pymongo.synchronous.stream_processing import _is_workspace_endpoint
self.assertTrue(
_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net")
)
self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net"))
self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net"))
self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net"))
self.assertFalse(_is_workspace_endpoint("localhost"))
@ -465,9 +463,7 @@ class TestSampleCursor(unittest.TestCase):
def test_get_samples_continuation_sends_get_more(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": [{"y": 2}]}])
proc = StreamProcessor(client=client, name="demo")
result = proc.get_stream_processor_samples(
GetStreamProcessorSamplesOptions(cursor_id=42)
)
result = proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=42))
cmd = calls[0]["cmd"]
self.assertIn("getMoreSampleStreamProcessor", cmd)
self.assertEqual(cmd.get("cursorId"), 42)
@ -495,7 +491,9 @@ class TestSampleCursor(unittest.TestCase):
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}]
)
proc = StreamProcessor(client=client, name="demo")
docs = [doc for doc in proc.sample()]
docs = []
for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
self.assertEqual(len(calls), 1)
@ -508,7 +506,9 @@ class TestSampleCursor(unittest.TestCase):
]
)
proc = StreamProcessor(client=client, name="demo")
docs = [doc for doc in proc.sample()]
docs = []
for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
self.assertEqual(len(calls), 3)
@ -521,14 +521,14 @@ class TestSampleCursor(unittest.TestCase):
]
)
proc = StreamProcessor(client=client, name="demo")
docs = [doc for doc in proc.sample()]
docs = []
for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}])
self.assertEqual(len(calls), 3)
def test_sample_iterator_stops_after_cursor_id_zero(self) -> None:
client, calls = _spy_client(
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}]
)
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = StreamProcessor(client=client, name="demo")
cursor = proc.sample()
first = cursor.__next__()
@ -548,9 +548,7 @@ class TestSampleCursor(unittest.TestCase):
self.assertEqual(len(calls), 0)
def test_sample_cursor_alive_property(self) -> None:
client, _ = _spy_client(
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}]
)
client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = StreamProcessor(client=client, name="demo")
cursor = proc.sample()
self.assertTrue(cursor.alive)
@ -610,33 +608,23 @@ class TestErrorHandling(unittest.TestCase):
def test_start_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(72))
self._assert_propagates(
lambda: StreamProcessor(client=client, name="p").start(), 72
)
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").start(), 72)
def test_stop_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(125))
self._assert_propagates(
lambda: StreamProcessor(client=client, name="p").stop(), 125
)
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").stop(), 125)
def test_drop_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(1))
self._assert_propagates(
lambda: StreamProcessor(client=client, name="p").drop(), 1
)
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").drop(), 1)
def test_get_info_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
self._assert_propagates(
lambda: StreamProcessors(client).get_info("p"), 9
)
self._assert_propagates(lambda: StreamProcessors(client).get_info("p"), 9)
def test_stats_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(72))
self._assert_propagates(
lambda: StreamProcessor(client=client, name="p").stats(), 72
)
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").stats(), 72)
def test_get_stream_processor_samples_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
@ -653,9 +641,7 @@ class TestErrorHandling(unittest.TestCase):
def test_no_operation_failure_caught_in_module(self) -> None:
"""Structural: verify no try/except hides server errors in either module."""
for mod in (
pymongo.synchronous.stream_processing,
):
for mod in (pymongo.synchronous.stream_processing,):
src = inspect.getsource(mod)
self.assertNotIn("except OperationFailure", src, mod.__name__)
self.assertNotIn("except PyMongoError", src, mod.__name__)