This commit is contained in:
Rohan Pota 2026-05-13 16:25:20 -04:00 committed by GitHub
commit 9ba071f302
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 3299 additions and 0 deletions

View File

@ -56,6 +56,7 @@ Sub-modules:
results
server_api
server_description
stream_processing
topology_description
uri_parser
write_concern

View File

@ -0,0 +1,190 @@
:mod:`stream_processing` -- Atlas Stream Processing
====================================================
.. warning::
The Atlas Stream Processing API is **experimental**. The driver
specification is in Draft status and the API surface — including
retryability behavior, error code mappings, and method signatures —
may change in a backward-incompatible way before the spec is
finalized.
Overview
--------
Atlas Stream Processing (ASP) lets you build continuous, stateful pipelines
that process data from one or more sources in real time. A *stream processing
workspace* is a dedicated Atlas endpoint that hosts one or more named stream
processors; it is distinct from a standard MongoDB cluster and is accessed
through :class:`~pymongo.asynchronous.stream_processing.AsyncStreamProcessingClient`
(or its sync twin :class:`~pymongo.synchronous.stream_processing.StreamProcessingClient`)
rather than ``MongoClient``.
Workspace connection strings use the standard ``mongodb://`` scheme — ``mongodb+srv://``
is not supported. TLS is always required for workspace connections and cannot be
disabled. ``authSource`` defaults to ``"admin"`` if not explicitly set. Users must
hold the ``atlasAdmin`` role to execute ASP commands.
Quickstart
----------
Async
~~~~~
.. code-block:: python
import asyncio
from pymongo import AsyncStreamProcessingClient
async def main():
uri = (
"mongodb://user:pass@"
"atlas-stream-<workspaceId>-<suffix>.<region>.a.query.mongodb.net/"
)
async with AsyncStreamProcessingClient(uri) as client:
sps = client.stream_processors()
await sps.create("demo", pipeline=[
{"$source": {"connectionName": "<conn>", "topic": "events"}},
{"$match": {"value": {"$gt": 100}}},
{"$emit": {"connectionName": "<conn>", "db": "out", "coll": "high"}},
])
proc = sps.get("demo")
await proc.start()
info = await sps.get_info("demo")
print(info.state, info.pipeline_version)
async for doc in proc.sample(limit=10):
print(doc)
await proc.stop()
await proc.drop()
asyncio.run(main())
Sync
~~~~
.. code-block:: python
from pymongo import StreamProcessingClient
uri = (
"mongodb://user:pass@"
"atlas-stream-<workspaceId>-<suffix>.<region>.a.query.mongodb.net/"
)
with StreamProcessingClient(uri) as client:
sps = client.stream_processors()
sps.create("demo", pipeline=[
{"$source": {"connectionName": "<conn>", "topic": "events"}},
{"$match": {"value": {"$gt": 100}}},
{"$emit": {"connectionName": "<conn>", "db": "out", "coll": "high"}},
])
proc = sps.get("demo")
proc.start()
info = sps.get_info("demo")
print(info.state, info.pipeline_version)
for doc in proc.sample(limit=10):
print(doc)
proc.stop()
proc.drop()
Sample cursor semantics
-----------------------
The sample cursor is a custom two-phase protocol distinct from the standard
MongoDB ``getMore`` mechanism. It MUST NOT be confused with standard
:class:`~pymongo.asynchronous.cursor.AsyncCursor` objects.
For most use cases, call :meth:`~pymongo.asynchronous.stream_processing.AsyncStreamProcessor.sample`
to obtain an :class:`~pymongo.asynchronous.stream_processing.AsyncSampleCursor` and iterate it
with ``async for``. The cursor drives the underlying protocol automatically:
- The first iteration sends ``startSampleStreamProcessor``, optionally with a ``limit``.
- Subsequent iterations send ``getMoreSampleStreamProcessor`` with the ``cursorId``
returned by the previous call, optionally with a ``batchSize``.
- When the server returns ``cursorId: 0`` the cursor is exhausted and no further
wire calls are made.
For fine-grained control — tracking ``cursorId`` across calls yourself — use
:meth:`~pymongo.asynchronous.stream_processing.AsyncStreamProcessor.get_stream_processor_samples`
directly. Pass ``cursor_id=0`` and an :exc:`~pymongo.errors.InvalidOperation` is raised
immediately, before any wire call is sent.
Commands not yet supported
--------------------------
The following commands from the ASP server specification are intentionally
deferred and not yet wrapped by this API:
- ``modifyStreamProcessor`` — rename, pipeline replacement, and DLQ reconfiguration
- ``listStreamProcessors`` — enumerate processors in a workspace
- ``listStreamConnections`` — enumerate available connections
- ``processStreamProcessor`` — one-shot ad-hoc pipeline execution
- ``listWorkspaceDefaults`` — fetch workspace tier defaults
Users can still call any of these directly via ``run_command`` on a plain
:class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` connected to the
workspace endpoint.
Async classes
-------------
.. autoclass:: pymongo.asynchronous.stream_processing.AsyncStreamProcessingClient
:members:
:show-inheritance:
.. autoclass:: pymongo.asynchronous.stream_processing.AsyncStreamProcessors
:members:
:show-inheritance:
.. autoclass:: pymongo.asynchronous.stream_processing.AsyncStreamProcessor
:members:
:show-inheritance:
.. autoclass:: pymongo.asynchronous.stream_processing.AsyncSampleCursor
:members:
:show-inheritance:
Sync classes
------------
.. autoclass:: pymongo.synchronous.stream_processing.StreamProcessingClient
:members:
:show-inheritance:
.. autoclass:: pymongo.synchronous.stream_processing.StreamProcessors
:members:
:show-inheritance:
.. autoclass:: pymongo.synchronous.stream_processing.StreamProcessor
:members:
:show-inheritance:
.. autoclass:: pymongo.synchronous.stream_processing.SampleCursor
:members:
:show-inheritance:
Options and result types
------------------------
.. autoclass:: pymongo.stream_processing_options.CreateStreamProcessorOptions
:members:
:show-inheritance:
.. autoclass:: pymongo.stream_processing_options.StartStreamProcessorOptions
:members:
:show-inheritance:
.. autoclass:: pymongo.stream_processing_options.GetStreamProcessorStatsOptions
:members:
:show-inheritance:
.. autoclass:: pymongo.stream_processing_options.GetStreamProcessorSamplesOptions
:members:
:show-inheritance:
.. autoclass:: pymongo.stream_processing_options.GetStreamProcessorSamplesResult
:members:
:show-inheritance:
.. autoclass:: pymongo.stream_processing_options.StreamProcessorInfo
:members:
:show-inheritance:

View File

@ -1,6 +1,20 @@
Changelog
=========
Upcoming
--------
- Added experimental support for Atlas Stream Processing (ASP).
:class:`~pymongo.asynchronous.stream_processing.AsyncStreamProcessingClient`
and :class:`~pymongo.synchronous.stream_processing.StreamProcessingClient`
enable native client-side management of stream processors in an ASP workspace,
including ``createStreamProcessor``, ``startStreamProcessor``,
``stopStreamProcessor``, ``dropStreamProcessor``, ``getStreamProcessor``,
``getStreamProcessorStats``, and the two-phase sample cursor protocol
(``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``). See the
:mod:`stream_processing` API docs for details. The API is experimental and
may change before the driver specification is finalized.
Changes in Version 4.17.0 (2026/04/20)
--------------------------------------

153
poc_asp.py Normal file
View File

@ -0,0 +1,153 @@
"""
POC: Atlas Stream Processing create / start / stop / drop a stream processor.
Fill in the FILL_ME values below before running:
python3 poc_asp.py
Pipeline used:
$source sample_stream_solar
$emit __testLog
"""
from __future__ import annotations
import asyncio
import pprint
from pymongo import AsyncStreamProcessingClient
from pymongo.errors import OperationFailure
# ---------------------------------------------------------------------------
# Configuration — fill these in before running
# ---------------------------------------------------------------------------
# Workspace connection string from Atlas UI:
# Stream Processing → your workspace → Connect
# Format: mongodb://<host>/ (no mongodb+srv://)
WORKSPACE_URI = "mongodb://atlas-stream-69ed590869155100cecc8b33-lulzki.virginia-usa.a.query.mongodb-dev.net/" # e.g. "mongodb://atlas-stream-<id>.<region>.a.query.mongodb.net/"
# Atlas DB user credentials (must have atlasAdmin role on the workspace project)
USERNAME = "streams"
PASSWORD = "letsdostreaming123"
# ---------------------------------------------------------------------------
# Pipeline — hardcoded per your setup
# ---------------------------------------------------------------------------
PROCESSOR_NAME = "simpletestSP"
PIPELINE = [
{
"$source": {
"connectionName": "sample_stream_solar",
}
},
{
"$emit": {
"connectionName": "__testLog",
}
},
]
# ---------------------------------------------------------------------------
# POC steps
# ---------------------------------------------------------------------------
async def main() -> None:
if "FILL_ME" in (WORKSPACE_URI, USERNAME, PASSWORD):
raise SystemExit("Fill in WORKSPACE_URI, USERNAME, and PASSWORD at the top of this file.")
async with AsyncStreamProcessingClient(WORKSPACE_URI, username=USERNAME, password=PASSWORD) as client:
sps = client.stream_processors()
# ------------------------------------------------------------------
# 1. Create
# ------------------------------------------------------------------
print(f"\n[1] Creating processor '{PROCESSOR_NAME}' ...")
try:
await sps.create(PROCESSOR_NAME, pipeline=PIPELINE)
print(" Created OK")
except OperationFailure as e:
raise SystemExit(f" Create failed (code {e.code}): {e}") from e
# ------------------------------------------------------------------
# 2. Inspect before starting
# ------------------------------------------------------------------
print("\n[2] Getting info ...")
info = await sps.get_info(PROCESSOR_NAME)
print(f" state : {info.state}")
print(f" pipeline_version : {info.pipeline_version}")
print(f" has_started : {info.has_started}")
# ------------------------------------------------------------------
# 3. Start
# ------------------------------------------------------------------
proc = sps.get(PROCESSOR_NAME)
print("\n[3] Starting processor ...")
try:
await proc.start()
print(" Start command sent OK")
except OperationFailure as e:
raise SystemExit(f" Start failed (code {e.code}): {e}") from e
# Give the server a moment to transition state
await asyncio.sleep(2)
info = await sps.get_info(PROCESSOR_NAME)
print(f" state after start: {info.state}")
# ------------------------------------------------------------------
# 4. Stats
# ------------------------------------------------------------------
print("\n[4] Fetching stats ...")
try:
raw_stats = await proc.stats()
pprint.pprint(raw_stats)
except OperationFailure as e:
print(f" Stats unavailable (code {e.code}): {e}")
# ------------------------------------------------------------------
# 5. Sample (up to 5 docs)
# Note: breaking manually after N docs because the dev server does not
# signal cursor exhaustion with cursorId=0 as the spec requires.
# ------------------------------------------------------------------
print("\n[5] Sampling up to 5 documents ...")
try:
count = 0
async for doc in proc.sample():
print(f" doc: {doc}")
count += 1
if count >= 5:
break
print(f" Sampled {count} document(s)")
except OperationFailure as e:
print(f" Sample unavailable (code {e.code}): {e}")
# ------------------------------------------------------------------
# 6. Stop
# ------------------------------------------------------------------
print("\n[6] Stopping processor ...")
try:
await proc.stop()
print(" Stop command sent OK")
except OperationFailure as e:
raise SystemExit(f" Stop failed (code {e.code}): {e}") from e
await asyncio.sleep(1)
info = await sps.get_info(PROCESSOR_NAME)
print(f" state after stop : {info.state}")
# ------------------------------------------------------------------
# 7. Drop (permanent — comment out to keep the processor alive)
# ------------------------------------------------------------------
print("\n[7] Dropping processor ...")
try:
await proc.drop()
print(" Dropped OK")
except OperationFailure as e:
raise SystemExit(f" Drop failed (code {e.code}): {e}") from e
print("\nDone.")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Python driver for MongoDB."""
from __future__ import annotations
from typing import ContextManager, Optional
@ -45,6 +46,21 @@ __all__ = [
"WriteConcern",
"has_c",
"timeout",
# Atlas Stream Processing (experimental)
"AsyncSampleCursor",
"AsyncStreamProcessingClient",
"AsyncStreamProcessor",
"AsyncStreamProcessors",
"SampleCursor",
"StreamProcessingClient",
"StreamProcessor",
"StreamProcessors",
"CreateStreamProcessorOptions",
"GetStreamProcessorSamplesOptions",
"GetStreamProcessorSamplesResult",
"GetStreamProcessorStatsOptions",
"StartStreamProcessorOptions",
"StreamProcessorInfo",
]
ASCENDING = 1
@ -89,6 +105,14 @@ TEXT = "text"
from pymongo import _csot
from pymongo._version import __version__, get_version_string, version_tuple
from pymongo.asynchronous.mongo_client import AsyncMongoClient
# Atlas Stream Processing (experimental)
from pymongo.asynchronous.stream_processing import (
AsyncSampleCursor,
AsyncStreamProcessingClient,
AsyncStreamProcessor,
AsyncStreamProcessors,
)
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION, has_c
from pymongo.cursor import CursorType
from pymongo.operations import (
@ -101,8 +125,22 @@ from pymongo.operations import (
UpdateOne,
)
from pymongo.read_preferences import ReadPreference
from pymongo.stream_processing_options import (
CreateStreamProcessorOptions,
GetStreamProcessorSamplesOptions,
GetStreamProcessorSamplesResult,
GetStreamProcessorStatsOptions,
StartStreamProcessorOptions,
StreamProcessorInfo,
)
from pymongo.synchronous.collection import ReturnDocument
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.stream_processing import (
SampleCursor,
StreamProcessingClient,
StreamProcessor,
StreamProcessors,
)
from pymongo.write_concern import WriteConcern
# Public module compatibility imports

View File

@ -0,0 +1,633 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Atlas Stream Processing client for stream processing workspaces.
A stream processing workspace endpoint uses the ``mongodb://`` URI scheme with
a hostname that follows this pattern::
atlas-stream-<workspaceId>-<suffix>.<region>.a.query.mongodb.net
For example::
mongodb://atlas-stream-699c842ef433fe6001480b17-etif1.virginia-usa.a.query.mongodb.net/
TLS is always required for workspace connections and cannot be disabled.
``authSource=admin`` is applied by default when not explicitly set.
For commands not yet wrapped by this API, users can connect via a plain
:class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` and call
``run_command`` directly that path remains fully supported.
Error handling
~~~~~~~~~~~~~~
ASP commands raise :class:`pymongo.errors.OperationFailure` on server-side
errors. The following error codes are known to be returned by Atlas Stream
Processing commands:
================ ================= ====================================
Code Name When returned
================ ================= ====================================
9 FailedToParse Invalid pipeline or command document
72 InvalidOptions Invalid option values
125 CommandFailed General command execution failure
1 InternalError Unexpected server-side error
================ ================= ====================================
This list is **non-exhaustive** and may grow as the server evolves. Drivers
do not maintain a closed list of valid codes applications should branch
on ``exc.code`` only when they need to react to a specific known code, and
should always have a generic fallback for unknown codes.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.errors import ConfigurationError, InvalidOperation
from pymongo.operations import _Op
from pymongo.stream_processing_options import (
CreateStreamProcessorOptions,
GetStreamProcessorSamplesOptions,
GetStreamProcessorSamplesResult,
GetStreamProcessorStatsOptions,
StartStreamProcessorOptions,
StreamProcessorInfo,
)
from pymongo.uri_parser_shared import SRV_SCHEME, _validate_uri
if TYPE_CHECKING:
from types import TracebackType
from pymongo.asynchronous.client_session import AsyncClientSession
_IS_SYNC = False
def _is_workspace_endpoint(host: str) -> bool:
"""Return True if *host* looks like an ASP workspace endpoint.
Workspace hostnames begin with ``atlas-stream-`` or end with
``.a.query.mongodb.net``.
"""
return host.startswith("atlas-stream-") or host.endswith(".a.query.mongodb.net")
class AsyncStreamProcessingClient:
"""A client connected to an Atlas Stream Processing workspace.
Wraps :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` with
Atlas Stream Processing constraints:
* Only the ``mongodb://`` URI scheme is accepted (``mongodb+srv://`` is
not supported for workspace endpoints).
* TLS is always enabled and cannot be disabled.
* ``authSource`` defaults to ``"admin"`` if not explicitly set.
Usage::
async with AsyncStreamProcessingClient(
"mongodb://atlas-stream-<id>.<region>.a.query.mongodb.net/",
username="user",
password="pass",
) as client:
sps = client.stream_processors()
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
host: Any = args[0] if args else kwargs.get("host")
uris: list[str] = []
if isinstance(host, str):
uris = [host]
elif isinstance(host, (list, tuple)):
uris = [h for h in host if isinstance(h, str)]
uri_has_auth_source = False
for uri_str in uris:
if not uri_str.startswith(("mongodb://", SRV_SCHEME)):
# Plain hostname — no URI scheme to check.
continue
if uri_str.startswith(SRV_SCHEME):
raise ConfigurationError(
"StreamProcessingClient does not support mongodb+srv:// URIs; "
"use mongodb:// with a workspace endpoint."
)
parsed = _validate_uri(uri_str, validate=True, warn=False, normalize=True)
uri_opts = parsed["options"]
if uri_opts.get("tls") is False:
raise ConfigurationError(
"TLS cannot be disabled for stream processing workspace connections."
)
if uri_opts.get("authsource") is not None:
uri_has_auth_source = True
# Also reject explicit tls=False / ssl=False in kwargs.
if kwargs.get("tls") is False or kwargs.get("ssl") is False:
raise ConfigurationError(
"TLS cannot be disabled for stream processing workspace connections."
)
kwargs.pop("ssl", None)
kwargs["tls"] = True
if not uri_has_auth_source and not any(k.lower() == "authsource" for k in kwargs):
kwargs["authSource"] = "admin"
self._client: AsyncMongoClient[Any] = AsyncMongoClient(*args, **kwargs)
# NOTE: Per the ASP driver spec, server errors MUST be surfaced as-is.
# Do NOT introduce error-code branching, rewrapping, or filtering anywhere
# in this module. Known codes are documented at the module level for
# reference only — they are not runtime invariants.
async def _command(
self,
cmd: dict[str, Any],
*,
retryable_read: bool = False,
session: Optional[AsyncClientSession] = None,
) -> Mapping[str, Any]:
"""Send a top-level ASP command to the admin database.
Routes through the existing retry machinery: retryable reads use
:meth:`~pymongo.asynchronous.database.AsyncDatabase._retryable_read_command`;
everything else uses the standard
:meth:`~pymongo.asynchronous.database.AsyncDatabase.command` path.
:param cmd: The command document.
:param retryable_read: If ``True``, the command is sent as a retryable read.
:param session: A
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
"""
admin = self._client._database_default_options("admin")
if retryable_read:
# The first key of the command document is the operation name,
# which matches the corresponding _Op enum value string.
operation = next(iter(cmd))
return await admin._retryable_read_command(cmd, session=session, operation=operation)
return await admin.command(cmd, session=session)
def stream_processors(self) -> AsyncStreamProcessors:
"""Return a handle for managing stream processors in this workspace."""
return AsyncStreamProcessors(self)
async def close(self) -> None:
"""Close the underlying client and release all resources."""
await self._client.close()
async def __aenter__(self) -> AsyncStreamProcessingClient:
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()
@property
def address(self) -> Any:
"""(host, port) of the current workspace endpoint, or None.
Delegates to the underlying
:attr:`~pymongo.asynchronous.mongo_client.AsyncMongoClient.address`.
"""
return self._client.address
class AsyncStreamProcessors:
"""Handle for managing stream processors in a workspace.
Obtained via :meth:`AsyncStreamProcessingClient.stream_processors`.
"""
def __init__(self, client: AsyncStreamProcessingClient) -> None:
self._client = client
async def create(
self,
name: str,
pipeline: list[Mapping[str, Any]],
options: Optional[CreateStreamProcessorOptions] = None,
*,
session: Optional[AsyncClientSession] = None,
) -> None:
"""Create a new stream processor in the workspace.
Sends the ``createStreamProcessor`` command to the ``admin`` database.
This operation is **not** retryable.
:param name: The name of the stream processor to create.
:param pipeline: The aggregation pipeline that defines the processor.
:param options: Optional :class:`~pymongo.stream_processing_options.CreateStreamProcessorOptions`
controlling DLQ, tier, and other settings.
:param session: A
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
"""
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
if not pipeline:
raise InvalidOperation("createStreamProcessor requires a non-empty pipeline.")
cmd: dict[str, Any] = {
_Op.CREATE_STREAM_PROCESSOR: name,
"pipeline": list(pipeline),
}
if options is not None:
opts: dict[str, Any] = {}
if options.dlq is not None:
opts["dlq"] = options.dlq
if options.stream_meta_field_name is not None:
opts["streamMetaFieldName"] = options.stream_meta_field_name
if options.tier is not None:
opts["tier"] = options.tier
if options.failover is not None:
opts["failover"] = options.failover
if opts:
cmd["options"] = opts
await self._client._command(cmd, session=session)
def get(self, name: str) -> AsyncStreamProcessor:
"""Return a handle for an existing stream processor by name.
This is a pure factory method it does **not** contact the server
or verify that the processor exists.
:param name: The name of the stream processor.
:returns: An :class:`AsyncStreamProcessor` handle.
"""
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
return AsyncStreamProcessor(client=self._client, name=name)
async def get_info(
self,
name: str,
*,
session: Optional[AsyncClientSession] = None,
) -> StreamProcessorInfo:
"""Return information about a single stream processor.
Sends the ``getStreamProcessor`` command to the ``admin`` database.
This operation is a **retryable read**.
:param name: The name of the stream processor.
:param session: A
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
:returns: A :class:`~pymongo.stream_processing_options.StreamProcessorInfo`
populated from the server response. Unknown server fields are preserved
in :attr:`~pymongo.stream_processing_options.StreamProcessorInfo.raw`.
"""
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR: name}
response = await self._client._command(cmd, retryable_read=True, session=session)
doc = response.get("result", response)
return StreamProcessorInfo.from_response(doc)
class AsyncStreamProcessor:
"""Handle for a specific named stream processor.
Does not imply the processor currently exists on the server.
Obtain via :meth:`AsyncStreamProcessors.get`.
"""
def __init__(self, *, client: AsyncStreamProcessingClient, name: str) -> None:
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
self._client = client
self.name = name
async def start(
self,
options: Optional[StartStreamProcessorOptions] = None,
*,
session: Optional[AsyncClientSession] = None,
) -> None:
"""Start this stream processor.
Sends the ``startStreamProcessor`` command to the ``admin`` database.
This operation is **not** retryable.
The processor must be in the ``STOPPED`` or ``FAILED`` state; starting
an already-``STARTED`` processor returns a server error.
Mutual exclusivity of ``start_after`` / ``start_at_operation_time`` and
tier validation are enforced by
:class:`~pymongo.stream_processing_options.StartStreamProcessorOptions`
at construction time.
:param options: Optional
:class:`~pymongo.stream_processing_options.StartStreamProcessorOptions`.
:param session: A
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
"""
cmd: dict[str, Any] = {_Op.START_STREAM_PROCESSOR: self.name}
if options is not None:
if options.workers is not None:
cmd["workers"] = options.workers
opts: dict[str, Any] = {}
if options.clear_checkpoints is not None:
opts["clearCheckpoints"] = options.clear_checkpoints
if options.start_at_operation_time is not None:
opts["startAtOperationTime"] = options.start_at_operation_time
if options.start_after is not None:
opts["startAfter"] = options.start_after
if options.tier is not None:
opts["tier"] = options.tier
if options.enable_auto_scaling is not None:
opts["enableAutoScaling"] = options.enable_auto_scaling
if options.failover is not None:
opts["failover"] = options.failover
if opts:
cmd["options"] = opts
await self._client._command(cmd, session=session)
async def stop(
self,
*,
session: Optional[AsyncClientSession] = None,
) -> None:
"""Stop this stream processor.
Sends the ``stopStreamProcessor`` command to the ``admin`` database.
This operation is **not** retryable. The processor transitions to the
``STOPPED`` state and can be restarted.
:param session: A
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
"""
await self._client._command({_Op.STOP_STREAM_PROCESSOR: self.name}, session=session)
async def drop(
self,
*,
session: Optional[AsyncClientSession] = None,
) -> None:
"""Permanently delete this stream processor.
Sends the ``dropStreamProcessor`` command to the ``admin`` database.
This operation is **not** retryable. A dropped processor cannot be
recovered.
:param session: A
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
"""
await self._client._command({_Op.DROP_STREAM_PROCESSOR: self.name}, session=session)
async def stats(
self,
options: Optional[GetStreamProcessorStatsOptions] = None,
*,
session: Optional[AsyncClientSession] = None,
) -> Mapping[str, Any]:
"""Return runtime statistics for this stream processor.
Sends the ``getStreamProcessorStats`` command to the ``admin`` database.
This operation is a **retryable read**. The server returns an error if
the processor is not in the ``STARTED`` state.
Unknown fields in the response are preserved the raw response dict
is returned unchanged so callers are not affected by server additions.
:param options: Optional
:class:`~pymongo.stream_processing_options.GetStreamProcessorStatsOptions`
controlling scale units and verbosity.
:param session: A
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
:returns: The raw server response document.
"""
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR_STATS: self.name}
if options is not None:
opts: dict[str, Any] = {}
if options.scale is not None:
opts["scale"] = options.scale
if options.verbose is not None:
opts["verbose"] = options.verbose
if opts:
cmd["options"] = opts
return await self._client._command(cmd, retryable_read=True, session=session)
async def get_stream_processor_samples(
self,
options: Optional[GetStreamProcessorSamplesOptions] = None,
*,
session: Optional[AsyncClientSession] = None,
) -> GetStreamProcessorSamplesResult:
"""Fetch one batch of sampled documents from a running stream processor.
Spec-literal entry point. Inspects ``options.cursor_id`` and routes to
``startSampleStreamProcessor`` (initial call) or
``getMoreSampleStreamProcessor`` (continuation). Most users should
prefer :meth:`sample`, which wraps this in an async iterator.
A returned ``cursor_id`` of ``0`` means the cursor is exhausted; callers
MUST NOT call this method again with that cursor id.
Sends to the ``admin`` database. Non-retryable.
:param options: Optional
:class:`~pymongo.stream_processing_options.GetStreamProcessorSamplesOptions`.
:param session: A
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
for this operation.
:returns: A :class:`~pymongo.stream_processing_options.GetStreamProcessorSamplesResult`
containing the batch of documents and the cursor id for the next call.
"""
if options is None:
options = GetStreamProcessorSamplesOptions()
if options.cursor_id == 0:
raise InvalidOperation("Sample cursor is exhausted; cursor_id 0 cannot be continued.")
if options.cursor_id is None:
cmd: dict[str, Any] = {_Op.START_SAMPLE_STREAM_PROCESSOR: self.name}
if options.limit is not None:
cmd["limit"] = options.limit
resp = await self._client._command(cmd, session=session)
return GetStreamProcessorSamplesResult(
cursor_id=int(resp["cursorId"]),
documents=list(resp.get("firstBatch", [])),
)
else:
cmd = {
_Op.GET_MORE_SAMPLE_STREAM_PROCESSOR: self.name,
"cursorId": options.cursor_id,
}
if options.batch_size is not None:
cmd["batchSize"] = options.batch_size
resp = await self._client._command(cmd, session=session)
return GetStreamProcessorSamplesResult(
cursor_id=int(resp["cursorId"]),
documents=list(resp.get("nextBatch", resp.get("messages", []))),
)
def sample(
self,
limit: Optional[int] = None,
batch_size: Optional[int] = None,
*,
session: Optional[AsyncClientSession] = None,
) -> AsyncSampleCursor:
"""Open a sample cursor over this stream processor's output.
Returns an async iterator that yields sampled documents until the
server-side cursor is exhausted. Internally drives the two-phase
``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``
protocol on the caller's behalf.
Usage::
async for doc in processor.sample(limit=100, batch_size=10):
print(doc)
:param limit: Maximum number of documents to sample (sent only on the
initial call).
:param batch_size: Number of documents per continuation batch (sent
only on subsequent calls).
:param session: Optional :class:`AsyncClientSession` propagated to all
underlying commands.
"""
return AsyncSampleCursor(
processor=self,
limit=limit,
batch_size=batch_size,
session=session,
)
class AsyncSampleCursor:
"""Async iterator over sampled stream processor output.
A custom two-phase cursor used to retrieve sampled documents from a
running stream processor. This cursor MUST NOT be wrapped or re-used
via the standard MongoDB ``Cursor`` types because it does not use the
standard ``getMore`` command it uses the dedicated
``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``
commands instead.
Obtained via :meth:`AsyncStreamProcessor.sample`. Iterate with
``async for``.
The cursor is exhausted when the server returns ``cursorId: 0``;
after that, no further wire calls are issued and iteration ends.
"""
def __init__(
self,
*,
processor: AsyncStreamProcessor,
limit: Optional[int] = None,
batch_size: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
) -> None:
self._processor = processor
self._limit = limit
self._batch_size = batch_size
self._session = session
self._buffer: list[Mapping[str, Any]] = []
self._cursor_id: Optional[int] = None # None = not yet opened
self._exhausted: bool = False
self._closed: bool = False
@property
def cursor_id(self) -> Optional[int]:
"""Current server-side cursor id, or ``None`` if not yet opened.
A value of ``0`` indicates the cursor has been exhausted.
"""
return self._cursor_id
@property
def alive(self) -> bool:
"""``True`` if more documents may be available; ``False`` once exhausted or closed."""
return not self._exhausted and not self._closed
async def _refill(self) -> None:
"""Fetch the next batch from the server. No-op if exhausted or closed."""
if self._exhausted or self._closed:
return
if self._cursor_id is None:
opts = GetStreamProcessorSamplesOptions(
cursor_id=None,
limit=self._limit,
batch_size=None,
)
else:
opts = GetStreamProcessorSamplesOptions(
cursor_id=self._cursor_id,
limit=None,
batch_size=self._batch_size,
)
result = await self._processor.get_stream_processor_samples(opts, session=self._session)
self._cursor_id = result.cursor_id
self._buffer.extend(result.documents)
# Spec: cursorId == 0 means exhausted. MUST NOT call getMore again.
if result.cursor_id == 0:
self._exhausted = True
def __aiter__(self) -> AsyncSampleCursor:
return self
async def __anext__(self) -> Mapping[str, Any]:
if self._buffer:
return self._buffer.pop(0)
if self._closed or self._exhausted:
raise StopAsyncIteration
# Loop guards against an empty batch from the server with a non-zero
# cursor id — keep pulling until we get documents or hit exhaustion.
while not self._buffer and not self._exhausted:
await self._refill()
if self._buffer:
return self._buffer.pop(0)
raise StopAsyncIteration
async def close(self) -> None:
"""Mark the cursor closed locally.
Note: ASP does not currently expose a way to explicitly kill a
sample cursor server-side. ``close`` only stops local iteration;
the server-side cursor will be cleaned up on its own timeout or
when the processor stops.
"""
self._closed = True
async def __aenter__(self) -> AsyncSampleCursor:
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()

View File

@ -59,12 +59,17 @@ class _Op(str, enum.Enum):
CREATE = "create"
CREATE_INDEXES = "createIndexes"
CREATE_SEARCH_INDEXES = "createSearchIndexes"
CREATE_STREAM_PROCESSOR = "createStreamProcessor"
DELETE = "delete"
DISTINCT = "distinct"
DROP = "drop"
DROP_DATABASE = "dropDatabase"
DROP_INDEXES = "dropIndexes"
DROP_SEARCH_INDEXES = "dropSearchIndexes"
DROP_STREAM_PROCESSOR = "dropStreamProcessor"
GET_STREAM_PROCESSOR = "getStreamProcessor"
GET_STREAM_PROCESSOR_STATS = "getStreamProcessorStats"
GET_MORE_SAMPLE_STREAM_PROCESSOR = "getMoreSampleStreamProcessor"
END_SESSIONS = "endSessions"
FIND_AND_MODIFY = "findAndModify"
FIND = "find"
@ -77,6 +82,9 @@ class _Op(str, enum.Enum):
UPDATE_INDEX = "updateIndex"
UPDATE_SEARCH_INDEX = "updateSearchIndex"
RENAME = "rename"
START_STREAM_PROCESSOR = "startStreamProcessor"
START_SAMPLE_STREAM_PROCESSOR = "startSampleStreamProcessor"
STOP_STREAM_PROCESSOR = "stopStreamProcessor"
GETMORE = "getMore"
KILL_CURSORS = "killCursors"
TEST = "testOperation"

View File

@ -0,0 +1,193 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Options and result classes for Atlas Stream Processing commands.
These classes are shared between the synchronous and asynchronous APIs
no async/sync split is needed for plain dataclasses.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any, Mapping, Optional
if TYPE_CHECKING:
from bson import Timestamp
from pymongo.errors import InvalidOperation
_VALID_TIERS = {"SP2", "SP5", "SP10", "SP30", "SP50"}
@dataclass
class CreateStreamProcessorOptions:
"""Options for :meth:`AsyncStreamProcessors.create`.
All fields are optional.
"""
dlq: Optional[Mapping[str, Any]] = None
stream_meta_field_name: Optional[str] = None
tier: Optional[str] = None
failover: Optional[bool] = None
@dataclass
class StartStreamProcessorOptions:
"""Options for :meth:`AsyncStreamProcessor.start`.
``start_after`` and ``start_at_operation_time`` are mutually exclusive.
``tier``, when provided, must be one of ``"SP2"``, ``"SP5"``, ``"SP10"``,
``"SP30"``, or ``"SP50"``.
"""
workers: Optional[int] = None
clear_checkpoints: Optional[bool] = None
start_at_operation_time: Optional[Timestamp] = None
start_after: Optional[Mapping[str, Any]] = None
tier: Optional[str] = None
enable_auto_scaling: Optional[bool] = None
failover: Optional[Mapping[str, Any]] = None
def __post_init__(self) -> None:
if self.start_after is not None and self.start_at_operation_time is not None:
raise InvalidOperation(
"start_after and start_at_operation_time are mutually exclusive."
)
if self.tier is not None and self.tier not in _VALID_TIERS:
raise InvalidOperation(
f"Invalid tier {self.tier!r}. Must be one of: {sorted(_VALID_TIERS)}."
)
if self.workers is not None and self.workers <= 0:
raise InvalidOperation("workers must be a positive integer.")
@dataclass
class GetStreamProcessorStatsOptions:
"""Options for :meth:`AsyncStreamProcessor.stats`.
:param scale: Size unit for byte-valued fields. ``1`` = bytes (default),
``1024`` = kibibytes.
:param verbose: If ``True``, include per-operator statistics.
"""
scale: Optional[int] = None
verbose: Optional[bool] = None
def __post_init__(self) -> None:
if self.scale is not None and self.scale <= 0:
raise InvalidOperation("scale must be a positive integer.")
@dataclass
class GetStreamProcessorSamplesOptions:
"""Options for :meth:`AsyncStreamProcessor.get_stream_processor_samples`.
When ``cursor_id`` is absent or zero a new sample cursor is opened via
``startSampleStreamProcessor``; ``limit`` is only sent on that initial
call. When ``cursor_id`` is non-zero the next batch is fetched via
``getMoreSampleStreamProcessor``; ``batch_size`` is only sent on
subsequent calls.
"""
cursor_id: Optional[int] = None
limit: Optional[int] = None
batch_size: Optional[int] = None
def __post_init__(self) -> None:
if self.cursor_id is not None and self.cursor_id < 0:
raise InvalidOperation("cursor_id must be non-negative (0 means exhausted).")
if self.limit is not None and self.limit < 0:
raise InvalidOperation("limit must be non-negative.")
if self.batch_size is not None and self.batch_size < 0:
raise InvalidOperation("batch_size must be non-negative.")
@dataclass
class GetStreamProcessorSamplesResult:
"""Result from :meth:`AsyncStreamProcessor.get_stream_processor_samples`.
A ``cursor_id`` of ``0`` means the cursor is exhausted; no further calls
should be made.
"""
cursor_id: int
documents: list[Mapping[str, Any]]
@dataclass
class StreamProcessorInfo:
"""Information about a single stream processor.
Returned by :meth:`AsyncStreamProcessors.get_info`.
All fields from the ``getStreamProcessor`` server response are mapped to
snake_case. The complete raw server document is preserved in :attr:`raw`
so that unknown or future fields are not silently discarded.
"""
id: Optional[str]
name: str
state: str # plain str — drivers MUST NOT hard-code this as an Enum
pipeline: list[Mapping[str, Any]]
pipeline_version: Optional[int]
tier: Optional[str] = None
dlq: Optional[Mapping[str, Any]] = None
stream_meta_field_name: Optional[str] = None
enable_auto_scaling: bool = False
failover_enabled: bool = False
active_region: str = ""
last_modified_at: Optional[datetime] = None
modified_by: Optional[str] = None
last_state_change: Optional[datetime] = None
last_heartbeat: Optional[datetime] = None
has_started: bool = False
stats: Optional[Mapping[str, Any]] = None
error_msg: Optional[str] = None
error_code: Optional[int] = None
error_retryable: Optional[bool] = None
raw: Mapping[str, Any] = field(default_factory=dict)
@classmethod
def from_response(cls, doc: Mapping[str, Any]) -> StreamProcessorInfo:
"""Construct a :class:`StreamProcessorInfo` from a server response document.
Maps camelCase server keys to Python snake_case fields and stashes the
full *doc* in :attr:`raw` so no fields are silently dropped.
"""
return cls(
id=doc.get("id"),
name=doc["name"],
state=doc["state"],
pipeline=doc.get("pipeline", []),
pipeline_version=doc.get("pipelineVersion"),
tier=doc.get("tier"),
dlq=doc.get("dlq"),
stream_meta_field_name=doc.get("streamMetaFieldName"),
enable_auto_scaling=doc.get("enableAutoScaling", False),
failover_enabled=doc.get("failoverEnabled", False),
active_region=doc.get("activeRegion", ""),
last_modified_at=doc.get("lastModifiedAt"),
modified_by=doc.get("modifiedBy"),
last_state_change=doc.get("lastStateChange"),
last_heartbeat=doc.get("lastHeartbeat"),
has_started=doc.get("hasStarted", False),
stats=doc.get("stats"),
error_msg=doc.get("errorMsg"),
error_code=doc.get("errorCode"),
error_retryable=doc.get("errorRetryable"),
raw=dict(doc),
)

View File

@ -0,0 +1,633 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Atlas Stream Processing client for stream processing workspaces.
A stream processing workspace endpoint uses the ``mongodb://`` URI scheme with
a hostname that follows this pattern::
atlas-stream-<workspaceId>-<suffix>.<region>.a.query.mongodb.net
For example::
mongodb://atlas-stream-699c842ef433fe6001480b17-etif1.virginia-usa.a.query.mongodb.net/
TLS is always required for workspace connections and cannot be disabled.
``authSource=admin`` is applied by default when not explicitly set.
For commands not yet wrapped by this API, users can connect via a plain
:class:`~pymongo.mongo_client.MongoClient` and call
``run_command`` directly that path remains fully supported.
Error handling
~~~~~~~~~~~~~~
ASP commands raise :class:`pymongo.errors.OperationFailure` on server-side
errors. The following error codes are known to be returned by Atlas Stream
Processing commands:
================ ================= ====================================
Code Name When returned
================ ================= ====================================
9 FailedToParse Invalid pipeline or command document
72 InvalidOptions Invalid option values
125 CommandFailed General command execution failure
1 InternalError Unexpected server-side error
================ ================= ====================================
This list is **non-exhaustive** and may grow as the server evolves. Drivers
do not maintain a closed list of valid codes applications should branch
on ``exc.code`` only when they need to react to a specific known code, and
should always have a generic fallback for unknown codes.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
from pymongo.errors import ConfigurationError, InvalidOperation
from pymongo.operations import _Op
from pymongo.stream_processing_options import (
CreateStreamProcessorOptions,
GetStreamProcessorSamplesOptions,
GetStreamProcessorSamplesResult,
GetStreamProcessorStatsOptions,
StartStreamProcessorOptions,
StreamProcessorInfo,
)
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.uri_parser_shared import SRV_SCHEME, _validate_uri
if TYPE_CHECKING:
from types import TracebackType
from pymongo.synchronous.client_session import ClientSession
_IS_SYNC = True
def _is_workspace_endpoint(host: str) -> bool:
"""Return True if *host* looks like an ASP workspace endpoint.
Workspace hostnames begin with ``atlas-stream-`` or end with
``.a.query.mongodb.net``.
"""
return host.startswith("atlas-stream-") or host.endswith(".a.query.mongodb.net")
class StreamProcessingClient:
"""A client connected to an Atlas Stream Processing workspace.
Wraps :class:`~pymongo.mongo_client.MongoClient` with
Atlas Stream Processing constraints:
* Only the ``mongodb://`` URI scheme is accepted (``mongodb+srv://`` is
not supported for workspace endpoints).
* TLS is always enabled and cannot be disabled.
* ``authSource`` defaults to ``"admin"`` if not explicitly set.
Usage::
with StreamProcessingClient(
"mongodb://atlas-stream-<id>.<region>.a.query.mongodb.net/",
username="user",
password="pass",
) as client:
sps = client.stream_processors()
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
host: Any = args[0] if args else kwargs.get("host")
uris: list[str] = []
if isinstance(host, str):
uris = [host]
elif isinstance(host, (list, tuple)):
uris = [h for h in host if isinstance(h, str)]
uri_has_auth_source = False
for uri_str in uris:
if not uri_str.startswith(("mongodb://", SRV_SCHEME)):
# Plain hostname — no URI scheme to check.
continue
if uri_str.startswith(SRV_SCHEME):
raise ConfigurationError(
"StreamProcessingClient does not support mongodb+srv:// URIs; "
"use mongodb:// with a workspace endpoint."
)
parsed = _validate_uri(uri_str, validate=True, warn=False, normalize=True)
uri_opts = parsed["options"]
if uri_opts.get("tls") is False:
raise ConfigurationError(
"TLS cannot be disabled for stream processing workspace connections."
)
if uri_opts.get("authsource") is not None:
uri_has_auth_source = True
# Also reject explicit tls=False / ssl=False in kwargs.
if kwargs.get("tls") is False or kwargs.get("ssl") is False:
raise ConfigurationError(
"TLS cannot be disabled for stream processing workspace connections."
)
kwargs.pop("ssl", None)
kwargs["tls"] = True
if not uri_has_auth_source and not any(k.lower() == "authsource" for k in kwargs):
kwargs["authSource"] = "admin"
self._client: MongoClient[Any] = MongoClient(*args, **kwargs)
# NOTE: Per the ASP driver spec, server errors MUST be surfaced as-is.
# Do NOT introduce error-code branching, rewrapping, or filtering anywhere
# in this module. Known codes are documented at the module level for
# reference only — they are not runtime invariants.
def _command(
self,
cmd: dict[str, Any],
*,
retryable_read: bool = False,
session: Optional[ClientSession] = None,
) -> Mapping[str, Any]:
"""Send a top-level ASP command to the admin database.
Routes through the existing retry machinery: retryable reads use
:meth:`~pymongo.database.Database._retryable_read_command`;
everything else uses the standard
:meth:`~pymongo.database.Database.command` path.
:param cmd: The command document.
:param retryable_read: If ``True``, the command is sent as a retryable read.
:param session: A
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
"""
admin = self._client._database_default_options("admin")
if retryable_read:
# The first key of the command document is the operation name,
# which matches the corresponding _Op enum value string.
operation = next(iter(cmd))
return admin._retryable_read_command(cmd, session=session, operation=operation)
return admin.command(cmd, session=session)
def stream_processors(self) -> StreamProcessors:
"""Return a handle for managing stream processors in this workspace."""
return StreamProcessors(self)
def close(self) -> None:
"""Close the underlying client and release all resources."""
self._client.close()
def __enter__(self) -> StreamProcessingClient:
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
@property
def address(self) -> Any:
"""(host, port) of the current workspace endpoint, or None.
Delegates to the underlying
:attr:`~pymongo.mongo_client.MongoClient.address`.
"""
return self._client.address
class StreamProcessors:
"""Handle for managing stream processors in a workspace.
Obtained via :meth:`StreamProcessingClient.stream_processors`.
"""
def __init__(self, client: StreamProcessingClient) -> None:
self._client = client
def create(
self,
name: str,
pipeline: list[Mapping[str, Any]],
options: Optional[CreateStreamProcessorOptions] = None,
*,
session: Optional[ClientSession] = None,
) -> None:
"""Create a new stream processor in the workspace.
Sends the ``createStreamProcessor`` command to the ``admin`` database.
This operation is **not** retryable.
:param name: The name of the stream processor to create.
:param pipeline: The aggregation pipeline that defines the processor.
:param options: Optional :class:`~pymongo.stream_processing_options.CreateStreamProcessorOptions`
controlling DLQ, tier, and other settings.
:param session: A
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
"""
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
if not pipeline:
raise InvalidOperation("createStreamProcessor requires a non-empty pipeline.")
cmd: dict[str, Any] = {
_Op.CREATE_STREAM_PROCESSOR: name,
"pipeline": list(pipeline),
}
if options is not None:
opts: dict[str, Any] = {}
if options.dlq is not None:
opts["dlq"] = options.dlq
if options.stream_meta_field_name is not None:
opts["streamMetaFieldName"] = options.stream_meta_field_name
if options.tier is not None:
opts["tier"] = options.tier
if options.failover is not None:
opts["failover"] = options.failover
if opts:
cmd["options"] = opts
self._client._command(cmd, session=session)
def get(self, name: str) -> StreamProcessor:
"""Return a handle for an existing stream processor by name.
This is a pure factory method it does **not** contact the server
or verify that the processor exists.
:param name: The name of the stream processor.
:returns: An :class:`StreamProcessor` handle.
"""
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
return StreamProcessor(client=self._client, name=name)
def get_info(
self,
name: str,
*,
session: Optional[ClientSession] = None,
) -> StreamProcessorInfo:
"""Return information about a single stream processor.
Sends the ``getStreamProcessor`` command to the ``admin`` database.
This operation is a **retryable read**.
:param name: The name of the stream processor.
:param session: A
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
:returns: A :class:`~pymongo.stream_processing_options.StreamProcessorInfo`
populated from the server response. Unknown server fields are preserved
in :attr:`~pymongo.stream_processing_options.StreamProcessorInfo.raw`.
"""
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR: name}
response = self._client._command(cmd, retryable_read=True, session=session)
doc = response.get("result", response)
return StreamProcessorInfo.from_response(doc)
class StreamProcessor:
"""Handle for a specific named stream processor.
Does not imply the processor currently exists on the server.
Obtain via :meth:`StreamProcessors.get`.
"""
def __init__(self, *, client: StreamProcessingClient, name: str) -> None:
if not name or not name.strip():
raise InvalidOperation("Stream processor name must be a non-empty string.")
self._client = client
self.name = name
def start(
self,
options: Optional[StartStreamProcessorOptions] = None,
*,
session: Optional[ClientSession] = None,
) -> None:
"""Start this stream processor.
Sends the ``startStreamProcessor`` command to the ``admin`` database.
This operation is **not** retryable.
The processor must be in the ``STOPPED`` or ``FAILED`` state; starting
an already-``STARTED`` processor returns a server error.
Mutual exclusivity of ``start_after`` / ``start_at_operation_time`` and
tier validation are enforced by
:class:`~pymongo.stream_processing_options.StartStreamProcessorOptions`
at construction time.
:param options: Optional
:class:`~pymongo.stream_processing_options.StartStreamProcessorOptions`.
:param session: A
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
"""
cmd: dict[str, Any] = {_Op.START_STREAM_PROCESSOR: self.name}
if options is not None:
if options.workers is not None:
cmd["workers"] = options.workers
opts: dict[str, Any] = {}
if options.clear_checkpoints is not None:
opts["clearCheckpoints"] = options.clear_checkpoints
if options.start_at_operation_time is not None:
opts["startAtOperationTime"] = options.start_at_operation_time
if options.start_after is not None:
opts["startAfter"] = options.start_after
if options.tier is not None:
opts["tier"] = options.tier
if options.enable_auto_scaling is not None:
opts["enableAutoScaling"] = options.enable_auto_scaling
if options.failover is not None:
opts["failover"] = options.failover
if opts:
cmd["options"] = opts
self._client._command(cmd, session=session)
def stop(
self,
*,
session: Optional[ClientSession] = None,
) -> None:
"""Stop this stream processor.
Sends the ``stopStreamProcessor`` command to the ``admin`` database.
This operation is **not** retryable. The processor transitions to the
``STOPPED`` state and can be restarted.
:param session: A
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
"""
self._client._command({_Op.STOP_STREAM_PROCESSOR: self.name}, session=session)
def drop(
self,
*,
session: Optional[ClientSession] = None,
) -> None:
"""Permanently delete this stream processor.
Sends the ``dropStreamProcessor`` command to the ``admin`` database.
This operation is **not** retryable. A dropped processor cannot be
recovered.
:param session: A
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
"""
self._client._command({_Op.DROP_STREAM_PROCESSOR: self.name}, session=session)
def stats(
self,
options: Optional[GetStreamProcessorStatsOptions] = None,
*,
session: Optional[ClientSession] = None,
) -> Mapping[str, Any]:
"""Return runtime statistics for this stream processor.
Sends the ``getStreamProcessorStats`` command to the ``admin`` database.
This operation is a **retryable read**. The server returns an error if
the processor is not in the ``STARTED`` state.
Unknown fields in the response are preserved the raw response dict
is returned unchanged so callers are not affected by server additions.
:param options: Optional
:class:`~pymongo.stream_processing_options.GetStreamProcessorStatsOptions`
controlling scale units and verbosity.
:param session: A
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
:returns: The raw server response document.
"""
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR_STATS: self.name}
if options is not None:
opts: dict[str, Any] = {}
if options.scale is not None:
opts["scale"] = options.scale
if options.verbose is not None:
opts["verbose"] = options.verbose
if opts:
cmd["options"] = opts
return self._client._command(cmd, retryable_read=True, session=session)
def get_stream_processor_samples(
self,
options: Optional[GetStreamProcessorSamplesOptions] = None,
*,
session: Optional[ClientSession] = None,
) -> GetStreamProcessorSamplesResult:
"""Fetch one batch of sampled documents from a running stream processor.
Spec-literal entry point. Inspects ``options.cursor_id`` and routes to
``startSampleStreamProcessor`` (initial call) or
``getMoreSampleStreamProcessor`` (continuation). Most users should
prefer :meth:`sample`, which wraps this in an async iterator.
A returned ``cursor_id`` of ``0`` means the cursor is exhausted; callers
MUST NOT call this method again with that cursor id.
Sends to the ``admin`` database. Non-retryable.
:param options: Optional
:class:`~pymongo.stream_processing_options.GetStreamProcessorSamplesOptions`.
:param session: A
:class:`~pymongo.client_session.ClientSession` to use
for this operation.
:returns: A :class:`~pymongo.stream_processing_options.GetStreamProcessorSamplesResult`
containing the batch of documents and the cursor id for the next call.
"""
if options is None:
options = GetStreamProcessorSamplesOptions()
if options.cursor_id == 0:
raise InvalidOperation("Sample cursor is exhausted; cursor_id 0 cannot be continued.")
if options.cursor_id is None:
cmd: dict[str, Any] = {_Op.START_SAMPLE_STREAM_PROCESSOR: self.name}
if options.limit is not None:
cmd["limit"] = options.limit
resp = self._client._command(cmd, session=session)
return GetStreamProcessorSamplesResult(
cursor_id=int(resp["cursorId"]),
documents=list(resp.get("firstBatch", [])),
)
else:
cmd = {
_Op.GET_MORE_SAMPLE_STREAM_PROCESSOR: self.name,
"cursorId": options.cursor_id,
}
if options.batch_size is not None:
cmd["batchSize"] = options.batch_size
resp = self._client._command(cmd, session=session)
return GetStreamProcessorSamplesResult(
cursor_id=int(resp["cursorId"]),
documents=list(resp.get("nextBatch", resp.get("messages", []))),
)
def sample(
self,
limit: Optional[int] = None,
batch_size: Optional[int] = None,
*,
session: Optional[ClientSession] = None,
) -> SampleCursor:
"""Open a sample cursor over this stream processor's output.
Returns an async iterator that yields sampled documents until the
server-side cursor is exhausted. Internally drives the two-phase
``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``
protocol on the caller's behalf.
Usage::
for doc in processor.sample(limit=100, batch_size=10):
print(doc)
:param limit: Maximum number of documents to sample (sent only on the
initial call).
:param batch_size: Number of documents per continuation batch (sent
only on subsequent calls).
:param session: Optional :class:`ClientSession` propagated to all
underlying commands.
"""
return SampleCursor(
processor=self,
limit=limit,
batch_size=batch_size,
session=session,
)
class SampleCursor:
"""Async iterator over sampled stream processor output.
A custom two-phase cursor used to retrieve sampled documents from a
running stream processor. This cursor MUST NOT be wrapped or re-used
via the standard MongoDB ``Cursor`` types because it does not use the
standard ``getMore`` command it uses the dedicated
``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``
commands instead.
Obtained via :meth:`StreamProcessor.sample`. Iterate with
``for``.
The cursor is exhausted when the server returns ``cursorId: 0``;
after that, no further wire calls are issued and iteration ends.
"""
def __init__(
self,
*,
processor: StreamProcessor,
limit: Optional[int] = None,
batch_size: Optional[int] = None,
session: Optional[ClientSession] = None,
) -> None:
self._processor = processor
self._limit = limit
self._batch_size = batch_size
self._session = session
self._buffer: list[Mapping[str, Any]] = []
self._cursor_id: Optional[int] = None # None = not yet opened
self._exhausted: bool = False
self._closed: bool = False
@property
def cursor_id(self) -> Optional[int]:
"""Current server-side cursor id, or ``None`` if not yet opened.
A value of ``0`` indicates the cursor has been exhausted.
"""
return self._cursor_id
@property
def alive(self) -> bool:
"""``True`` if more documents may be available; ``False`` once exhausted or closed."""
return not self._exhausted and not self._closed
def _refill(self) -> None:
"""Fetch the next batch from the server. No-op if exhausted or closed."""
if self._exhausted or self._closed:
return
if self._cursor_id is None:
opts = GetStreamProcessorSamplesOptions(
cursor_id=None,
limit=self._limit,
batch_size=None,
)
else:
opts = GetStreamProcessorSamplesOptions(
cursor_id=self._cursor_id,
limit=None,
batch_size=self._batch_size,
)
result = self._processor.get_stream_processor_samples(opts, session=self._session)
self._cursor_id = result.cursor_id
self._buffer.extend(result.documents)
# Spec: cursorId == 0 means exhausted. MUST NOT call getMore again.
if result.cursor_id == 0:
self._exhausted = True
def __iter__(self) -> SampleCursor:
return self
def __next__(self) -> Mapping[str, Any]:
if self._buffer:
return self._buffer.pop(0)
if self._closed or self._exhausted:
raise StopIteration
# Loop guards against an empty batch from the server with a non-zero
# cursor id — keep pulling until we get documents or hit exhaustion.
while not self._buffer and not self._exhausted:
self._refill()
if self._buffer:
return self._buffer.pop(0)
raise StopIteration
def close(self) -> None:
"""Mark the cursor closed locally.
Note: ASP does not currently expose a way to explicitly kill a
sample cursor server-side. ``close`` only stops local iteration;
the server-side cursor will be cleaned up on its own timeout or
when the processor stops.
"""
self._closed = True
def __enter__(self) -> SampleCursor:
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()

View File

@ -0,0 +1,720 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the Atlas Stream Processing driver module.
All tests run offline no live workspace is required. The wire layer is
stubbed out by replacing ``_command`` on the client with a lightweight spy
that records calls and returns pre-configured responses.
"""
from __future__ import annotations
import inspect
import sys
import unittest
from datetime import datetime, timezone
from typing import Any, Mapping, Optional
from unittest.mock import MagicMock, patch
sys.path[0:0] = [""]
_IS_SYNC = False
import pymongo.asynchronous.stream_processing
from bson import Timestamp
from pymongo import (
AsyncSampleCursor,
AsyncStreamProcessingClient,
AsyncStreamProcessor,
AsyncStreamProcessors,
CreateStreamProcessorOptions,
GetStreamProcessorSamplesOptions,
GetStreamProcessorSamplesResult,
GetStreamProcessorStatsOptions,
StartStreamProcessorOptions,
StreamProcessorInfo,
)
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
# ---------------------------------------------------------------------------
# Spy helper
# ---------------------------------------------------------------------------
def _spy_client(
responses: Optional[list[Mapping[str, Any]]] = None,
raises: Optional[Exception] = None,
) -> tuple[AsyncStreamProcessingClient, list[dict[str, Any]]]:
"""Return *(client, calls)* where ``client._command`` records every call.
*responses* is consumed in order; when exhausted, ``{}`` is returned.
If *raises* is set, the spy raises that exception on every call instead
of returning a response.
"""
calls: list[dict[str, Any]] = []
resp_list: list[Mapping[str, Any]] = list(responses) if responses is not None else [{}]
async def _command(
cmd: dict[str, Any],
*,
retryable_read: bool = False,
session: Any = None,
) -> Mapping[str, Any]:
calls.append({"cmd": dict(cmd), "retryable_read": retryable_read, "session": session})
if raises is not None:
raise raises
return resp_list.pop(0) if resp_list else {}
client = AsyncStreamProcessingClient.__new__(AsyncStreamProcessingClient)
client._client = MagicMock()
client._command = _command
return client, calls
# ---------------------------------------------------------------------------
# Minimal info response fixture
# ---------------------------------------------------------------------------
_INFO_RESPONSE: dict[str, Any] = {
"ok": 1,
"id": "abc",
"name": "demo",
"state": "CREATED",
"pipeline": [],
"pipelineVersion": 1,
"enableAutoScaling": False,
"failoverEnabled": False,
"activeRegion": "us-east-1",
"hasStarted": False,
}
# ===========================================================================
# Test class 1 — client constructor validation
# ===========================================================================
class TestStreamProcessingClientConfig(unittest.TestCase):
"""Constructor-level validation for AsyncStreamProcessingClient."""
def test_rejects_srv_uri(self) -> None:
with self.assertRaises(ConfigurationError) as cm:
AsyncStreamProcessingClient("mongodb+srv://example.com/")
self.assertIn("mongodb+srv", str(cm.exception).lower())
def test_rejects_tls_false_kwarg(self) -> None:
with self.assertRaises(ConfigurationError):
AsyncStreamProcessingClient("mongodb://example.com/", tls=False)
def test_rejects_ssl_false_kwarg(self) -> None:
with self.assertRaises(ConfigurationError):
AsyncStreamProcessingClient("mongodb://example.com/", ssl=False)
def test_rejects_tls_false_in_uri(self) -> None:
with self.assertRaises(ConfigurationError):
AsyncStreamProcessingClient("mongodb://example.com/?tls=false")
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
def test_forces_tls_true(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
AsyncStreamProcessingClient("mongodb://example.com/")
kwargs = mock_client.call_args.kwargs
self.assertTrue(kwargs.get("tls"))
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
def test_defaults_authsource_to_admin(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
AsyncStreamProcessingClient("mongodb://example.com/")
kwargs = mock_client.call_args.kwargs
self.assertEqual(kwargs.get("authSource"), "admin")
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
def test_preserves_explicit_authsource_in_kwarg(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
AsyncStreamProcessingClient("mongodb://example.com/", authSource="other")
kwargs = mock_client.call_args.kwargs
# Our default injection must not overwrite a user-supplied value.
self.assertEqual(kwargs.get("authSource"), "other")
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
def test_preserves_explicit_authsource_in_uri(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
AsyncStreamProcessingClient("mongodb://example.com/?authSource=other")
kwargs = mock_client.call_args.kwargs
# authSource is already in the URI; we must not inject a duplicate kwarg.
self.assertNotIn("authSource", kwargs)
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
def test_drops_ssl_kwarg_to_avoid_duplicate(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
AsyncStreamProcessingClient("mongodb://example.com/", ssl=True)
kwargs = mock_client.call_args.kwargs
self.assertNotIn("ssl", kwargs)
self.assertTrue(kwargs.get("tls"))
def test_workspace_endpoint_detection(self) -> None:
from pymongo.asynchronous.stream_processing import _is_workspace_endpoint
self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net"))
self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net"))
self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net"))
self.assertFalse(_is_workspace_endpoint("localhost"))
# ===========================================================================
# Test class 2 — options dataclass validation
# ===========================================================================
class TestStreamProcessingOptions(unittest.TestCase):
"""Dataclass validation in stream_processing_options."""
def test_start_options_mutex_start_after_and_at_op_time(self) -> None:
with self.assertRaises(InvalidOperation):
StartStreamProcessorOptions(
start_after={"_id": 1},
start_at_operation_time=Timestamp(0, 0),
)
def test_start_options_invalid_tier(self) -> None:
with self.assertRaises(InvalidOperation):
StartStreamProcessorOptions(tier="SP1")
for valid_tier in ("SP2", "SP5", "SP10", "SP30", "SP50"):
StartStreamProcessorOptions(tier=valid_tier) # must not raise
def test_start_options_invalid_workers(self) -> None:
with self.assertRaises(InvalidOperation):
StartStreamProcessorOptions(workers=0)
with self.assertRaises(InvalidOperation):
StartStreamProcessorOptions(workers=-1)
StartStreamProcessorOptions(workers=1) # must not raise
def test_stats_options_invalid_scale(self) -> None:
with self.assertRaises(InvalidOperation):
GetStreamProcessorStatsOptions(scale=0)
with self.assertRaises(InvalidOperation):
GetStreamProcessorStatsOptions(scale=-1)
GetStreamProcessorStatsOptions(scale=1)
GetStreamProcessorStatsOptions(scale=1024)
def test_samples_options_invalid_cursor_id(self) -> None:
with self.assertRaises(InvalidOperation):
GetStreamProcessorSamplesOptions(cursor_id=-1)
# cursor_id=0 is valid at construction; the method-level check rejects it.
GetStreamProcessorSamplesOptions(cursor_id=0)
GetStreamProcessorSamplesOptions(cursor_id=42)
def test_samples_options_invalid_limit(self) -> None:
with self.assertRaises(InvalidOperation):
GetStreamProcessorSamplesOptions(limit=-1)
GetStreamProcessorSamplesOptions(limit=0)
GetStreamProcessorSamplesOptions(limit=10)
def test_samples_options_invalid_batch_size(self) -> None:
with self.assertRaises(InvalidOperation):
GetStreamProcessorSamplesOptions(batch_size=-1)
GetStreamProcessorSamplesOptions(batch_size=0)
GetStreamProcessorSamplesOptions(batch_size=5)
def test_create_options_all_optional(self) -> None:
opts = CreateStreamProcessorOptions()
self.assertIsNone(opts.dlq)
self.assertIsNone(opts.stream_meta_field_name)
self.assertIsNone(opts.tier)
self.assertIsNone(opts.failover)
def test_stream_processor_info_from_response_full(self) -> None:
now = datetime.now(tz=timezone.utc)
doc: dict[str, Any] = {
"ok": 1,
"id": "abc123",
"name": "demo",
"state": "STARTED",
"pipeline": [{"$source": {}}],
"pipelineVersion": 3,
"tier": "SP10",
"dlq": {"connectionName": "c1"},
"streamMetaFieldName": "ts",
"enableAutoScaling": True,
"failoverEnabled": True,
"activeRegion": "us-east-1",
"lastModifiedAt": now,
"modifiedBy": "user@example.com",
"lastStateChange": now,
"lastHeartbeat": now,
"hasStarted": True,
"stats": {"inputMessageCount": 42},
"errorMsg": None,
"errorCode": None,
"errorRetryable": None,
}
info = StreamProcessorInfo.from_response(doc)
self.assertEqual(info.id, "abc123")
self.assertEqual(info.name, "demo")
self.assertEqual(info.state, "STARTED")
self.assertEqual(info.pipeline, [{"$source": {}}])
self.assertEqual(info.pipeline_version, 3)
self.assertEqual(info.tier, "SP10")
self.assertEqual(info.dlq, {"connectionName": "c1"})
self.assertEqual(info.stream_meta_field_name, "ts")
self.assertTrue(info.enable_auto_scaling)
self.assertTrue(info.failover_enabled)
self.assertEqual(info.active_region, "us-east-1")
self.assertTrue(info.has_started)
self.assertEqual(info.stats, {"inputMessageCount": 42})
self.assertEqual(info.raw, doc)
def test_stream_processor_info_from_response_minimal(self) -> None:
doc: dict[str, Any] = {
"id": "x",
"name": "p",
"state": "CREATED",
"pipeline": [],
"pipelineVersion": 1,
}
info = StreamProcessorInfo.from_response(doc)
self.assertEqual(info.id, "x")
self.assertIsNone(info.tier)
self.assertIsNone(info.dlq)
self.assertFalse(info.enable_auto_scaling)
self.assertFalse(info.has_started)
self.assertEqual(info.raw, doc)
def test_stream_processor_info_preserves_unknown_fields(self) -> None:
doc: dict[str, Any] = {
"id": "x",
"name": "p",
"state": "CREATED",
"pipeline": [],
"pipelineVersion": 1,
"futureField": "someValue",
}
info = StreamProcessorInfo.from_response(doc)
self.assertEqual(info.raw["futureField"], "someValue")
def test_stream_processor_info_state_is_string(self) -> None:
doc: dict[str, Any] = {
"id": "x",
"name": "p",
"state": "FUTURE_UNKNOWN_STATE",
"pipeline": [],
"pipelineVersion": 1,
}
info = StreamProcessorInfo.from_response(doc)
self.assertEqual(info.state, "FUTURE_UNKNOWN_STATE")
# ===========================================================================
# Test class 3 — lifecycle commands
# ===========================================================================
class AsyncTestStreamProcessorsCommands(unittest.IsolatedAsyncioTestCase):
"""Lifecycle commands on AsyncStreamProcessors and AsyncStreamProcessor."""
async def asyncSetUp(self) -> None:
self.client, self.calls = _spy_client(responses=[{"ok": 1}] * 30)
self.sps = AsyncStreamProcessors(self.client)
self.proc = AsyncStreamProcessor(client=self.client, name="demo")
async def test_create_sends_correct_command(self) -> None:
await self.sps.create("demo", pipeline=[{"$source": {}}])
entry = self.calls[-1]
self.assertEqual(entry["cmd"].get("createStreamProcessor"), "demo")
self.assertEqual(entry["cmd"].get("pipeline"), [{"$source": {}}])
self.assertNotIn("options", entry["cmd"])
self.assertFalse(entry["retryable_read"])
async def test_create_with_options_includes_them(self) -> None:
opts = CreateStreamProcessorOptions(dlq={"connectionName": "c"}, tier="SP10")
await self.sps.create("demo", pipeline=[{"$source": {}}], options=opts)
cmd = self.calls[-1]["cmd"]
self.assertEqual(cmd["options"]["dlq"], {"connectionName": "c"})
self.assertEqual(cmd["options"]["tier"], "SP10")
async def test_create_rejects_empty_name(self) -> None:
with self.assertRaises(InvalidOperation):
await self.sps.create("", pipeline=[{"$x": 1}])
self.assertEqual(len(self.calls), 0)
async def test_create_rejects_whitespace_name(self) -> None:
with self.assertRaises(InvalidOperation):
await self.sps.create(" ", pipeline=[{"$x": 1}])
self.assertEqual(len(self.calls), 0)
async def test_create_rejects_empty_pipeline(self) -> None:
with self.assertRaises(InvalidOperation):
await self.sps.create("demo", pipeline=[])
self.assertEqual(len(self.calls), 0)
async def test_get_returns_handle_without_calling_server(self) -> None:
proc = self.sps.get("demo")
self.assertEqual(len(self.calls), 0)
self.assertEqual(proc.name, "demo")
self.assertIsInstance(proc, AsyncStreamProcessor)
async def test_get_rejects_empty_name(self) -> None:
with self.assertRaises(InvalidOperation):
self.sps.get("")
async def test_get_info_sends_correct_command_and_decodes(self) -> None:
client, calls = _spy_client(responses=[dict(_INFO_RESPONSE)])
info = await AsyncStreamProcessors(client).get_info("demo")
self.assertEqual(calls[0]["cmd"].get("getStreamProcessor"), "demo")
self.assertTrue(calls[0]["retryable_read"])
self.assertIsInstance(info, StreamProcessorInfo)
self.assertEqual(info.id, "abc")
self.assertEqual(info.state, "CREATED")
async def test_get_info_preserves_unknown_response_fields(self) -> None:
resp = dict(_INFO_RESPONSE, futureField="value")
client, _ = _spy_client(responses=[resp])
info = await AsyncStreamProcessors(client).get_info("demo")
self.assertEqual(info.raw["futureField"], "value")
async def test_start_sends_correct_command(self) -> None:
await self.proc.start()
entry = self.calls[-1]
self.assertEqual(entry["cmd"].get("startStreamProcessor"), "demo")
self.assertNotIn("workers", entry["cmd"])
self.assertNotIn("options", entry["cmd"])
self.assertFalse(entry["retryable_read"])
async def test_start_with_options_includes_them(self) -> None:
opts = StartStreamProcessorOptions(workers=2, clear_checkpoints=True, tier="SP10")
await self.proc.start(options=opts)
cmd = self.calls[-1]["cmd"]
self.assertEqual(cmd.get("workers"), 2)
self.assertTrue(cmd["options"]["clearCheckpoints"])
self.assertEqual(cmd["options"]["tier"], "SP10")
async def test_stop_sends_correct_command(self) -> None:
await self.proc.stop()
entry = self.calls[-1]
self.assertEqual(entry["cmd"].get("stopStreamProcessor"), "demo")
self.assertFalse(entry["retryable_read"])
async def test_drop_sends_correct_command(self) -> None:
await self.proc.drop()
entry = self.calls[-1]
self.assertEqual(entry["cmd"].get("dropStreamProcessor"), "demo")
self.assertFalse(entry["retryable_read"])
async def test_stats_sends_correct_command_and_returns_raw_dict(self) -> None:
raw_resp: dict[str, Any] = {"ok": 1, "stats": {"inputMessageCount": 5}, "futureField": "x"}
client, calls = _spy_client(responses=[raw_resp])
result = await AsyncStreamProcessor(client=client, name="demo").stats()
self.assertEqual(calls[0]["cmd"].get("getStreamProcessorStats"), "demo")
self.assertTrue(calls[0]["retryable_read"])
self.assertEqual(result, raw_resp)
async def test_stats_with_options(self) -> None:
client, calls = _spy_client(responses=[{"ok": 1}])
opts = GetStreamProcessorStatsOptions(scale=1024, verbose=True)
await AsyncStreamProcessor(client=client, name="demo").stats(options=opts)
self.assertEqual(calls[0]["cmd"]["options"]["scale"], 1024)
self.assertTrue(calls[0]["cmd"]["options"]["verbose"])
async def test_session_propagates(self) -> None:
fake_session = MagicMock()
await self.proc.start(session=fake_session)
self.assertIs(self.calls[-1]["session"], fake_session)
# ===========================================================================
# Test class 4 — sample cursor
# ===========================================================================
class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase):
"""Two-phase sample cursor protocol."""
async def test_get_samples_initial_call_sends_start_command(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"x": 1}]}])
proc = AsyncStreamProcessor(client=client, name="demo")
result = await proc.get_stream_processor_samples()
cmd = calls[0]["cmd"]
self.assertIn("startSampleStreamProcessor", cmd)
self.assertNotIn("cursorId", cmd)
self.assertNotIn("batchSize", cmd)
self.assertFalse(calls[0]["retryable_read"])
self.assertEqual(result.cursor_id, 42)
self.assertEqual(result.documents, [{"x": 1}])
async def test_get_samples_initial_call_includes_limit(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 1, "firstBatch": []}])
proc = AsyncStreamProcessor(client=client, name="demo")
await proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(limit=100))
self.assertEqual(calls[0]["cmd"].get("limit"), 100)
self.assertNotIn("batchSize", calls[0]["cmd"])
async def test_get_samples_continuation_sends_get_more(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": [{"y": 2}]}])
proc = AsyncStreamProcessor(client=client, name="demo")
result = await proc.get_stream_processor_samples(
GetStreamProcessorSamplesOptions(cursor_id=42)
)
cmd = calls[0]["cmd"]
self.assertIn("getMoreSampleStreamProcessor", cmd)
self.assertEqual(cmd.get("cursorId"), 42)
self.assertNotIn("limit", cmd)
self.assertEqual(result.cursor_id, 0)
self.assertEqual(result.documents, [{"y": 2}])
async def test_get_samples_continuation_includes_batch_size(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": []}])
proc = AsyncStreamProcessor(client=client, name="demo")
await proc.get_stream_processor_samples(
GetStreamProcessorSamplesOptions(cursor_id=42, batch_size=5)
)
self.assertEqual(calls[0]["cmd"].get("batchSize"), 5)
async def test_get_samples_rejects_cursor_id_zero(self) -> None:
client, calls = _spy_client()
proc = AsyncStreamProcessor(client=client, name="demo")
with self.assertRaises(InvalidOperation):
await proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=0))
self.assertEqual(len(calls), 0)
async def test_sample_iterator_drains_single_batch(self) -> None:
client, calls = _spy_client(
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}]
)
proc = AsyncStreamProcessor(client=client, name="demo")
docs = []
async for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
self.assertEqual(len(calls), 1)
async def test_sample_iterator_drains_multiple_batches(self) -> None:
client, calls = _spy_client(
responses=[
{"cursorId": 11, "firstBatch": [{"a": 1}]},
{"cursorId": 22, "nextBatch": [{"a": 2}]},
{"cursorId": 0, "nextBatch": [{"a": 3}]},
]
)
proc = AsyncStreamProcessor(client=client, name="demo")
docs = []
async for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
self.assertEqual(len(calls), 3)
async def test_sample_iterator_handles_empty_batch_with_nonzero_cursor(self) -> None:
client, calls = _spy_client(
responses=[
{"cursorId": 11, "firstBatch": []},
{"cursorId": 22, "nextBatch": []},
{"cursorId": 0, "nextBatch": [{"a": 1}]},
]
)
proc = AsyncStreamProcessor(client=client, name="demo")
docs = []
async for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}])
self.assertEqual(len(calls), 3)
async def test_sample_iterator_stops_after_cursor_id_zero(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = AsyncStreamProcessor(client=client, name="demo")
cursor = proc.sample()
first = await cursor.__anext__()
self.assertEqual(first, {"a": 1})
with self.assertRaises(StopAsyncIteration):
await cursor.__anext__()
# Only one wire call — the server returned cursorId:0 so we stopped.
self.assertEqual(len(calls), 1)
async def test_sample_cursor_close_stops_iteration(self) -> None:
client, calls = _spy_client()
proc = AsyncStreamProcessor(client=client, name="demo")
cursor = proc.sample()
await cursor.close()
with self.assertRaises(StopAsyncIteration):
await cursor.__anext__()
self.assertEqual(len(calls), 0)
async def test_sample_cursor_alive_property(self) -> None:
client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = AsyncStreamProcessor(client=client, name="demo")
cursor = proc.sample()
self.assertTrue(cursor.alive)
doc = await cursor.__anext__()
self.assertEqual(doc, {"a": 1})
# After exhaustion (cursorId:0 returned during refill), alive is False.
self.assertFalse(cursor.alive)
async def test_sample_cursor_id_property(self) -> None:
client, _ = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"a": 1}]}])
proc = AsyncStreamProcessor(client=client, name="demo")
cursor = proc.sample()
self.assertIsNone(cursor.cursor_id)
await cursor.__anext__()
self.assertEqual(cursor.cursor_id, 42)
async def test_sample_cursor_async_context_manager(self) -> None:
client, _ = _spy_client()
proc = AsyncStreamProcessor(client=client, name="demo")
async with proc.sample() as cur:
self.assertIsInstance(cur, AsyncSampleCursor)
self.assertTrue(cur._closed)
async def test_sample_cursor_session_propagates_to_refill(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = AsyncStreamProcessor(client=client, name="demo")
fake_session = MagicMock()
async for _ in proc.sample(session=fake_session):
pass
self.assertIs(calls[0]["session"], fake_session)
def test_sample_cursor_does_not_inherit_from_cursor_base(self) -> None:
self.assertEqual(AsyncSampleCursor.__bases__, (object,))
# ===========================================================================
# Test class 5 — error handling
# ===========================================================================
class AsyncTestErrorHandling(unittest.IsolatedAsyncioTestCase):
"""Server errors MUST propagate unchanged — no rewrapping, no filtering."""
def _failure(self, code: int = 9) -> OperationFailure:
return OperationFailure("test error", code=code, details={"errmsg": "test", "code": code})
async def _assert_propagates(self, fn: Any, expected_code: int) -> None:
with self.assertRaises(OperationFailure) as cm:
await fn()
self.assertEqual(cm.exception.code, expected_code)
async def test_create_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
await self._assert_propagates(
lambda: AsyncStreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 9
)
async def test_start_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(72))
await self._assert_propagates(
lambda: AsyncStreamProcessor(client=client, name="p").start(), 72
)
async def test_stop_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(125))
await self._assert_propagates(
lambda: AsyncStreamProcessor(client=client, name="p").stop(), 125
)
async def test_drop_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(1))
await self._assert_propagates(
lambda: AsyncStreamProcessor(client=client, name="p").drop(), 1
)
async def test_get_info_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
await self._assert_propagates(lambda: AsyncStreamProcessors(client).get_info("p"), 9)
async def test_stats_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(72))
await self._assert_propagates(
lambda: AsyncStreamProcessor(client=client, name="p").stats(), 72
)
async def test_get_stream_processor_samples_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
await self._assert_propagates(
lambda: AsyncStreamProcessor(client=client, name="p").get_stream_processor_samples(), 9
)
async def test_unknown_error_code_propagates(self) -> None:
# Code 999 is not in the documented list; it must still propagate.
client, _ = _spy_client(raises=self._failure(999))
await self._assert_propagates(
lambda: AsyncStreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 999
)
def test_no_operation_failure_caught_in_module(self) -> None:
"""Structural: verify no try/except hides server errors in either module."""
for mod in (pymongo.asynchronous.stream_processing,):
src = inspect.getsource(mod)
self.assertNotIn("except OperationFailure", src, mod.__name__)
self.assertNotIn("except PyMongoError", src, mod.__name__)
self.assertNotIn("exc.code in", src, mod.__name__)
# ===========================================================================
# Test class 6 — spec compliance
# ===========================================================================
class AsyncTestSpecCompliance(unittest.IsolatedAsyncioTestCase):
"""Spec MUST-level compliance checks."""
async def test_retryability_classification(self) -> None:
"""Each command must route as retryable or non-retryable per the spec."""
cases: list[tuple[str, Any, bool]] = [
# (label, coroutine-factory, expected_retryable_read)
(
"create",
lambda c: AsyncStreamProcessors(c).create("demo", pipeline=[{"$x": 1}]),
False,
),
("start", lambda c: AsyncStreamProcessor(client=c, name="demo").start(), False),
("stop", lambda c: AsyncStreamProcessor(client=c, name="demo").stop(), False),
("drop", lambda c: AsyncStreamProcessor(client=c, name="demo").drop(), False),
("get_info", lambda c: AsyncStreamProcessors(c).get_info("demo"), True),
("stats", lambda c: AsyncStreamProcessor(client=c, name="demo").stats(), True),
(
"get_samples_initial",
lambda c: AsyncStreamProcessor(
client=c, name="demo"
).get_stream_processor_samples(),
False,
),
(
"get_samples_continuation",
lambda c: AsyncStreamProcessor(client=c, name="demo").get_stream_processor_samples(
GetStreamProcessorSamplesOptions(cursor_id=42)
),
False,
),
]
responses_for: dict[str, list[Mapping[str, Any]]] = {
"get_info": [dict(_INFO_RESPONSE)],
"get_samples_initial": [{"cursorId": 0, "firstBatch": []}],
"get_samples_continuation": [{"cursorId": 0, "nextBatch": []}],
}
for label, make_coro, expected in cases:
resp = responses_for.get(label, [{"ok": 1}])
client, calls = _spy_client(responses=resp)
await make_coro(client)
self.assertEqual(
calls[-1]["retryable_read"],
expected,
f"retryability mismatch for '{label}'",
)
def test_async_sample_cursor_class_hierarchy(self) -> None:
self.assertEqual(AsyncSampleCursor.__bases__, (object,))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,704 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the Atlas Stream Processing driver module.
All tests run offline no live workspace is required. The wire layer is
stubbed out by replacing ``_command`` on the client with a lightweight spy
that records calls and returns pre-configured responses.
"""
from __future__ import annotations
import inspect
import sys
import unittest
from datetime import datetime, timezone
from typing import Any, Mapping, Optional
from unittest.mock import MagicMock, patch
sys.path[0:0] = [""]
_IS_SYNC = True
import pymongo.synchronous.stream_processing
from bson import Timestamp
from pymongo import (
CreateStreamProcessorOptions,
GetStreamProcessorSamplesOptions,
GetStreamProcessorSamplesResult,
GetStreamProcessorStatsOptions,
SampleCursor,
StartStreamProcessorOptions,
StreamProcessingClient,
StreamProcessor,
StreamProcessorInfo,
StreamProcessors,
)
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
# ---------------------------------------------------------------------------
# Spy helper
# ---------------------------------------------------------------------------
def _spy_client(
responses: Optional[list[Mapping[str, Any]]] = None,
raises: Optional[Exception] = None,
) -> tuple[StreamProcessingClient, list[dict[str, Any]]]:
"""Return *(client, calls)* where ``client._command`` records every call.
*responses* is consumed in order; when exhausted, ``{}`` is returned.
If *raises* is set, the spy raises that exception on every call instead
of returning a response.
"""
calls: list[dict[str, Any]] = []
resp_list: list[Mapping[str, Any]] = list(responses) if responses is not None else [{}]
def _command(
cmd: dict[str, Any],
*,
retryable_read: bool = False,
session: Any = None,
) -> Mapping[str, Any]:
calls.append({"cmd": dict(cmd), "retryable_read": retryable_read, "session": session})
if raises is not None:
raise raises
return resp_list.pop(0) if resp_list else {}
client = StreamProcessingClient.__new__(StreamProcessingClient)
client._client = MagicMock()
client._command = _command
return client, calls
# ---------------------------------------------------------------------------
# Minimal info response fixture
# ---------------------------------------------------------------------------
_INFO_RESPONSE: dict[str, Any] = {
"ok": 1,
"id": "abc",
"name": "demo",
"state": "CREATED",
"pipeline": [],
"pipelineVersion": 1,
"enableAutoScaling": False,
"failoverEnabled": False,
"activeRegion": "us-east-1",
"hasStarted": False,
}
# ===========================================================================
# Test class 1 — client constructor validation
# ===========================================================================
class TestStreamProcessingClientConfig(unittest.TestCase):
"""Constructor-level validation for StreamProcessingClient."""
def test_rejects_srv_uri(self) -> None:
with self.assertRaises(ConfigurationError) as cm:
StreamProcessingClient("mongodb+srv://example.com/")
self.assertIn("mongodb+srv", str(cm.exception).lower())
def test_rejects_tls_false_kwarg(self) -> None:
with self.assertRaises(ConfigurationError):
StreamProcessingClient("mongodb://example.com/", tls=False)
def test_rejects_ssl_false_kwarg(self) -> None:
with self.assertRaises(ConfigurationError):
StreamProcessingClient("mongodb://example.com/", ssl=False)
def test_rejects_tls_false_in_uri(self) -> None:
with self.assertRaises(ConfigurationError):
StreamProcessingClient("mongodb://example.com/?tls=false")
@patch("pymongo.synchronous.stream_processing.MongoClient")
def test_forces_tls_true(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
StreamProcessingClient("mongodb://example.com/")
kwargs = mock_client.call_args.kwargs
self.assertTrue(kwargs.get("tls"))
@patch("pymongo.synchronous.stream_processing.MongoClient")
def test_defaults_authsource_to_admin(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
StreamProcessingClient("mongodb://example.com/")
kwargs = mock_client.call_args.kwargs
self.assertEqual(kwargs.get("authSource"), "admin")
@patch("pymongo.synchronous.stream_processing.MongoClient")
def test_preserves_explicit_authsource_in_kwarg(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
StreamProcessingClient("mongodb://example.com/", authSource="other")
kwargs = mock_client.call_args.kwargs
# Our default injection must not overwrite a user-supplied value.
self.assertEqual(kwargs.get("authSource"), "other")
@patch("pymongo.synchronous.stream_processing.MongoClient")
def test_preserves_explicit_authsource_in_uri(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
StreamProcessingClient("mongodb://example.com/?authSource=other")
kwargs = mock_client.call_args.kwargs
# authSource is already in the URI; we must not inject a duplicate kwarg.
self.assertNotIn("authSource", kwargs)
@patch("pymongo.synchronous.stream_processing.MongoClient")
def test_drops_ssl_kwarg_to_avoid_duplicate(self, mock_client: MagicMock) -> None:
mock_client.return_value = MagicMock()
StreamProcessingClient("mongodb://example.com/", ssl=True)
kwargs = mock_client.call_args.kwargs
self.assertNotIn("ssl", kwargs)
self.assertTrue(kwargs.get("tls"))
def test_workspace_endpoint_detection(self) -> None:
from pymongo.synchronous.stream_processing import _is_workspace_endpoint
self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net"))
self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net"))
self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net"))
self.assertFalse(_is_workspace_endpoint("localhost"))
# ===========================================================================
# Test class 2 — options dataclass validation
# ===========================================================================
class TestStreamProcessingOptions(unittest.TestCase):
"""Dataclass validation in stream_processing_options."""
def test_start_options_mutex_start_after_and_at_op_time(self) -> None:
with self.assertRaises(InvalidOperation):
StartStreamProcessorOptions(
start_after={"_id": 1},
start_at_operation_time=Timestamp(0, 0),
)
def test_start_options_invalid_tier(self) -> None:
with self.assertRaises(InvalidOperation):
StartStreamProcessorOptions(tier="SP1")
for valid_tier in ("SP2", "SP5", "SP10", "SP30", "SP50"):
StartStreamProcessorOptions(tier=valid_tier) # must not raise
def test_start_options_invalid_workers(self) -> None:
with self.assertRaises(InvalidOperation):
StartStreamProcessorOptions(workers=0)
with self.assertRaises(InvalidOperation):
StartStreamProcessorOptions(workers=-1)
StartStreamProcessorOptions(workers=1) # must not raise
def test_stats_options_invalid_scale(self) -> None:
with self.assertRaises(InvalidOperation):
GetStreamProcessorStatsOptions(scale=0)
with self.assertRaises(InvalidOperation):
GetStreamProcessorStatsOptions(scale=-1)
GetStreamProcessorStatsOptions(scale=1)
GetStreamProcessorStatsOptions(scale=1024)
def test_samples_options_invalid_cursor_id(self) -> None:
with self.assertRaises(InvalidOperation):
GetStreamProcessorSamplesOptions(cursor_id=-1)
# cursor_id=0 is valid at construction; the method-level check rejects it.
GetStreamProcessorSamplesOptions(cursor_id=0)
GetStreamProcessorSamplesOptions(cursor_id=42)
def test_samples_options_invalid_limit(self) -> None:
with self.assertRaises(InvalidOperation):
GetStreamProcessorSamplesOptions(limit=-1)
GetStreamProcessorSamplesOptions(limit=0)
GetStreamProcessorSamplesOptions(limit=10)
def test_samples_options_invalid_batch_size(self) -> None:
with self.assertRaises(InvalidOperation):
GetStreamProcessorSamplesOptions(batch_size=-1)
GetStreamProcessorSamplesOptions(batch_size=0)
GetStreamProcessorSamplesOptions(batch_size=5)
def test_create_options_all_optional(self) -> None:
opts = CreateStreamProcessorOptions()
self.assertIsNone(opts.dlq)
self.assertIsNone(opts.stream_meta_field_name)
self.assertIsNone(opts.tier)
self.assertIsNone(opts.failover)
def test_stream_processor_info_from_response_full(self) -> None:
now = datetime.now(tz=timezone.utc)
doc: dict[str, Any] = {
"ok": 1,
"id": "abc123",
"name": "demo",
"state": "STARTED",
"pipeline": [{"$source": {}}],
"pipelineVersion": 3,
"tier": "SP10",
"dlq": {"connectionName": "c1"},
"streamMetaFieldName": "ts",
"enableAutoScaling": True,
"failoverEnabled": True,
"activeRegion": "us-east-1",
"lastModifiedAt": now,
"modifiedBy": "user@example.com",
"lastStateChange": now,
"lastHeartbeat": now,
"hasStarted": True,
"stats": {"inputMessageCount": 42},
"errorMsg": None,
"errorCode": None,
"errorRetryable": None,
}
info = StreamProcessorInfo.from_response(doc)
self.assertEqual(info.id, "abc123")
self.assertEqual(info.name, "demo")
self.assertEqual(info.state, "STARTED")
self.assertEqual(info.pipeline, [{"$source": {}}])
self.assertEqual(info.pipeline_version, 3)
self.assertEqual(info.tier, "SP10")
self.assertEqual(info.dlq, {"connectionName": "c1"})
self.assertEqual(info.stream_meta_field_name, "ts")
self.assertTrue(info.enable_auto_scaling)
self.assertTrue(info.failover_enabled)
self.assertEqual(info.active_region, "us-east-1")
self.assertTrue(info.has_started)
self.assertEqual(info.stats, {"inputMessageCount": 42})
self.assertEqual(info.raw, doc)
def test_stream_processor_info_from_response_minimal(self) -> None:
doc: dict[str, Any] = {
"id": "x",
"name": "p",
"state": "CREATED",
"pipeline": [],
"pipelineVersion": 1,
}
info = StreamProcessorInfo.from_response(doc)
self.assertEqual(info.id, "x")
self.assertIsNone(info.tier)
self.assertIsNone(info.dlq)
self.assertFalse(info.enable_auto_scaling)
self.assertFalse(info.has_started)
self.assertEqual(info.raw, doc)
def test_stream_processor_info_preserves_unknown_fields(self) -> None:
doc: dict[str, Any] = {
"id": "x",
"name": "p",
"state": "CREATED",
"pipeline": [],
"pipelineVersion": 1,
"futureField": "someValue",
}
info = StreamProcessorInfo.from_response(doc)
self.assertEqual(info.raw["futureField"], "someValue")
def test_stream_processor_info_state_is_string(self) -> None:
doc: dict[str, Any] = {
"id": "x",
"name": "p",
"state": "FUTURE_UNKNOWN_STATE",
"pipeline": [],
"pipelineVersion": 1,
}
info = StreamProcessorInfo.from_response(doc)
self.assertEqual(info.state, "FUTURE_UNKNOWN_STATE")
# ===========================================================================
# Test class 3 — lifecycle commands
# ===========================================================================
class TestStreamProcessorsCommands(unittest.TestCase):
"""Lifecycle commands on StreamProcessors and StreamProcessor."""
def setUp(self) -> None:
self.client, self.calls = _spy_client(responses=[{"ok": 1}] * 30)
self.sps = StreamProcessors(self.client)
self.proc = StreamProcessor(client=self.client, name="demo")
def test_create_sends_correct_command(self) -> None:
self.sps.create("demo", pipeline=[{"$source": {}}])
entry = self.calls[-1]
self.assertEqual(entry["cmd"].get("createStreamProcessor"), "demo")
self.assertEqual(entry["cmd"].get("pipeline"), [{"$source": {}}])
self.assertNotIn("options", entry["cmd"])
self.assertFalse(entry["retryable_read"])
def test_create_with_options_includes_them(self) -> None:
opts = CreateStreamProcessorOptions(dlq={"connectionName": "c"}, tier="SP10")
self.sps.create("demo", pipeline=[{"$source": {}}], options=opts)
cmd = self.calls[-1]["cmd"]
self.assertEqual(cmd["options"]["dlq"], {"connectionName": "c"})
self.assertEqual(cmd["options"]["tier"], "SP10")
def test_create_rejects_empty_name(self) -> None:
with self.assertRaises(InvalidOperation):
self.sps.create("", pipeline=[{"$x": 1}])
self.assertEqual(len(self.calls), 0)
def test_create_rejects_whitespace_name(self) -> None:
with self.assertRaises(InvalidOperation):
self.sps.create(" ", pipeline=[{"$x": 1}])
self.assertEqual(len(self.calls), 0)
def test_create_rejects_empty_pipeline(self) -> None:
with self.assertRaises(InvalidOperation):
self.sps.create("demo", pipeline=[])
self.assertEqual(len(self.calls), 0)
def test_get_returns_handle_without_calling_server(self) -> None:
proc = self.sps.get("demo")
self.assertEqual(len(self.calls), 0)
self.assertEqual(proc.name, "demo")
self.assertIsInstance(proc, StreamProcessor)
def test_get_rejects_empty_name(self) -> None:
with self.assertRaises(InvalidOperation):
self.sps.get("")
def test_get_info_sends_correct_command_and_decodes(self) -> None:
client, calls = _spy_client(responses=[dict(_INFO_RESPONSE)])
info = StreamProcessors(client).get_info("demo")
self.assertEqual(calls[0]["cmd"].get("getStreamProcessor"), "demo")
self.assertTrue(calls[0]["retryable_read"])
self.assertIsInstance(info, StreamProcessorInfo)
self.assertEqual(info.id, "abc")
self.assertEqual(info.state, "CREATED")
def test_get_info_preserves_unknown_response_fields(self) -> None:
resp = dict(_INFO_RESPONSE, futureField="value")
client, _ = _spy_client(responses=[resp])
info = StreamProcessors(client).get_info("demo")
self.assertEqual(info.raw["futureField"], "value")
def test_start_sends_correct_command(self) -> None:
self.proc.start()
entry = self.calls[-1]
self.assertEqual(entry["cmd"].get("startStreamProcessor"), "demo")
self.assertNotIn("workers", entry["cmd"])
self.assertNotIn("options", entry["cmd"])
self.assertFalse(entry["retryable_read"])
def test_start_with_options_includes_them(self) -> None:
opts = StartStreamProcessorOptions(workers=2, clear_checkpoints=True, tier="SP10")
self.proc.start(options=opts)
cmd = self.calls[-1]["cmd"]
self.assertEqual(cmd.get("workers"), 2)
self.assertTrue(cmd["options"]["clearCheckpoints"])
self.assertEqual(cmd["options"]["tier"], "SP10")
def test_stop_sends_correct_command(self) -> None:
self.proc.stop()
entry = self.calls[-1]
self.assertEqual(entry["cmd"].get("stopStreamProcessor"), "demo")
self.assertFalse(entry["retryable_read"])
def test_drop_sends_correct_command(self) -> None:
self.proc.drop()
entry = self.calls[-1]
self.assertEqual(entry["cmd"].get("dropStreamProcessor"), "demo")
self.assertFalse(entry["retryable_read"])
def test_stats_sends_correct_command_and_returns_raw_dict(self) -> None:
raw_resp: dict[str, Any] = {"ok": 1, "stats": {"inputMessageCount": 5}, "futureField": "x"}
client, calls = _spy_client(responses=[raw_resp])
result = StreamProcessor(client=client, name="demo").stats()
self.assertEqual(calls[0]["cmd"].get("getStreamProcessorStats"), "demo")
self.assertTrue(calls[0]["retryable_read"])
self.assertEqual(result, raw_resp)
def test_stats_with_options(self) -> None:
client, calls = _spy_client(responses=[{"ok": 1}])
opts = GetStreamProcessorStatsOptions(scale=1024, verbose=True)
StreamProcessor(client=client, name="demo").stats(options=opts)
self.assertEqual(calls[0]["cmd"]["options"]["scale"], 1024)
self.assertTrue(calls[0]["cmd"]["options"]["verbose"])
def test_session_propagates(self) -> None:
fake_session = MagicMock()
self.proc.start(session=fake_session)
self.assertIs(self.calls[-1]["session"], fake_session)
# ===========================================================================
# Test class 4 — sample cursor
# ===========================================================================
class TestSampleCursor(unittest.TestCase):
"""Two-phase sample cursor protocol."""
def test_get_samples_initial_call_sends_start_command(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"x": 1}]}])
proc = StreamProcessor(client=client, name="demo")
result = proc.get_stream_processor_samples()
cmd = calls[0]["cmd"]
self.assertIn("startSampleStreamProcessor", cmd)
self.assertNotIn("cursorId", cmd)
self.assertNotIn("batchSize", cmd)
self.assertFalse(calls[0]["retryable_read"])
self.assertEqual(result.cursor_id, 42)
self.assertEqual(result.documents, [{"x": 1}])
def test_get_samples_initial_call_includes_limit(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 1, "firstBatch": []}])
proc = StreamProcessor(client=client, name="demo")
proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(limit=100))
self.assertEqual(calls[0]["cmd"].get("limit"), 100)
self.assertNotIn("batchSize", calls[0]["cmd"])
def test_get_samples_continuation_sends_get_more(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": [{"y": 2}]}])
proc = StreamProcessor(client=client, name="demo")
result = proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=42))
cmd = calls[0]["cmd"]
self.assertIn("getMoreSampleStreamProcessor", cmd)
self.assertEqual(cmd.get("cursorId"), 42)
self.assertNotIn("limit", cmd)
self.assertEqual(result.cursor_id, 0)
self.assertEqual(result.documents, [{"y": 2}])
def test_get_samples_continuation_includes_batch_size(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": []}])
proc = StreamProcessor(client=client, name="demo")
proc.get_stream_processor_samples(
GetStreamProcessorSamplesOptions(cursor_id=42, batch_size=5)
)
self.assertEqual(calls[0]["cmd"].get("batchSize"), 5)
def test_get_samples_rejects_cursor_id_zero(self) -> None:
client, calls = _spy_client()
proc = StreamProcessor(client=client, name="demo")
with self.assertRaises(InvalidOperation):
proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=0))
self.assertEqual(len(calls), 0)
def test_sample_iterator_drains_single_batch(self) -> None:
client, calls = _spy_client(
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}]
)
proc = StreamProcessor(client=client, name="demo")
docs = []
for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
self.assertEqual(len(calls), 1)
def test_sample_iterator_drains_multiple_batches(self) -> None:
client, calls = _spy_client(
responses=[
{"cursorId": 11, "firstBatch": [{"a": 1}]},
{"cursorId": 22, "nextBatch": [{"a": 2}]},
{"cursorId": 0, "nextBatch": [{"a": 3}]},
]
)
proc = StreamProcessor(client=client, name="demo")
docs = []
for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
self.assertEqual(len(calls), 3)
def test_sample_iterator_handles_empty_batch_with_nonzero_cursor(self) -> None:
client, calls = _spy_client(
responses=[
{"cursorId": 11, "firstBatch": []},
{"cursorId": 22, "nextBatch": []},
{"cursorId": 0, "nextBatch": [{"a": 1}]},
]
)
proc = StreamProcessor(client=client, name="demo")
docs = []
for doc in proc.sample():
docs.append(doc)
self.assertEqual(docs, [{"a": 1}])
self.assertEqual(len(calls), 3)
def test_sample_iterator_stops_after_cursor_id_zero(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = StreamProcessor(client=client, name="demo")
cursor = proc.sample()
first = cursor.__next__()
self.assertEqual(first, {"a": 1})
with self.assertRaises(StopIteration):
cursor.__next__()
# Only one wire call — the server returned cursorId:0 so we stopped.
self.assertEqual(len(calls), 1)
def test_sample_cursor_close_stops_iteration(self) -> None:
client, calls = _spy_client()
proc = StreamProcessor(client=client, name="demo")
cursor = proc.sample()
cursor.close()
with self.assertRaises(StopIteration):
cursor.__next__()
self.assertEqual(len(calls), 0)
def test_sample_cursor_alive_property(self) -> None:
client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = StreamProcessor(client=client, name="demo")
cursor = proc.sample()
self.assertTrue(cursor.alive)
doc = cursor.__next__()
self.assertEqual(doc, {"a": 1})
# After exhaustion (cursorId:0 returned during refill), alive is False.
self.assertFalse(cursor.alive)
def test_sample_cursor_id_property(self) -> None:
client, _ = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"a": 1}]}])
proc = StreamProcessor(client=client, name="demo")
cursor = proc.sample()
self.assertIsNone(cursor.cursor_id)
cursor.__next__()
self.assertEqual(cursor.cursor_id, 42)
def test_sample_cursor_async_context_manager(self) -> None:
client, _ = _spy_client()
proc = StreamProcessor(client=client, name="demo")
with proc.sample() as cur:
self.assertIsInstance(cur, SampleCursor)
self.assertTrue(cur._closed)
def test_sample_cursor_session_propagates_to_refill(self) -> None:
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
proc = StreamProcessor(client=client, name="demo")
fake_session = MagicMock()
for _ in proc.sample(session=fake_session):
pass
self.assertIs(calls[0]["session"], fake_session)
def test_sample_cursor_does_not_inherit_from_cursor_base(self) -> None:
self.assertEqual(SampleCursor.__bases__, (object,))
# ===========================================================================
# Test class 5 — error handling
# ===========================================================================
class TestErrorHandling(unittest.TestCase):
"""Server errors MUST propagate unchanged — no rewrapping, no filtering."""
def _failure(self, code: int = 9) -> OperationFailure:
return OperationFailure("test error", code=code, details={"errmsg": "test", "code": code})
def _assert_propagates(self, fn: Any, expected_code: int) -> None:
with self.assertRaises(OperationFailure) as cm:
fn()
self.assertEqual(cm.exception.code, expected_code)
def test_create_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
self._assert_propagates(
lambda: StreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 9
)
def test_start_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(72))
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").start(), 72)
def test_stop_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(125))
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").stop(), 125)
def test_drop_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(1))
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").drop(), 1)
def test_get_info_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
self._assert_propagates(lambda: StreamProcessors(client).get_info("p"), 9)
def test_stats_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(72))
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").stats(), 72)
def test_get_stream_processor_samples_propagates_operation_failure(self) -> None:
client, _ = _spy_client(raises=self._failure(9))
self._assert_propagates(
lambda: StreamProcessor(client=client, name="p").get_stream_processor_samples(), 9
)
def test_unknown_error_code_propagates(self) -> None:
# Code 999 is not in the documented list; it must still propagate.
client, _ = _spy_client(raises=self._failure(999))
self._assert_propagates(
lambda: StreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 999
)
def test_no_operation_failure_caught_in_module(self) -> None:
"""Structural: verify no try/except hides server errors in either module."""
for mod in (pymongo.synchronous.stream_processing,):
src = inspect.getsource(mod)
self.assertNotIn("except OperationFailure", src, mod.__name__)
self.assertNotIn("except PyMongoError", src, mod.__name__)
self.assertNotIn("exc.code in", src, mod.__name__)
# ===========================================================================
# Test class 6 — spec compliance
# ===========================================================================
class TestSpecCompliance(unittest.TestCase):
"""Spec MUST-level compliance checks."""
def test_retryability_classification(self) -> None:
"""Each command must route as retryable or non-retryable per the spec."""
cases: list[tuple[str, Any, bool]] = [
# (label, coroutine-factory, expected_retryable_read)
("create", lambda c: StreamProcessors(c).create("demo", pipeline=[{"$x": 1}]), False),
("start", lambda c: StreamProcessor(client=c, name="demo").start(), False),
("stop", lambda c: StreamProcessor(client=c, name="demo").stop(), False),
("drop", lambda c: StreamProcessor(client=c, name="demo").drop(), False),
("get_info", lambda c: StreamProcessors(c).get_info("demo"), True),
("stats", lambda c: StreamProcessor(client=c, name="demo").stats(), True),
(
"get_samples_initial",
lambda c: StreamProcessor(client=c, name="demo").get_stream_processor_samples(),
False,
),
(
"get_samples_continuation",
lambda c: StreamProcessor(client=c, name="demo").get_stream_processor_samples(
GetStreamProcessorSamplesOptions(cursor_id=42)
),
False,
),
]
responses_for: dict[str, list[Mapping[str, Any]]] = {
"get_info": [dict(_INFO_RESPONSE)],
"get_samples_initial": [{"cursorId": 0, "firstBatch": []}],
"get_samples_continuation": [{"cursorId": 0, "nextBatch": []}],
}
for label, make_coro, expected in cases:
resp = responses_for.get(label, [{"ok": 1}])
client, calls = _spy_client(responses=resp)
make_coro(client)
self.assertEqual(
calls[-1]["retryable_read"],
expected,
f"retryability mismatch for '{label}'",
)
def test_async_sample_cursor_class_hierarchy(self) -> None:
self.assertEqual(SampleCursor.__bases__, (object,))
if __name__ == "__main__":
unittest.main()

View File

@ -32,6 +32,7 @@ replacements = {
"AsyncDatabase": "Database",
"_AsyncCursorBase": "_CursorBase",
"AsyncCursor": "Cursor",
"pymongo.asynchronous.stream_processing.AsyncMongoClient": "pymongo.synchronous.stream_processing.MongoClient",
"AsyncMongoClient": "MongoClient",
"AsyncCommandCursor": "CommandCursor",
"AsyncRawBatchCursor": "RawBatchCursor",
@ -67,6 +68,10 @@ replacements = {
"_AsyncGridOutChunkIterator": "GridOutChunkIterator",
"_a_grid_in_property": "_grid_in_property",
"_a_grid_out_property": "_grid_out_property",
"AsyncStreamProcessingClient": "StreamProcessingClient",
"AsyncStreamProcessors": "StreamProcessors",
"AsyncStreamProcessor": "StreamProcessor",
"AsyncSampleCursor": "SampleCursor",
"AsyncClientEncryption": "ClientEncryption",
"AsyncMongoCryptCallback": "MongoCryptCallback",
"AsyncExplicitEncrypter": "ExplicitEncrypter",
@ -128,6 +133,12 @@ replacements = {
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",
"AsyncMockMonitor": "SyncMockMonitor",
"AsyncMock": "MagicMock",
"AsyncTestStreamProcessorsCommands": "TestStreamProcessorsCommands",
"AsyncTestSampleCursor": "TestSampleCursor",
"AsyncTestErrorHandling": "TestErrorHandling",
"AsyncTestSpecCompliance": "TestSpecCompliance",
"StopAsyncIteration": "StopIteration",
"create_async_event": "create_event",
"async_create_barrier": "create_barrier",
@ -272,6 +283,7 @@ converted_tests = [
"test_srv_polling.py",
"test_ssl.py",
"test_streaming_protocol.py",
"test_stream_processing.py",
"test_transactions.py",
"test_transactions_unified.py",
"test_unified_format.py",