Merge d78f204ad6 into 552b7bf47b
This commit is contained in:
commit
9ba071f302
@ -56,6 +56,7 @@ Sub-modules:
|
||||
results
|
||||
server_api
|
||||
server_description
|
||||
stream_processing
|
||||
topology_description
|
||||
uri_parser
|
||||
write_concern
|
||||
|
||||
190
doc/api/pymongo/stream_processing.rst
Normal file
190
doc/api/pymongo/stream_processing.rst
Normal file
@ -0,0 +1,190 @@
|
||||
:mod:`stream_processing` -- Atlas Stream Processing
|
||||
====================================================
|
||||
|
||||
.. warning::
|
||||
|
||||
The Atlas Stream Processing API is **experimental**. The driver
|
||||
specification is in Draft status and the API surface — including
|
||||
retryability behavior, error code mappings, and method signatures —
|
||||
may change in a backward-incompatible way before the spec is
|
||||
finalized.
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
Atlas Stream Processing (ASP) lets you build continuous, stateful pipelines
|
||||
that process data from one or more sources in real time. A *stream processing
|
||||
workspace* is a dedicated Atlas endpoint that hosts one or more named stream
|
||||
processors; it is distinct from a standard MongoDB cluster and is accessed
|
||||
through :class:`~pymongo.asynchronous.stream_processing.AsyncStreamProcessingClient`
|
||||
(or its sync twin :class:`~pymongo.synchronous.stream_processing.StreamProcessingClient`)
|
||||
rather than ``MongoClient``.
|
||||
|
||||
Workspace connection strings use the standard ``mongodb://`` scheme — ``mongodb+srv://``
|
||||
is not supported. TLS is always required for workspace connections and cannot be
|
||||
disabled. ``authSource`` defaults to ``"admin"`` if not explicitly set. Users must
|
||||
hold the ``atlasAdmin`` role to execute ASP commands.
|
||||
|
||||
Quickstart
|
||||
----------
|
||||
|
||||
Async
|
||||
~~~~~
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from pymongo import AsyncStreamProcessingClient
|
||||
|
||||
async def main():
|
||||
uri = (
|
||||
"mongodb://user:pass@"
|
||||
"atlas-stream-<workspaceId>-<suffix>.<region>.a.query.mongodb.net/"
|
||||
)
|
||||
async with AsyncStreamProcessingClient(uri) as client:
|
||||
sps = client.stream_processors()
|
||||
await sps.create("demo", pipeline=[
|
||||
{"$source": {"connectionName": "<conn>", "topic": "events"}},
|
||||
{"$match": {"value": {"$gt": 100}}},
|
||||
{"$emit": {"connectionName": "<conn>", "db": "out", "coll": "high"}},
|
||||
])
|
||||
proc = sps.get("demo")
|
||||
await proc.start()
|
||||
info = await sps.get_info("demo")
|
||||
print(info.state, info.pipeline_version)
|
||||
async for doc in proc.sample(limit=10):
|
||||
print(doc)
|
||||
await proc.stop()
|
||||
await proc.drop()
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Sync
|
||||
~~~~
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pymongo import StreamProcessingClient
|
||||
|
||||
uri = (
|
||||
"mongodb://user:pass@"
|
||||
"atlas-stream-<workspaceId>-<suffix>.<region>.a.query.mongodb.net/"
|
||||
)
|
||||
with StreamProcessingClient(uri) as client:
|
||||
sps = client.stream_processors()
|
||||
sps.create("demo", pipeline=[
|
||||
{"$source": {"connectionName": "<conn>", "topic": "events"}},
|
||||
{"$match": {"value": {"$gt": 100}}},
|
||||
{"$emit": {"connectionName": "<conn>", "db": "out", "coll": "high"}},
|
||||
])
|
||||
proc = sps.get("demo")
|
||||
proc.start()
|
||||
info = sps.get_info("demo")
|
||||
print(info.state, info.pipeline_version)
|
||||
for doc in proc.sample(limit=10):
|
||||
print(doc)
|
||||
proc.stop()
|
||||
proc.drop()
|
||||
|
||||
Sample cursor semantics
|
||||
-----------------------
|
||||
|
||||
The sample cursor is a custom two-phase protocol distinct from the standard
|
||||
MongoDB ``getMore`` mechanism. It MUST NOT be confused with standard
|
||||
:class:`~pymongo.asynchronous.cursor.AsyncCursor` objects.
|
||||
|
||||
For most use cases, call :meth:`~pymongo.asynchronous.stream_processing.AsyncStreamProcessor.sample`
|
||||
to obtain an :class:`~pymongo.asynchronous.stream_processing.AsyncSampleCursor` and iterate it
|
||||
with ``async for``. The cursor drives the underlying protocol automatically:
|
||||
|
||||
- The first iteration sends ``startSampleStreamProcessor``, optionally with a ``limit``.
|
||||
- Subsequent iterations send ``getMoreSampleStreamProcessor`` with the ``cursorId``
|
||||
returned by the previous call, optionally with a ``batchSize``.
|
||||
- When the server returns ``cursorId: 0`` the cursor is exhausted and no further
|
||||
wire calls are made.
|
||||
|
||||
For fine-grained control — tracking ``cursorId`` across calls yourself — use
|
||||
:meth:`~pymongo.asynchronous.stream_processing.AsyncStreamProcessor.get_stream_processor_samples`
|
||||
directly. Pass ``cursor_id=0`` and an :exc:`~pymongo.errors.InvalidOperation` is raised
|
||||
immediately, before any wire call is sent.
|
||||
|
||||
Commands not yet supported
|
||||
--------------------------
|
||||
|
||||
The following commands from the ASP server specification are intentionally
|
||||
deferred and not yet wrapped by this API:
|
||||
|
||||
- ``modifyStreamProcessor`` — rename, pipeline replacement, and DLQ reconfiguration
|
||||
- ``listStreamProcessors`` — enumerate processors in a workspace
|
||||
- ``listStreamConnections`` — enumerate available connections
|
||||
- ``processStreamProcessor`` — one-shot ad-hoc pipeline execution
|
||||
- ``listWorkspaceDefaults`` — fetch workspace tier defaults
|
||||
|
||||
Users can still call any of these directly via ``run_command`` on a plain
|
||||
:class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` connected to the
|
||||
workspace endpoint.
|
||||
|
||||
Async classes
|
||||
-------------
|
||||
|
||||
.. autoclass:: pymongo.asynchronous.stream_processing.AsyncStreamProcessingClient
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.asynchronous.stream_processing.AsyncStreamProcessors
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.asynchronous.stream_processing.AsyncStreamProcessor
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.asynchronous.stream_processing.AsyncSampleCursor
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Sync classes
|
||||
------------
|
||||
|
||||
.. autoclass:: pymongo.synchronous.stream_processing.StreamProcessingClient
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.synchronous.stream_processing.StreamProcessors
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.synchronous.stream_processing.StreamProcessor
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.synchronous.stream_processing.SampleCursor
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Options and result types
|
||||
------------------------
|
||||
|
||||
.. autoclass:: pymongo.stream_processing_options.CreateStreamProcessorOptions
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.stream_processing_options.StartStreamProcessorOptions
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.stream_processing_options.GetStreamProcessorStatsOptions
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.stream_processing_options.GetStreamProcessorSamplesOptions
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.stream_processing_options.GetStreamProcessorSamplesResult
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: pymongo.stream_processing_options.StreamProcessorInfo
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@ -1,6 +1,20 @@
|
||||
Changelog
|
||||
=========
|
||||
|
||||
Upcoming
|
||||
--------
|
||||
|
||||
- Added experimental support for Atlas Stream Processing (ASP).
|
||||
:class:`~pymongo.asynchronous.stream_processing.AsyncStreamProcessingClient`
|
||||
and :class:`~pymongo.synchronous.stream_processing.StreamProcessingClient`
|
||||
enable native client-side management of stream processors in an ASP workspace,
|
||||
including ``createStreamProcessor``, ``startStreamProcessor``,
|
||||
``stopStreamProcessor``, ``dropStreamProcessor``, ``getStreamProcessor``,
|
||||
``getStreamProcessorStats``, and the two-phase sample cursor protocol
|
||||
(``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``). See the
|
||||
:mod:`stream_processing` API docs for details. The API is experimental and
|
||||
may change before the driver specification is finalized.
|
||||
|
||||
Changes in Version 4.17.0 (2026/04/20)
|
||||
--------------------------------------
|
||||
|
||||
|
||||
153
poc_asp.py
Normal file
153
poc_asp.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""
|
||||
POC: Atlas Stream Processing — create / start / stop / drop a stream processor.
|
||||
|
||||
Fill in the FILL_ME values below before running:
|
||||
python3 poc_asp.py
|
||||
|
||||
Pipeline used:
|
||||
$source → sample_stream_solar
|
||||
$emit → __testLog
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import pprint
|
||||
|
||||
from pymongo import AsyncStreamProcessingClient
|
||||
from pymongo.errors import OperationFailure
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration — fill these in before running
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Workspace connection string from Atlas UI:
|
||||
# Stream Processing → your workspace → Connect
|
||||
# Format: mongodb://<host>/ (no mongodb+srv://)
|
||||
WORKSPACE_URI = "mongodb://atlas-stream-69ed590869155100cecc8b33-lulzki.virginia-usa.a.query.mongodb-dev.net/" # e.g. "mongodb://atlas-stream-<id>.<region>.a.query.mongodb.net/"
|
||||
|
||||
# Atlas DB user credentials (must have atlasAdmin role on the workspace project)
|
||||
USERNAME = "streams"
|
||||
PASSWORD = "letsdostreaming123"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline — hardcoded per your setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROCESSOR_NAME = "simpletestSP"
|
||||
|
||||
PIPELINE = [
|
||||
{
|
||||
"$source": {
|
||||
"connectionName": "sample_stream_solar",
|
||||
}
|
||||
},
|
||||
{
|
||||
"$emit": {
|
||||
"connectionName": "__testLog",
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POC steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def main() -> None:
|
||||
if "FILL_ME" in (WORKSPACE_URI, USERNAME, PASSWORD):
|
||||
raise SystemExit("Fill in WORKSPACE_URI, USERNAME, and PASSWORD at the top of this file.")
|
||||
|
||||
async with AsyncStreamProcessingClient(WORKSPACE_URI, username=USERNAME, password=PASSWORD) as client:
|
||||
sps = client.stream_processors()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Create
|
||||
# ------------------------------------------------------------------
|
||||
print(f"\n[1] Creating processor '{PROCESSOR_NAME}' ...")
|
||||
try:
|
||||
await sps.create(PROCESSOR_NAME, pipeline=PIPELINE)
|
||||
print(" Created OK")
|
||||
except OperationFailure as e:
|
||||
raise SystemExit(f" Create failed (code {e.code}): {e}") from e
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Inspect before starting
|
||||
# ------------------------------------------------------------------
|
||||
print("\n[2] Getting info ...")
|
||||
info = await sps.get_info(PROCESSOR_NAME)
|
||||
print(f" state : {info.state}")
|
||||
print(f" pipeline_version : {info.pipeline_version}")
|
||||
print(f" has_started : {info.has_started}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Start
|
||||
# ------------------------------------------------------------------
|
||||
proc = sps.get(PROCESSOR_NAME)
|
||||
print("\n[3] Starting processor ...")
|
||||
try:
|
||||
await proc.start()
|
||||
print(" Start command sent OK")
|
||||
except OperationFailure as e:
|
||||
raise SystemExit(f" Start failed (code {e.code}): {e}") from e
|
||||
|
||||
# Give the server a moment to transition state
|
||||
await asyncio.sleep(2)
|
||||
|
||||
info = await sps.get_info(PROCESSOR_NAME)
|
||||
print(f" state after start: {info.state}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Stats
|
||||
# ------------------------------------------------------------------
|
||||
print("\n[4] Fetching stats ...")
|
||||
try:
|
||||
raw_stats = await proc.stats()
|
||||
pprint.pprint(raw_stats)
|
||||
except OperationFailure as e:
|
||||
print(f" Stats unavailable (code {e.code}): {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Sample (up to 5 docs)
|
||||
# Note: breaking manually after N docs because the dev server does not
|
||||
# signal cursor exhaustion with cursorId=0 as the spec requires.
|
||||
# ------------------------------------------------------------------
|
||||
print("\n[5] Sampling up to 5 documents ...")
|
||||
try:
|
||||
count = 0
|
||||
async for doc in proc.sample():
|
||||
print(f" doc: {doc}")
|
||||
count += 1
|
||||
if count >= 5:
|
||||
break
|
||||
print(f" Sampled {count} document(s)")
|
||||
except OperationFailure as e:
|
||||
print(f" Sample unavailable (code {e.code}): {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. Stop
|
||||
# ------------------------------------------------------------------
|
||||
print("\n[6] Stopping processor ...")
|
||||
try:
|
||||
await proc.stop()
|
||||
print(" Stop command sent OK")
|
||||
except OperationFailure as e:
|
||||
raise SystemExit(f" Stop failed (code {e.code}): {e}") from e
|
||||
|
||||
await asyncio.sleep(1)
|
||||
info = await sps.get_info(PROCESSOR_NAME)
|
||||
print(f" state after stop : {info.state}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 7. Drop (permanent — comment out to keep the processor alive)
|
||||
# ------------------------------------------------------------------
|
||||
print("\n[7] Dropping processor ...")
|
||||
try:
|
||||
await proc.drop()
|
||||
print(" Dropped OK")
|
||||
except OperationFailure as e:
|
||||
raise SystemExit(f" Drop failed (code {e.code}): {e}") from e
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Python driver for MongoDB."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ContextManager, Optional
|
||||
@ -45,6 +46,21 @@ __all__ = [
|
||||
"WriteConcern",
|
||||
"has_c",
|
||||
"timeout",
|
||||
# Atlas Stream Processing (experimental)
|
||||
"AsyncSampleCursor",
|
||||
"AsyncStreamProcessingClient",
|
||||
"AsyncStreamProcessor",
|
||||
"AsyncStreamProcessors",
|
||||
"SampleCursor",
|
||||
"StreamProcessingClient",
|
||||
"StreamProcessor",
|
||||
"StreamProcessors",
|
||||
"CreateStreamProcessorOptions",
|
||||
"GetStreamProcessorSamplesOptions",
|
||||
"GetStreamProcessorSamplesResult",
|
||||
"GetStreamProcessorStatsOptions",
|
||||
"StartStreamProcessorOptions",
|
||||
"StreamProcessorInfo",
|
||||
]
|
||||
|
||||
ASCENDING = 1
|
||||
@ -89,6 +105,14 @@ TEXT = "text"
|
||||
from pymongo import _csot
|
||||
from pymongo._version import __version__, get_version_string, version_tuple
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
|
||||
# Atlas Stream Processing (experimental)
|
||||
from pymongo.asynchronous.stream_processing import (
|
||||
AsyncSampleCursor,
|
||||
AsyncStreamProcessingClient,
|
||||
AsyncStreamProcessor,
|
||||
AsyncStreamProcessors,
|
||||
)
|
||||
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION, has_c
|
||||
from pymongo.cursor import CursorType
|
||||
from pymongo.operations import (
|
||||
@ -101,8 +125,22 @@ from pymongo.operations import (
|
||||
UpdateOne,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.stream_processing_options import (
|
||||
CreateStreamProcessorOptions,
|
||||
GetStreamProcessorSamplesOptions,
|
||||
GetStreamProcessorSamplesResult,
|
||||
GetStreamProcessorStatsOptions,
|
||||
StartStreamProcessorOptions,
|
||||
StreamProcessorInfo,
|
||||
)
|
||||
from pymongo.synchronous.collection import ReturnDocument
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.stream_processing import (
|
||||
SampleCursor,
|
||||
StreamProcessingClient,
|
||||
StreamProcessor,
|
||||
StreamProcessors,
|
||||
)
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
# Public module compatibility imports
|
||||
|
||||
633
pymongo/asynchronous/stream_processing.py
Normal file
633
pymongo/asynchronous/stream_processing.py
Normal file
@ -0,0 +1,633 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Atlas Stream Processing client for stream processing workspaces.
|
||||
|
||||
A stream processing workspace endpoint uses the ``mongodb://`` URI scheme with
|
||||
a hostname that follows this pattern::
|
||||
|
||||
atlas-stream-<workspaceId>-<suffix>.<region>.a.query.mongodb.net
|
||||
|
||||
For example::
|
||||
|
||||
mongodb://atlas-stream-699c842ef433fe6001480b17-etif1.virginia-usa.a.query.mongodb.net/
|
||||
|
||||
TLS is always required for workspace connections and cannot be disabled.
|
||||
``authSource=admin`` is applied by default when not explicitly set.
|
||||
|
||||
For commands not yet wrapped by this API, users can connect via a plain
|
||||
:class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` and call
|
||||
``run_command`` directly — that path remains fully supported.
|
||||
|
||||
Error handling
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
ASP commands raise :class:`pymongo.errors.OperationFailure` on server-side
|
||||
errors. The following error codes are known to be returned by Atlas Stream
|
||||
Processing commands:
|
||||
|
||||
================ ================= ====================================
|
||||
Code Name When returned
|
||||
================ ================= ====================================
|
||||
9 FailedToParse Invalid pipeline or command document
|
||||
72 InvalidOptions Invalid option values
|
||||
125 CommandFailed General command execution failure
|
||||
1 InternalError Unexpected server-side error
|
||||
================ ================= ====================================
|
||||
|
||||
This list is **non-exhaustive** and may grow as the server evolves. Drivers
|
||||
do not maintain a closed list of valid codes — applications should branch
|
||||
on ``exc.code`` only when they need to react to a specific known code, and
|
||||
should always have a generic fallback for unknown codes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.stream_processing_options import (
|
||||
CreateStreamProcessorOptions,
|
||||
GetStreamProcessorSamplesOptions,
|
||||
GetStreamProcessorSamplesResult,
|
||||
GetStreamProcessorStatsOptions,
|
||||
StartStreamProcessorOptions,
|
||||
StreamProcessorInfo,
|
||||
)
|
||||
from pymongo.uri_parser_shared import SRV_SCHEME, _validate_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _is_workspace_endpoint(host: str) -> bool:
|
||||
"""Return True if *host* looks like an ASP workspace endpoint.
|
||||
|
||||
Workspace hostnames begin with ``atlas-stream-`` or end with
|
||||
``.a.query.mongodb.net``.
|
||||
"""
|
||||
return host.startswith("atlas-stream-") or host.endswith(".a.query.mongodb.net")
|
||||
|
||||
|
||||
class AsyncStreamProcessingClient:
|
||||
"""A client connected to an Atlas Stream Processing workspace.
|
||||
|
||||
Wraps :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` with
|
||||
Atlas Stream Processing constraints:
|
||||
|
||||
* Only the ``mongodb://`` URI scheme is accepted (``mongodb+srv://`` is
|
||||
not supported for workspace endpoints).
|
||||
* TLS is always enabled and cannot be disabled.
|
||||
* ``authSource`` defaults to ``"admin"`` if not explicitly set.
|
||||
|
||||
Usage::
|
||||
|
||||
async with AsyncStreamProcessingClient(
|
||||
"mongodb://atlas-stream-<id>.<region>.a.query.mongodb.net/",
|
||||
username="user",
|
||||
password="pass",
|
||||
) as client:
|
||||
sps = client.stream_processors()
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
host: Any = args[0] if args else kwargs.get("host")
|
||||
|
||||
uris: list[str] = []
|
||||
if isinstance(host, str):
|
||||
uris = [host]
|
||||
elif isinstance(host, (list, tuple)):
|
||||
uris = [h for h in host if isinstance(h, str)]
|
||||
|
||||
uri_has_auth_source = False
|
||||
for uri_str in uris:
|
||||
if not uri_str.startswith(("mongodb://", SRV_SCHEME)):
|
||||
# Plain hostname — no URI scheme to check.
|
||||
continue
|
||||
if uri_str.startswith(SRV_SCHEME):
|
||||
raise ConfigurationError(
|
||||
"StreamProcessingClient does not support mongodb+srv:// URIs; "
|
||||
"use mongodb:// with a workspace endpoint."
|
||||
)
|
||||
parsed = _validate_uri(uri_str, validate=True, warn=False, normalize=True)
|
||||
uri_opts = parsed["options"]
|
||||
if uri_opts.get("tls") is False:
|
||||
raise ConfigurationError(
|
||||
"TLS cannot be disabled for stream processing workspace connections."
|
||||
)
|
||||
if uri_opts.get("authsource") is not None:
|
||||
uri_has_auth_source = True
|
||||
|
||||
# Also reject explicit tls=False / ssl=False in kwargs.
|
||||
if kwargs.get("tls") is False or kwargs.get("ssl") is False:
|
||||
raise ConfigurationError(
|
||||
"TLS cannot be disabled for stream processing workspace connections."
|
||||
)
|
||||
|
||||
kwargs.pop("ssl", None)
|
||||
kwargs["tls"] = True
|
||||
|
||||
if not uri_has_auth_source and not any(k.lower() == "authsource" for k in kwargs):
|
||||
kwargs["authSource"] = "admin"
|
||||
|
||||
self._client: AsyncMongoClient[Any] = AsyncMongoClient(*args, **kwargs)
|
||||
|
||||
# NOTE: Per the ASP driver spec, server errors MUST be surfaced as-is.
|
||||
# Do NOT introduce error-code branching, rewrapping, or filtering anywhere
|
||||
# in this module. Known codes are documented at the module level for
|
||||
# reference only — they are not runtime invariants.
|
||||
async def _command(
|
||||
self,
|
||||
cmd: dict[str, Any],
|
||||
*,
|
||||
retryable_read: bool = False,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Send a top-level ASP command to the admin database.
|
||||
|
||||
Routes through the existing retry machinery: retryable reads use
|
||||
:meth:`~pymongo.asynchronous.database.AsyncDatabase._retryable_read_command`;
|
||||
everything else uses the standard
|
||||
:meth:`~pymongo.asynchronous.database.AsyncDatabase.command` path.
|
||||
|
||||
:param cmd: The command document.
|
||||
:param retryable_read: If ``True``, the command is sent as a retryable read.
|
||||
:param session: A
|
||||
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
admin = self._client._database_default_options("admin")
|
||||
if retryable_read:
|
||||
# The first key of the command document is the operation name,
|
||||
# which matches the corresponding _Op enum value string.
|
||||
operation = next(iter(cmd))
|
||||
return await admin._retryable_read_command(cmd, session=session, operation=operation)
|
||||
return await admin.command(cmd, session=session)
|
||||
|
||||
def stream_processors(self) -> AsyncStreamProcessors:
|
||||
"""Return a handle for managing stream processors in this workspace."""
|
||||
return AsyncStreamProcessors(self)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying client and release all resources."""
|
||||
await self._client.close()
|
||||
|
||||
async def __aenter__(self) -> AsyncStreamProcessingClient:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
@property
|
||||
def address(self) -> Any:
|
||||
"""(host, port) of the current workspace endpoint, or None.
|
||||
|
||||
Delegates to the underlying
|
||||
:attr:`~pymongo.asynchronous.mongo_client.AsyncMongoClient.address`.
|
||||
"""
|
||||
return self._client.address
|
||||
|
||||
|
||||
class AsyncStreamProcessors:
|
||||
"""Handle for managing stream processors in a workspace.
|
||||
|
||||
Obtained via :meth:`AsyncStreamProcessingClient.stream_processors`.
|
||||
"""
|
||||
|
||||
def __init__(self, client: AsyncStreamProcessingClient) -> None:
|
||||
self._client = client
|
||||
|
||||
async def create(
|
||||
self,
|
||||
name: str,
|
||||
pipeline: list[Mapping[str, Any]],
|
||||
options: Optional[CreateStreamProcessorOptions] = None,
|
||||
*,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
"""Create a new stream processor in the workspace.
|
||||
|
||||
Sends the ``createStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is **not** retryable.
|
||||
|
||||
:param name: The name of the stream processor to create.
|
||||
:param pipeline: The aggregation pipeline that defines the processor.
|
||||
:param options: Optional :class:`~pymongo.stream_processing_options.CreateStreamProcessorOptions`
|
||||
controlling DLQ, tier, and other settings.
|
||||
:param session: A
|
||||
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise InvalidOperation("Stream processor name must be a non-empty string.")
|
||||
if not pipeline:
|
||||
raise InvalidOperation("createStreamProcessor requires a non-empty pipeline.")
|
||||
cmd: dict[str, Any] = {
|
||||
_Op.CREATE_STREAM_PROCESSOR: name,
|
||||
"pipeline": list(pipeline),
|
||||
}
|
||||
if options is not None:
|
||||
opts: dict[str, Any] = {}
|
||||
if options.dlq is not None:
|
||||
opts["dlq"] = options.dlq
|
||||
if options.stream_meta_field_name is not None:
|
||||
opts["streamMetaFieldName"] = options.stream_meta_field_name
|
||||
if options.tier is not None:
|
||||
opts["tier"] = options.tier
|
||||
if options.failover is not None:
|
||||
opts["failover"] = options.failover
|
||||
if opts:
|
||||
cmd["options"] = opts
|
||||
await self._client._command(cmd, session=session)
|
||||
|
||||
def get(self, name: str) -> AsyncStreamProcessor:
|
||||
"""Return a handle for an existing stream processor by name.
|
||||
|
||||
This is a pure factory method — it does **not** contact the server
|
||||
or verify that the processor exists.
|
||||
|
||||
:param name: The name of the stream processor.
|
||||
:returns: An :class:`AsyncStreamProcessor` handle.
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise InvalidOperation("Stream processor name must be a non-empty string.")
|
||||
return AsyncStreamProcessor(client=self._client, name=name)
|
||||
|
||||
async def get_info(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> StreamProcessorInfo:
|
||||
"""Return information about a single stream processor.
|
||||
|
||||
Sends the ``getStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is a **retryable read**.
|
||||
|
||||
:param name: The name of the stream processor.
|
||||
:param session: A
|
||||
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
|
||||
for this operation.
|
||||
:returns: A :class:`~pymongo.stream_processing_options.StreamProcessorInfo`
|
||||
populated from the server response. Unknown server fields are preserved
|
||||
in :attr:`~pymongo.stream_processing_options.StreamProcessorInfo.raw`.
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise InvalidOperation("Stream processor name must be a non-empty string.")
|
||||
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR: name}
|
||||
response = await self._client._command(cmd, retryable_read=True, session=session)
|
||||
doc = response.get("result", response)
|
||||
return StreamProcessorInfo.from_response(doc)
|
||||
|
||||
|
||||
class AsyncStreamProcessor:
|
||||
"""Handle for a specific named stream processor.
|
||||
|
||||
Does not imply the processor currently exists on the server.
|
||||
Obtain via :meth:`AsyncStreamProcessors.get`.
|
||||
"""
|
||||
|
||||
def __init__(self, *, client: AsyncStreamProcessingClient, name: str) -> None:
|
||||
if not name or not name.strip():
|
||||
raise InvalidOperation("Stream processor name must be a non-empty string.")
|
||||
self._client = client
|
||||
self.name = name
|
||||
|
||||
async def start(
|
||||
self,
|
||||
options: Optional[StartStreamProcessorOptions] = None,
|
||||
*,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
"""Start this stream processor.
|
||||
|
||||
Sends the ``startStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is **not** retryable.
|
||||
|
||||
The processor must be in the ``STOPPED`` or ``FAILED`` state; starting
|
||||
an already-``STARTED`` processor returns a server error.
|
||||
|
||||
Mutual exclusivity of ``start_after`` / ``start_at_operation_time`` and
|
||||
tier validation are enforced by
|
||||
:class:`~pymongo.stream_processing_options.StartStreamProcessorOptions`
|
||||
at construction time.
|
||||
|
||||
:param options: Optional
|
||||
:class:`~pymongo.stream_processing_options.StartStreamProcessorOptions`.
|
||||
:param session: A
|
||||
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
cmd: dict[str, Any] = {_Op.START_STREAM_PROCESSOR: self.name}
|
||||
if options is not None:
|
||||
if options.workers is not None:
|
||||
cmd["workers"] = options.workers
|
||||
opts: dict[str, Any] = {}
|
||||
if options.clear_checkpoints is not None:
|
||||
opts["clearCheckpoints"] = options.clear_checkpoints
|
||||
if options.start_at_operation_time is not None:
|
||||
opts["startAtOperationTime"] = options.start_at_operation_time
|
||||
if options.start_after is not None:
|
||||
opts["startAfter"] = options.start_after
|
||||
if options.tier is not None:
|
||||
opts["tier"] = options.tier
|
||||
if options.enable_auto_scaling is not None:
|
||||
opts["enableAutoScaling"] = options.enable_auto_scaling
|
||||
if options.failover is not None:
|
||||
opts["failover"] = options.failover
|
||||
if opts:
|
||||
cmd["options"] = opts
|
||||
await self._client._command(cmd, session=session)
|
||||
|
||||
async def stop(
|
||||
self,
|
||||
*,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
"""Stop this stream processor.
|
||||
|
||||
Sends the ``stopStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is **not** retryable. The processor transitions to the
|
||||
``STOPPED`` state and can be restarted.
|
||||
|
||||
:param session: A
|
||||
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
await self._client._command({_Op.STOP_STREAM_PROCESSOR: self.name}, session=session)
|
||||
|
||||
async def drop(
|
||||
self,
|
||||
*,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
"""Permanently delete this stream processor.
|
||||
|
||||
Sends the ``dropStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is **not** retryable. A dropped processor cannot be
|
||||
recovered.
|
||||
|
||||
:param session: A
|
||||
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
await self._client._command({_Op.DROP_STREAM_PROCESSOR: self.name}, session=session)
|
||||
|
||||
async def stats(
|
||||
self,
|
||||
options: Optional[GetStreamProcessorStatsOptions] = None,
|
||||
*,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Return runtime statistics for this stream processor.
|
||||
|
||||
Sends the ``getStreamProcessorStats`` command to the ``admin`` database.
|
||||
This operation is a **retryable read**. The server returns an error if
|
||||
the processor is not in the ``STARTED`` state.
|
||||
|
||||
Unknown fields in the response are preserved — the raw response dict
|
||||
is returned unchanged so callers are not affected by server additions.
|
||||
|
||||
:param options: Optional
|
||||
:class:`~pymongo.stream_processing_options.GetStreamProcessorStatsOptions`
|
||||
controlling scale units and verbosity.
|
||||
:param session: A
|
||||
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
|
||||
for this operation.
|
||||
:returns: The raw server response document.
|
||||
"""
|
||||
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR_STATS: self.name}
|
||||
if options is not None:
|
||||
opts: dict[str, Any] = {}
|
||||
if options.scale is not None:
|
||||
opts["scale"] = options.scale
|
||||
if options.verbose is not None:
|
||||
opts["verbose"] = options.verbose
|
||||
if opts:
|
||||
cmd["options"] = opts
|
||||
return await self._client._command(cmd, retryable_read=True, session=session)
|
||||
|
||||
async def get_stream_processor_samples(
|
||||
self,
|
||||
options: Optional[GetStreamProcessorSamplesOptions] = None,
|
||||
*,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> GetStreamProcessorSamplesResult:
|
||||
"""Fetch one batch of sampled documents from a running stream processor.
|
||||
|
||||
Spec-literal entry point. Inspects ``options.cursor_id`` and routes to
|
||||
``startSampleStreamProcessor`` (initial call) or
|
||||
``getMoreSampleStreamProcessor`` (continuation). Most users should
|
||||
prefer :meth:`sample`, which wraps this in an async iterator.
|
||||
|
||||
A returned ``cursor_id`` of ``0`` means the cursor is exhausted; callers
|
||||
MUST NOT call this method again with that cursor id.
|
||||
|
||||
Sends to the ``admin`` database. Non-retryable.
|
||||
|
||||
:param options: Optional
|
||||
:class:`~pymongo.stream_processing_options.GetStreamProcessorSamplesOptions`.
|
||||
:param session: A
|
||||
:class:`~pymongo.asynchronous.client_session.AsyncClientSession` to use
|
||||
for this operation.
|
||||
:returns: A :class:`~pymongo.stream_processing_options.GetStreamProcessorSamplesResult`
|
||||
containing the batch of documents and the cursor id for the next call.
|
||||
"""
|
||||
if options is None:
|
||||
options = GetStreamProcessorSamplesOptions()
|
||||
|
||||
if options.cursor_id == 0:
|
||||
raise InvalidOperation("Sample cursor is exhausted; cursor_id 0 cannot be continued.")
|
||||
|
||||
if options.cursor_id is None:
|
||||
cmd: dict[str, Any] = {_Op.START_SAMPLE_STREAM_PROCESSOR: self.name}
|
||||
if options.limit is not None:
|
||||
cmd["limit"] = options.limit
|
||||
resp = await self._client._command(cmd, session=session)
|
||||
return GetStreamProcessorSamplesResult(
|
||||
cursor_id=int(resp["cursorId"]),
|
||||
documents=list(resp.get("firstBatch", [])),
|
||||
)
|
||||
else:
|
||||
cmd = {
|
||||
_Op.GET_MORE_SAMPLE_STREAM_PROCESSOR: self.name,
|
||||
"cursorId": options.cursor_id,
|
||||
}
|
||||
if options.batch_size is not None:
|
||||
cmd["batchSize"] = options.batch_size
|
||||
resp = await self._client._command(cmd, session=session)
|
||||
return GetStreamProcessorSamplesResult(
|
||||
cursor_id=int(resp["cursorId"]),
|
||||
documents=list(resp.get("nextBatch", resp.get("messages", []))),
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
*,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> AsyncSampleCursor:
|
||||
"""Open a sample cursor over this stream processor's output.
|
||||
|
||||
Returns an async iterator that yields sampled documents until the
|
||||
server-side cursor is exhausted. Internally drives the two-phase
|
||||
``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``
|
||||
protocol on the caller's behalf.
|
||||
|
||||
Usage::
|
||||
|
||||
async for doc in processor.sample(limit=100, batch_size=10):
|
||||
print(doc)
|
||||
|
||||
:param limit: Maximum number of documents to sample (sent only on the
|
||||
initial call).
|
||||
:param batch_size: Number of documents per continuation batch (sent
|
||||
only on subsequent calls).
|
||||
:param session: Optional :class:`AsyncClientSession` propagated to all
|
||||
underlying commands.
|
||||
"""
|
||||
return AsyncSampleCursor(
|
||||
processor=self,
|
||||
limit=limit,
|
||||
batch_size=batch_size,
|
||||
session=session,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSampleCursor:
|
||||
"""Async iterator over sampled stream processor output.
|
||||
|
||||
A custom two-phase cursor used to retrieve sampled documents from a
|
||||
running stream processor. This cursor MUST NOT be wrapped or re-used
|
||||
via the standard MongoDB ``Cursor`` types because it does not use the
|
||||
standard ``getMore`` command — it uses the dedicated
|
||||
``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``
|
||||
commands instead.
|
||||
|
||||
Obtained via :meth:`AsyncStreamProcessor.sample`. Iterate with
|
||||
``async for``.
|
||||
|
||||
The cursor is exhausted when the server returns ``cursorId: 0``;
|
||||
after that, no further wire calls are issued and iteration ends.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
processor: AsyncStreamProcessor,
|
||||
limit: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
) -> None:
|
||||
self._processor = processor
|
||||
self._limit = limit
|
||||
self._batch_size = batch_size
|
||||
self._session = session
|
||||
|
||||
self._buffer: list[Mapping[str, Any]] = []
|
||||
self._cursor_id: Optional[int] = None # None = not yet opened
|
||||
self._exhausted: bool = False
|
||||
self._closed: bool = False
|
||||
|
||||
@property
|
||||
def cursor_id(self) -> Optional[int]:
|
||||
"""Current server-side cursor id, or ``None`` if not yet opened.
|
||||
|
||||
A value of ``0`` indicates the cursor has been exhausted.
|
||||
"""
|
||||
return self._cursor_id
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""``True`` if more documents may be available; ``False`` once exhausted or closed."""
|
||||
return not self._exhausted and not self._closed
|
||||
|
||||
async def _refill(self) -> None:
|
||||
"""Fetch the next batch from the server. No-op if exhausted or closed."""
|
||||
if self._exhausted or self._closed:
|
||||
return
|
||||
|
||||
if self._cursor_id is None:
|
||||
opts = GetStreamProcessorSamplesOptions(
|
||||
cursor_id=None,
|
||||
limit=self._limit,
|
||||
batch_size=None,
|
||||
)
|
||||
else:
|
||||
opts = GetStreamProcessorSamplesOptions(
|
||||
cursor_id=self._cursor_id,
|
||||
limit=None,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
|
||||
result = await self._processor.get_stream_processor_samples(opts, session=self._session)
|
||||
self._cursor_id = result.cursor_id
|
||||
self._buffer.extend(result.documents)
|
||||
|
||||
# Spec: cursorId == 0 means exhausted. MUST NOT call getMore again.
|
||||
if result.cursor_id == 0:
|
||||
self._exhausted = True
|
||||
|
||||
def __aiter__(self) -> AsyncSampleCursor:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Mapping[str, Any]:
|
||||
if self._buffer:
|
||||
return self._buffer.pop(0)
|
||||
|
||||
if self._closed or self._exhausted:
|
||||
raise StopAsyncIteration
|
||||
|
||||
# Loop guards against an empty batch from the server with a non-zero
|
||||
# cursor id — keep pulling until we get documents or hit exhaustion.
|
||||
while not self._buffer and not self._exhausted:
|
||||
await self._refill()
|
||||
|
||||
if self._buffer:
|
||||
return self._buffer.pop(0)
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Mark the cursor closed locally.
|
||||
|
||||
Note: ASP does not currently expose a way to explicitly kill a
|
||||
sample cursor server-side. ``close`` only stops local iteration;
|
||||
the server-side cursor will be cleaned up on its own timeout or
|
||||
when the processor stops.
|
||||
"""
|
||||
self._closed = True
|
||||
|
||||
async def __aenter__(self) -> AsyncSampleCursor:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.close()
|
||||
@ -59,12 +59,17 @@ class _Op(str, enum.Enum):
|
||||
CREATE = "create"
|
||||
CREATE_INDEXES = "createIndexes"
|
||||
CREATE_SEARCH_INDEXES = "createSearchIndexes"
|
||||
CREATE_STREAM_PROCESSOR = "createStreamProcessor"
|
||||
DELETE = "delete"
|
||||
DISTINCT = "distinct"
|
||||
DROP = "drop"
|
||||
DROP_DATABASE = "dropDatabase"
|
||||
DROP_INDEXES = "dropIndexes"
|
||||
DROP_SEARCH_INDEXES = "dropSearchIndexes"
|
||||
DROP_STREAM_PROCESSOR = "dropStreamProcessor"
|
||||
GET_STREAM_PROCESSOR = "getStreamProcessor"
|
||||
GET_STREAM_PROCESSOR_STATS = "getStreamProcessorStats"
|
||||
GET_MORE_SAMPLE_STREAM_PROCESSOR = "getMoreSampleStreamProcessor"
|
||||
END_SESSIONS = "endSessions"
|
||||
FIND_AND_MODIFY = "findAndModify"
|
||||
FIND = "find"
|
||||
@ -77,6 +82,9 @@ class _Op(str, enum.Enum):
|
||||
UPDATE_INDEX = "updateIndex"
|
||||
UPDATE_SEARCH_INDEX = "updateSearchIndex"
|
||||
RENAME = "rename"
|
||||
START_STREAM_PROCESSOR = "startStreamProcessor"
|
||||
START_SAMPLE_STREAM_PROCESSOR = "startSampleStreamProcessor"
|
||||
STOP_STREAM_PROCESSOR = "stopStreamProcessor"
|
||||
GETMORE = "getMore"
|
||||
KILL_CURSORS = "killCursors"
|
||||
TEST = "testOperation"
|
||||
|
||||
193
pymongo/stream_processing_options.py
Normal file
193
pymongo/stream_processing_options.py
Normal file
@ -0,0 +1,193 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Options and result classes for Atlas Stream Processing commands.
|
||||
|
||||
These classes are shared between the synchronous and asynchronous APIs —
|
||||
no async/sync split is needed for plain dataclasses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import Timestamp
|
||||
|
||||
from pymongo.errors import InvalidOperation
|
||||
|
||||
_VALID_TIERS = {"SP2", "SP5", "SP10", "SP30", "SP50"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateStreamProcessorOptions:
|
||||
"""Options for :meth:`AsyncStreamProcessors.create`.
|
||||
|
||||
All fields are optional.
|
||||
"""
|
||||
|
||||
dlq: Optional[Mapping[str, Any]] = None
|
||||
stream_meta_field_name: Optional[str] = None
|
||||
tier: Optional[str] = None
|
||||
failover: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartStreamProcessorOptions:
|
||||
"""Options for :meth:`AsyncStreamProcessor.start`.
|
||||
|
||||
``start_after`` and ``start_at_operation_time`` are mutually exclusive.
|
||||
``tier``, when provided, must be one of ``"SP2"``, ``"SP5"``, ``"SP10"``,
|
||||
``"SP30"``, or ``"SP50"``.
|
||||
"""
|
||||
|
||||
workers: Optional[int] = None
|
||||
clear_checkpoints: Optional[bool] = None
|
||||
start_at_operation_time: Optional[Timestamp] = None
|
||||
start_after: Optional[Mapping[str, Any]] = None
|
||||
tier: Optional[str] = None
|
||||
enable_auto_scaling: Optional[bool] = None
|
||||
failover: Optional[Mapping[str, Any]] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.start_after is not None and self.start_at_operation_time is not None:
|
||||
raise InvalidOperation(
|
||||
"start_after and start_at_operation_time are mutually exclusive."
|
||||
)
|
||||
if self.tier is not None and self.tier not in _VALID_TIERS:
|
||||
raise InvalidOperation(
|
||||
f"Invalid tier {self.tier!r}. Must be one of: {sorted(_VALID_TIERS)}."
|
||||
)
|
||||
if self.workers is not None and self.workers <= 0:
|
||||
raise InvalidOperation("workers must be a positive integer.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetStreamProcessorStatsOptions:
|
||||
"""Options for :meth:`AsyncStreamProcessor.stats`.
|
||||
|
||||
:param scale: Size unit for byte-valued fields. ``1`` = bytes (default),
|
||||
``1024`` = kibibytes.
|
||||
:param verbose: If ``True``, include per-operator statistics.
|
||||
"""
|
||||
|
||||
scale: Optional[int] = None
|
||||
verbose: Optional[bool] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.scale is not None and self.scale <= 0:
|
||||
raise InvalidOperation("scale must be a positive integer.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetStreamProcessorSamplesOptions:
|
||||
"""Options for :meth:`AsyncStreamProcessor.get_stream_processor_samples`.
|
||||
|
||||
When ``cursor_id`` is absent or zero a new sample cursor is opened via
|
||||
``startSampleStreamProcessor``; ``limit`` is only sent on that initial
|
||||
call. When ``cursor_id`` is non-zero the next batch is fetched via
|
||||
``getMoreSampleStreamProcessor``; ``batch_size`` is only sent on
|
||||
subsequent calls.
|
||||
"""
|
||||
|
||||
cursor_id: Optional[int] = None
|
||||
limit: Optional[int] = None
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.cursor_id is not None and self.cursor_id < 0:
|
||||
raise InvalidOperation("cursor_id must be non-negative (0 means exhausted).")
|
||||
if self.limit is not None and self.limit < 0:
|
||||
raise InvalidOperation("limit must be non-negative.")
|
||||
if self.batch_size is not None and self.batch_size < 0:
|
||||
raise InvalidOperation("batch_size must be non-negative.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetStreamProcessorSamplesResult:
|
||||
"""Result from :meth:`AsyncStreamProcessor.get_stream_processor_samples`.
|
||||
|
||||
A ``cursor_id`` of ``0`` means the cursor is exhausted; no further calls
|
||||
should be made.
|
||||
"""
|
||||
|
||||
cursor_id: int
|
||||
documents: list[Mapping[str, Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamProcessorInfo:
|
||||
"""Information about a single stream processor.
|
||||
|
||||
Returned by :meth:`AsyncStreamProcessors.get_info`.
|
||||
|
||||
All fields from the ``getStreamProcessor`` server response are mapped to
|
||||
snake_case. The complete raw server document is preserved in :attr:`raw`
|
||||
so that unknown or future fields are not silently discarded.
|
||||
"""
|
||||
|
||||
id: Optional[str]
|
||||
name: str
|
||||
state: str # plain str — drivers MUST NOT hard-code this as an Enum
|
||||
pipeline: list[Mapping[str, Any]]
|
||||
pipeline_version: Optional[int]
|
||||
tier: Optional[str] = None
|
||||
dlq: Optional[Mapping[str, Any]] = None
|
||||
stream_meta_field_name: Optional[str] = None
|
||||
enable_auto_scaling: bool = False
|
||||
failover_enabled: bool = False
|
||||
active_region: str = ""
|
||||
last_modified_at: Optional[datetime] = None
|
||||
modified_by: Optional[str] = None
|
||||
last_state_change: Optional[datetime] = None
|
||||
last_heartbeat: Optional[datetime] = None
|
||||
has_started: bool = False
|
||||
stats: Optional[Mapping[str, Any]] = None
|
||||
error_msg: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
error_retryable: Optional[bool] = None
|
||||
raw: Mapping[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, doc: Mapping[str, Any]) -> StreamProcessorInfo:
|
||||
"""Construct a :class:`StreamProcessorInfo` from a server response document.
|
||||
|
||||
Maps camelCase server keys to Python snake_case fields and stashes the
|
||||
full *doc* in :attr:`raw` so no fields are silently dropped.
|
||||
"""
|
||||
return cls(
|
||||
id=doc.get("id"),
|
||||
name=doc["name"],
|
||||
state=doc["state"],
|
||||
pipeline=doc.get("pipeline", []),
|
||||
pipeline_version=doc.get("pipelineVersion"),
|
||||
tier=doc.get("tier"),
|
||||
dlq=doc.get("dlq"),
|
||||
stream_meta_field_name=doc.get("streamMetaFieldName"),
|
||||
enable_auto_scaling=doc.get("enableAutoScaling", False),
|
||||
failover_enabled=doc.get("failoverEnabled", False),
|
||||
active_region=doc.get("activeRegion", ""),
|
||||
last_modified_at=doc.get("lastModifiedAt"),
|
||||
modified_by=doc.get("modifiedBy"),
|
||||
last_state_change=doc.get("lastStateChange"),
|
||||
last_heartbeat=doc.get("lastHeartbeat"),
|
||||
has_started=doc.get("hasStarted", False),
|
||||
stats=doc.get("stats"),
|
||||
error_msg=doc.get("errorMsg"),
|
||||
error_code=doc.get("errorCode"),
|
||||
error_retryable=doc.get("errorRetryable"),
|
||||
raw=dict(doc),
|
||||
)
|
||||
633
pymongo/synchronous/stream_processing.py
Normal file
633
pymongo/synchronous/stream_processing.py
Normal file
@ -0,0 +1,633 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Atlas Stream Processing client for stream processing workspaces.
|
||||
|
||||
A stream processing workspace endpoint uses the ``mongodb://`` URI scheme with
|
||||
a hostname that follows this pattern::
|
||||
|
||||
atlas-stream-<workspaceId>-<suffix>.<region>.a.query.mongodb.net
|
||||
|
||||
For example::
|
||||
|
||||
mongodb://atlas-stream-699c842ef433fe6001480b17-etif1.virginia-usa.a.query.mongodb.net/
|
||||
|
||||
TLS is always required for workspace connections and cannot be disabled.
|
||||
``authSource=admin`` is applied by default when not explicitly set.
|
||||
|
||||
For commands not yet wrapped by this API, users can connect via a plain
|
||||
:class:`~pymongo.mongo_client.MongoClient` and call
|
||||
``run_command`` directly — that path remains fully supported.
|
||||
|
||||
Error handling
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
ASP commands raise :class:`pymongo.errors.OperationFailure` on server-side
|
||||
errors. The following error codes are known to be returned by Atlas Stream
|
||||
Processing commands:
|
||||
|
||||
================ ================= ====================================
|
||||
Code Name When returned
|
||||
================ ================= ====================================
|
||||
9 FailedToParse Invalid pipeline or command document
|
||||
72 InvalidOptions Invalid option values
|
||||
125 CommandFailed General command execution failure
|
||||
1 InternalError Unexpected server-side error
|
||||
================ ================= ====================================
|
||||
|
||||
This list is **non-exhaustive** and may grow as the server evolves. Drivers
|
||||
do not maintain a closed list of valid codes — applications should branch
|
||||
on ``exc.code`` only when they need to react to a specific known code, and
|
||||
should always have a generic fallback for unknown codes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.stream_processing_options import (
|
||||
CreateStreamProcessorOptions,
|
||||
GetStreamProcessorSamplesOptions,
|
||||
GetStreamProcessorSamplesResult,
|
||||
GetStreamProcessorStatsOptions,
|
||||
StartStreamProcessorOptions,
|
||||
StreamProcessorInfo,
|
||||
)
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.uri_parser_shared import SRV_SCHEME, _validate_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from pymongo.synchronous.client_session import ClientSession
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _is_workspace_endpoint(host: str) -> bool:
|
||||
"""Return True if *host* looks like an ASP workspace endpoint.
|
||||
|
||||
Workspace hostnames begin with ``atlas-stream-`` or end with
|
||||
``.a.query.mongodb.net``.
|
||||
"""
|
||||
return host.startswith("atlas-stream-") or host.endswith(".a.query.mongodb.net")
|
||||
|
||||
|
||||
class StreamProcessingClient:
|
||||
"""A client connected to an Atlas Stream Processing workspace.
|
||||
|
||||
Wraps :class:`~pymongo.mongo_client.MongoClient` with
|
||||
Atlas Stream Processing constraints:
|
||||
|
||||
* Only the ``mongodb://`` URI scheme is accepted (``mongodb+srv://`` is
|
||||
not supported for workspace endpoints).
|
||||
* TLS is always enabled and cannot be disabled.
|
||||
* ``authSource`` defaults to ``"admin"`` if not explicitly set.
|
||||
|
||||
Usage::
|
||||
|
||||
with StreamProcessingClient(
|
||||
"mongodb://atlas-stream-<id>.<region>.a.query.mongodb.net/",
|
||||
username="user",
|
||||
password="pass",
|
||||
) as client:
|
||||
sps = client.stream_processors()
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
host: Any = args[0] if args else kwargs.get("host")
|
||||
|
||||
uris: list[str] = []
|
||||
if isinstance(host, str):
|
||||
uris = [host]
|
||||
elif isinstance(host, (list, tuple)):
|
||||
uris = [h for h in host if isinstance(h, str)]
|
||||
|
||||
uri_has_auth_source = False
|
||||
for uri_str in uris:
|
||||
if not uri_str.startswith(("mongodb://", SRV_SCHEME)):
|
||||
# Plain hostname — no URI scheme to check.
|
||||
continue
|
||||
if uri_str.startswith(SRV_SCHEME):
|
||||
raise ConfigurationError(
|
||||
"StreamProcessingClient does not support mongodb+srv:// URIs; "
|
||||
"use mongodb:// with a workspace endpoint."
|
||||
)
|
||||
parsed = _validate_uri(uri_str, validate=True, warn=False, normalize=True)
|
||||
uri_opts = parsed["options"]
|
||||
if uri_opts.get("tls") is False:
|
||||
raise ConfigurationError(
|
||||
"TLS cannot be disabled for stream processing workspace connections."
|
||||
)
|
||||
if uri_opts.get("authsource") is not None:
|
||||
uri_has_auth_source = True
|
||||
|
||||
# Also reject explicit tls=False / ssl=False in kwargs.
|
||||
if kwargs.get("tls") is False or kwargs.get("ssl") is False:
|
||||
raise ConfigurationError(
|
||||
"TLS cannot be disabled for stream processing workspace connections."
|
||||
)
|
||||
|
||||
kwargs.pop("ssl", None)
|
||||
kwargs["tls"] = True
|
||||
|
||||
if not uri_has_auth_source and not any(k.lower() == "authsource" for k in kwargs):
|
||||
kwargs["authSource"] = "admin"
|
||||
|
||||
self._client: MongoClient[Any] = MongoClient(*args, **kwargs)
|
||||
|
||||
# NOTE: Per the ASP driver spec, server errors MUST be surfaced as-is.
|
||||
# Do NOT introduce error-code branching, rewrapping, or filtering anywhere
|
||||
# in this module. Known codes are documented at the module level for
|
||||
# reference only — they are not runtime invariants.
|
||||
def _command(
|
||||
self,
|
||||
cmd: dict[str, Any],
|
||||
*,
|
||||
retryable_read: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Send a top-level ASP command to the admin database.
|
||||
|
||||
Routes through the existing retry machinery: retryable reads use
|
||||
:meth:`~pymongo.database.Database._retryable_read_command`;
|
||||
everything else uses the standard
|
||||
:meth:`~pymongo.database.Database.command` path.
|
||||
|
||||
:param cmd: The command document.
|
||||
:param retryable_read: If ``True``, the command is sent as a retryable read.
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
admin = self._client._database_default_options("admin")
|
||||
if retryable_read:
|
||||
# The first key of the command document is the operation name,
|
||||
# which matches the corresponding _Op enum value string.
|
||||
operation = next(iter(cmd))
|
||||
return admin._retryable_read_command(cmd, session=session, operation=operation)
|
||||
return admin.command(cmd, session=session)
|
||||
|
||||
def stream_processors(self) -> StreamProcessors:
|
||||
"""Return a handle for managing stream processors in this workspace."""
|
||||
return StreamProcessors(self)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the underlying client and release all resources."""
|
||||
self._client.close()
|
||||
|
||||
def __enter__(self) -> StreamProcessingClient:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@property
|
||||
def address(self) -> Any:
|
||||
"""(host, port) of the current workspace endpoint, or None.
|
||||
|
||||
Delegates to the underlying
|
||||
:attr:`~pymongo.mongo_client.MongoClient.address`.
|
||||
"""
|
||||
return self._client.address
|
||||
|
||||
|
||||
class StreamProcessors:
|
||||
"""Handle for managing stream processors in a workspace.
|
||||
|
||||
Obtained via :meth:`StreamProcessingClient.stream_processors`.
|
||||
"""
|
||||
|
||||
def __init__(self, client: StreamProcessingClient) -> None:
|
||||
self._client = client
|
||||
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
pipeline: list[Mapping[str, Any]],
|
||||
options: Optional[CreateStreamProcessorOptions] = None,
|
||||
*,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
"""Create a new stream processor in the workspace.
|
||||
|
||||
Sends the ``createStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is **not** retryable.
|
||||
|
||||
:param name: The name of the stream processor to create.
|
||||
:param pipeline: The aggregation pipeline that defines the processor.
|
||||
:param options: Optional :class:`~pymongo.stream_processing_options.CreateStreamProcessorOptions`
|
||||
controlling DLQ, tier, and other settings.
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise InvalidOperation("Stream processor name must be a non-empty string.")
|
||||
if not pipeline:
|
||||
raise InvalidOperation("createStreamProcessor requires a non-empty pipeline.")
|
||||
cmd: dict[str, Any] = {
|
||||
_Op.CREATE_STREAM_PROCESSOR: name,
|
||||
"pipeline": list(pipeline),
|
||||
}
|
||||
if options is not None:
|
||||
opts: dict[str, Any] = {}
|
||||
if options.dlq is not None:
|
||||
opts["dlq"] = options.dlq
|
||||
if options.stream_meta_field_name is not None:
|
||||
opts["streamMetaFieldName"] = options.stream_meta_field_name
|
||||
if options.tier is not None:
|
||||
opts["tier"] = options.tier
|
||||
if options.failover is not None:
|
||||
opts["failover"] = options.failover
|
||||
if opts:
|
||||
cmd["options"] = opts
|
||||
self._client._command(cmd, session=session)
|
||||
|
||||
def get(self, name: str) -> StreamProcessor:
|
||||
"""Return a handle for an existing stream processor by name.
|
||||
|
||||
This is a pure factory method — it does **not** contact the server
|
||||
or verify that the processor exists.
|
||||
|
||||
:param name: The name of the stream processor.
|
||||
:returns: An :class:`StreamProcessor` handle.
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise InvalidOperation("Stream processor name must be a non-empty string.")
|
||||
return StreamProcessor(client=self._client, name=name)
|
||||
|
||||
def get_info(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> StreamProcessorInfo:
|
||||
"""Return information about a single stream processor.
|
||||
|
||||
Sends the ``getStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is a **retryable read**.
|
||||
|
||||
:param name: The name of the stream processor.
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession` to use
|
||||
for this operation.
|
||||
:returns: A :class:`~pymongo.stream_processing_options.StreamProcessorInfo`
|
||||
populated from the server response. Unknown server fields are preserved
|
||||
in :attr:`~pymongo.stream_processing_options.StreamProcessorInfo.raw`.
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise InvalidOperation("Stream processor name must be a non-empty string.")
|
||||
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR: name}
|
||||
response = self._client._command(cmd, retryable_read=True, session=session)
|
||||
doc = response.get("result", response)
|
||||
return StreamProcessorInfo.from_response(doc)
|
||||
|
||||
|
||||
class StreamProcessor:
|
||||
"""Handle for a specific named stream processor.
|
||||
|
||||
Does not imply the processor currently exists on the server.
|
||||
Obtain via :meth:`StreamProcessors.get`.
|
||||
"""
|
||||
|
||||
def __init__(self, *, client: StreamProcessingClient, name: str) -> None:
|
||||
if not name or not name.strip():
|
||||
raise InvalidOperation("Stream processor name must be a non-empty string.")
|
||||
self._client = client
|
||||
self.name = name
|
||||
|
||||
def start(
|
||||
self,
|
||||
options: Optional[StartStreamProcessorOptions] = None,
|
||||
*,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
"""Start this stream processor.
|
||||
|
||||
Sends the ``startStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is **not** retryable.
|
||||
|
||||
The processor must be in the ``STOPPED`` or ``FAILED`` state; starting
|
||||
an already-``STARTED`` processor returns a server error.
|
||||
|
||||
Mutual exclusivity of ``start_after`` / ``start_at_operation_time`` and
|
||||
tier validation are enforced by
|
||||
:class:`~pymongo.stream_processing_options.StartStreamProcessorOptions`
|
||||
at construction time.
|
||||
|
||||
:param options: Optional
|
||||
:class:`~pymongo.stream_processing_options.StartStreamProcessorOptions`.
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
cmd: dict[str, Any] = {_Op.START_STREAM_PROCESSOR: self.name}
|
||||
if options is not None:
|
||||
if options.workers is not None:
|
||||
cmd["workers"] = options.workers
|
||||
opts: dict[str, Any] = {}
|
||||
if options.clear_checkpoints is not None:
|
||||
opts["clearCheckpoints"] = options.clear_checkpoints
|
||||
if options.start_at_operation_time is not None:
|
||||
opts["startAtOperationTime"] = options.start_at_operation_time
|
||||
if options.start_after is not None:
|
||||
opts["startAfter"] = options.start_after
|
||||
if options.tier is not None:
|
||||
opts["tier"] = options.tier
|
||||
if options.enable_auto_scaling is not None:
|
||||
opts["enableAutoScaling"] = options.enable_auto_scaling
|
||||
if options.failover is not None:
|
||||
opts["failover"] = options.failover
|
||||
if opts:
|
||||
cmd["options"] = opts
|
||||
self._client._command(cmd, session=session)
|
||||
|
||||
def stop(
|
||||
self,
|
||||
*,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
"""Stop this stream processor.
|
||||
|
||||
Sends the ``stopStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is **not** retryable. The processor transitions to the
|
||||
``STOPPED`` state and can be restarted.
|
||||
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
self._client._command({_Op.STOP_STREAM_PROCESSOR: self.name}, session=session)
|
||||
|
||||
def drop(
|
||||
self,
|
||||
*,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
"""Permanently delete this stream processor.
|
||||
|
||||
Sends the ``dropStreamProcessor`` command to the ``admin`` database.
|
||||
This operation is **not** retryable. A dropped processor cannot be
|
||||
recovered.
|
||||
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession` to use
|
||||
for this operation.
|
||||
"""
|
||||
self._client._command({_Op.DROP_STREAM_PROCESSOR: self.name}, session=session)
|
||||
|
||||
def stats(
|
||||
self,
|
||||
options: Optional[GetStreamProcessorStatsOptions] = None,
|
||||
*,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Return runtime statistics for this stream processor.
|
||||
|
||||
Sends the ``getStreamProcessorStats`` command to the ``admin`` database.
|
||||
This operation is a **retryable read**. The server returns an error if
|
||||
the processor is not in the ``STARTED`` state.
|
||||
|
||||
Unknown fields in the response are preserved — the raw response dict
|
||||
is returned unchanged so callers are not affected by server additions.
|
||||
|
||||
:param options: Optional
|
||||
:class:`~pymongo.stream_processing_options.GetStreamProcessorStatsOptions`
|
||||
controlling scale units and verbosity.
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession` to use
|
||||
for this operation.
|
||||
:returns: The raw server response document.
|
||||
"""
|
||||
cmd: dict[str, Any] = {_Op.GET_STREAM_PROCESSOR_STATS: self.name}
|
||||
if options is not None:
|
||||
opts: dict[str, Any] = {}
|
||||
if options.scale is not None:
|
||||
opts["scale"] = options.scale
|
||||
if options.verbose is not None:
|
||||
opts["verbose"] = options.verbose
|
||||
if opts:
|
||||
cmd["options"] = opts
|
||||
return self._client._command(cmd, retryable_read=True, session=session)
|
||||
|
||||
def get_stream_processor_samples(
|
||||
self,
|
||||
options: Optional[GetStreamProcessorSamplesOptions] = None,
|
||||
*,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> GetStreamProcessorSamplesResult:
|
||||
"""Fetch one batch of sampled documents from a running stream processor.
|
||||
|
||||
Spec-literal entry point. Inspects ``options.cursor_id`` and routes to
|
||||
``startSampleStreamProcessor`` (initial call) or
|
||||
``getMoreSampleStreamProcessor`` (continuation). Most users should
|
||||
prefer :meth:`sample`, which wraps this in an async iterator.
|
||||
|
||||
A returned ``cursor_id`` of ``0`` means the cursor is exhausted; callers
|
||||
MUST NOT call this method again with that cursor id.
|
||||
|
||||
Sends to the ``admin`` database. Non-retryable.
|
||||
|
||||
:param options: Optional
|
||||
:class:`~pymongo.stream_processing_options.GetStreamProcessorSamplesOptions`.
|
||||
:param session: A
|
||||
:class:`~pymongo.client_session.ClientSession` to use
|
||||
for this operation.
|
||||
:returns: A :class:`~pymongo.stream_processing_options.GetStreamProcessorSamplesResult`
|
||||
containing the batch of documents and the cursor id for the next call.
|
||||
"""
|
||||
if options is None:
|
||||
options = GetStreamProcessorSamplesOptions()
|
||||
|
||||
if options.cursor_id == 0:
|
||||
raise InvalidOperation("Sample cursor is exhausted; cursor_id 0 cannot be continued.")
|
||||
|
||||
if options.cursor_id is None:
|
||||
cmd: dict[str, Any] = {_Op.START_SAMPLE_STREAM_PROCESSOR: self.name}
|
||||
if options.limit is not None:
|
||||
cmd["limit"] = options.limit
|
||||
resp = self._client._command(cmd, session=session)
|
||||
return GetStreamProcessorSamplesResult(
|
||||
cursor_id=int(resp["cursorId"]),
|
||||
documents=list(resp.get("firstBatch", [])),
|
||||
)
|
||||
else:
|
||||
cmd = {
|
||||
_Op.GET_MORE_SAMPLE_STREAM_PROCESSOR: self.name,
|
||||
"cursorId": options.cursor_id,
|
||||
}
|
||||
if options.batch_size is not None:
|
||||
cmd["batchSize"] = options.batch_size
|
||||
resp = self._client._command(cmd, session=session)
|
||||
return GetStreamProcessorSamplesResult(
|
||||
cursor_id=int(resp["cursorId"]),
|
||||
documents=list(resp.get("nextBatch", resp.get("messages", []))),
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
*,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> SampleCursor:
|
||||
"""Open a sample cursor over this stream processor's output.
|
||||
|
||||
Returns an async iterator that yields sampled documents until the
|
||||
server-side cursor is exhausted. Internally drives the two-phase
|
||||
``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``
|
||||
protocol on the caller's behalf.
|
||||
|
||||
Usage::
|
||||
|
||||
for doc in processor.sample(limit=100, batch_size=10):
|
||||
print(doc)
|
||||
|
||||
:param limit: Maximum number of documents to sample (sent only on the
|
||||
initial call).
|
||||
:param batch_size: Number of documents per continuation batch (sent
|
||||
only on subsequent calls).
|
||||
:param session: Optional :class:`ClientSession` propagated to all
|
||||
underlying commands.
|
||||
"""
|
||||
return SampleCursor(
|
||||
processor=self,
|
||||
limit=limit,
|
||||
batch_size=batch_size,
|
||||
session=session,
|
||||
)
|
||||
|
||||
|
||||
class SampleCursor:
|
||||
"""Async iterator over sampled stream processor output.
|
||||
|
||||
A custom two-phase cursor used to retrieve sampled documents from a
|
||||
running stream processor. This cursor MUST NOT be wrapped or re-used
|
||||
via the standard MongoDB ``Cursor`` types because it does not use the
|
||||
standard ``getMore`` command — it uses the dedicated
|
||||
``startSampleStreamProcessor`` / ``getMoreSampleStreamProcessor``
|
||||
commands instead.
|
||||
|
||||
Obtained via :meth:`StreamProcessor.sample`. Iterate with
|
||||
``for``.
|
||||
|
||||
The cursor is exhausted when the server returns ``cursorId: 0``;
|
||||
after that, no further wire calls are issued and iteration ends.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
processor: StreamProcessor,
|
||||
limit: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
self._processor = processor
|
||||
self._limit = limit
|
||||
self._batch_size = batch_size
|
||||
self._session = session
|
||||
|
||||
self._buffer: list[Mapping[str, Any]] = []
|
||||
self._cursor_id: Optional[int] = None # None = not yet opened
|
||||
self._exhausted: bool = False
|
||||
self._closed: bool = False
|
||||
|
||||
@property
|
||||
def cursor_id(self) -> Optional[int]:
|
||||
"""Current server-side cursor id, or ``None`` if not yet opened.
|
||||
|
||||
A value of ``0`` indicates the cursor has been exhausted.
|
||||
"""
|
||||
return self._cursor_id
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""``True`` if more documents may be available; ``False`` once exhausted or closed."""
|
||||
return not self._exhausted and not self._closed
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""Fetch the next batch from the server. No-op if exhausted or closed."""
|
||||
if self._exhausted or self._closed:
|
||||
return
|
||||
|
||||
if self._cursor_id is None:
|
||||
opts = GetStreamProcessorSamplesOptions(
|
||||
cursor_id=None,
|
||||
limit=self._limit,
|
||||
batch_size=None,
|
||||
)
|
||||
else:
|
||||
opts = GetStreamProcessorSamplesOptions(
|
||||
cursor_id=self._cursor_id,
|
||||
limit=None,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
|
||||
result = self._processor.get_stream_processor_samples(opts, session=self._session)
|
||||
self._cursor_id = result.cursor_id
|
||||
self._buffer.extend(result.documents)
|
||||
|
||||
# Spec: cursorId == 0 means exhausted. MUST NOT call getMore again.
|
||||
if result.cursor_id == 0:
|
||||
self._exhausted = True
|
||||
|
||||
def __iter__(self) -> SampleCursor:
|
||||
return self
|
||||
|
||||
def __next__(self) -> Mapping[str, Any]:
|
||||
if self._buffer:
|
||||
return self._buffer.pop(0)
|
||||
|
||||
if self._closed or self._exhausted:
|
||||
raise StopIteration
|
||||
|
||||
# Loop guards against an empty batch from the server with a non-zero
|
||||
# cursor id — keep pulling until we get documents or hit exhaustion.
|
||||
while not self._buffer and not self._exhausted:
|
||||
self._refill()
|
||||
|
||||
if self._buffer:
|
||||
return self._buffer.pop(0)
|
||||
|
||||
raise StopIteration
|
||||
|
||||
def close(self) -> None:
|
||||
"""Mark the cursor closed locally.
|
||||
|
||||
Note: ASP does not currently expose a way to explicitly kill a
|
||||
sample cursor server-side. ``close`` only stops local iteration;
|
||||
the server-side cursor will be cleaned up on its own timeout or
|
||||
when the processor stops.
|
||||
"""
|
||||
self._closed = True
|
||||
|
||||
def __enter__(self) -> SampleCursor:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.close()
|
||||
720
test/asynchronous/test_stream_processing.py
Normal file
720
test/asynchronous/test_stream_processing.py
Normal file
@ -0,0 +1,720 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the Atlas Stream Processing driver module.
|
||||
|
||||
All tests run offline — no live workspace is required. The wire layer is
|
||||
stubbed out by replacing ``_command`` on the client with a lightweight spy
|
||||
that records calls and returns pre-configured responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
import unittest
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Mapping, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
import pymongo.asynchronous.stream_processing
|
||||
from bson import Timestamp
|
||||
from pymongo import (
|
||||
AsyncSampleCursor,
|
||||
AsyncStreamProcessingClient,
|
||||
AsyncStreamProcessor,
|
||||
AsyncStreamProcessors,
|
||||
CreateStreamProcessorOptions,
|
||||
GetStreamProcessorSamplesOptions,
|
||||
GetStreamProcessorSamplesResult,
|
||||
GetStreamProcessorStatsOptions,
|
||||
StartStreamProcessorOptions,
|
||||
StreamProcessorInfo,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spy helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _spy_client(
|
||||
responses: Optional[list[Mapping[str, Any]]] = None,
|
||||
raises: Optional[Exception] = None,
|
||||
) -> tuple[AsyncStreamProcessingClient, list[dict[str, Any]]]:
|
||||
"""Return *(client, calls)* where ``client._command`` records every call.
|
||||
|
||||
*responses* is consumed in order; when exhausted, ``{}`` is returned.
|
||||
If *raises* is set, the spy raises that exception on every call instead
|
||||
of returning a response.
|
||||
"""
|
||||
calls: list[dict[str, Any]] = []
|
||||
resp_list: list[Mapping[str, Any]] = list(responses) if responses is not None else [{}]
|
||||
|
||||
async def _command(
|
||||
cmd: dict[str, Any],
|
||||
*,
|
||||
retryable_read: bool = False,
|
||||
session: Any = None,
|
||||
) -> Mapping[str, Any]:
|
||||
calls.append({"cmd": dict(cmd), "retryable_read": retryable_read, "session": session})
|
||||
if raises is not None:
|
||||
raise raises
|
||||
return resp_list.pop(0) if resp_list else {}
|
||||
|
||||
client = AsyncStreamProcessingClient.__new__(AsyncStreamProcessingClient)
|
||||
client._client = MagicMock()
|
||||
client._command = _command
|
||||
return client, calls
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal info response fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INFO_RESPONSE: dict[str, Any] = {
|
||||
"ok": 1,
|
||||
"id": "abc",
|
||||
"name": "demo",
|
||||
"state": "CREATED",
|
||||
"pipeline": [],
|
||||
"pipelineVersion": 1,
|
||||
"enableAutoScaling": False,
|
||||
"failoverEnabled": False,
|
||||
"activeRegion": "us-east-1",
|
||||
"hasStarted": False,
|
||||
}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 1 — client constructor validation
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamProcessingClientConfig(unittest.TestCase):
|
||||
"""Constructor-level validation for AsyncStreamProcessingClient."""
|
||||
|
||||
def test_rejects_srv_uri(self) -> None:
|
||||
with self.assertRaises(ConfigurationError) as cm:
|
||||
AsyncStreamProcessingClient("mongodb+srv://example.com/")
|
||||
self.assertIn("mongodb+srv", str(cm.exception).lower())
|
||||
|
||||
def test_rejects_tls_false_kwarg(self) -> None:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
AsyncStreamProcessingClient("mongodb://example.com/", tls=False)
|
||||
|
||||
def test_rejects_ssl_false_kwarg(self) -> None:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
AsyncStreamProcessingClient("mongodb://example.com/", ssl=False)
|
||||
|
||||
def test_rejects_tls_false_in_uri(self) -> None:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
AsyncStreamProcessingClient("mongodb://example.com/?tls=false")
|
||||
|
||||
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
||||
def test_forces_tls_true(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
AsyncStreamProcessingClient("mongodb://example.com/")
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
self.assertTrue(kwargs.get("tls"))
|
||||
|
||||
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
||||
def test_defaults_authsource_to_admin(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
AsyncStreamProcessingClient("mongodb://example.com/")
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
self.assertEqual(kwargs.get("authSource"), "admin")
|
||||
|
||||
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
||||
def test_preserves_explicit_authsource_in_kwarg(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
AsyncStreamProcessingClient("mongodb://example.com/", authSource="other")
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
# Our default injection must not overwrite a user-supplied value.
|
||||
self.assertEqual(kwargs.get("authSource"), "other")
|
||||
|
||||
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
||||
def test_preserves_explicit_authsource_in_uri(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
AsyncStreamProcessingClient("mongodb://example.com/?authSource=other")
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
# authSource is already in the URI; we must not inject a duplicate kwarg.
|
||||
self.assertNotIn("authSource", kwargs)
|
||||
|
||||
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
||||
def test_drops_ssl_kwarg_to_avoid_duplicate(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
AsyncStreamProcessingClient("mongodb://example.com/", ssl=True)
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
self.assertNotIn("ssl", kwargs)
|
||||
self.assertTrue(kwargs.get("tls"))
|
||||
|
||||
def test_workspace_endpoint_detection(self) -> None:
|
||||
from pymongo.asynchronous.stream_processing import _is_workspace_endpoint
|
||||
|
||||
self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net"))
|
||||
self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net"))
|
||||
self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net"))
|
||||
self.assertFalse(_is_workspace_endpoint("localhost"))
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 2 — options dataclass validation
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamProcessingOptions(unittest.TestCase):
|
||||
"""Dataclass validation in stream_processing_options."""
|
||||
|
||||
def test_start_options_mutex_start_after_and_at_op_time(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
StartStreamProcessorOptions(
|
||||
start_after={"_id": 1},
|
||||
start_at_operation_time=Timestamp(0, 0),
|
||||
)
|
||||
|
||||
def test_start_options_invalid_tier(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
StartStreamProcessorOptions(tier="SP1")
|
||||
for valid_tier in ("SP2", "SP5", "SP10", "SP30", "SP50"):
|
||||
StartStreamProcessorOptions(tier=valid_tier) # must not raise
|
||||
|
||||
def test_start_options_invalid_workers(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
StartStreamProcessorOptions(workers=0)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
StartStreamProcessorOptions(workers=-1)
|
||||
StartStreamProcessorOptions(workers=1) # must not raise
|
||||
|
||||
def test_stats_options_invalid_scale(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorStatsOptions(scale=0)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorStatsOptions(scale=-1)
|
||||
GetStreamProcessorStatsOptions(scale=1)
|
||||
GetStreamProcessorStatsOptions(scale=1024)
|
||||
|
||||
def test_samples_options_invalid_cursor_id(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorSamplesOptions(cursor_id=-1)
|
||||
# cursor_id=0 is valid at construction; the method-level check rejects it.
|
||||
GetStreamProcessorSamplesOptions(cursor_id=0)
|
||||
GetStreamProcessorSamplesOptions(cursor_id=42)
|
||||
|
||||
def test_samples_options_invalid_limit(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorSamplesOptions(limit=-1)
|
||||
GetStreamProcessorSamplesOptions(limit=0)
|
||||
GetStreamProcessorSamplesOptions(limit=10)
|
||||
|
||||
def test_samples_options_invalid_batch_size(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorSamplesOptions(batch_size=-1)
|
||||
GetStreamProcessorSamplesOptions(batch_size=0)
|
||||
GetStreamProcessorSamplesOptions(batch_size=5)
|
||||
|
||||
def test_create_options_all_optional(self) -> None:
|
||||
opts = CreateStreamProcessorOptions()
|
||||
self.assertIsNone(opts.dlq)
|
||||
self.assertIsNone(opts.stream_meta_field_name)
|
||||
self.assertIsNone(opts.tier)
|
||||
self.assertIsNone(opts.failover)
|
||||
|
||||
def test_stream_processor_info_from_response_full(self) -> None:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
doc: dict[str, Any] = {
|
||||
"ok": 1,
|
||||
"id": "abc123",
|
||||
"name": "demo",
|
||||
"state": "STARTED",
|
||||
"pipeline": [{"$source": {}}],
|
||||
"pipelineVersion": 3,
|
||||
"tier": "SP10",
|
||||
"dlq": {"connectionName": "c1"},
|
||||
"streamMetaFieldName": "ts",
|
||||
"enableAutoScaling": True,
|
||||
"failoverEnabled": True,
|
||||
"activeRegion": "us-east-1",
|
||||
"lastModifiedAt": now,
|
||||
"modifiedBy": "user@example.com",
|
||||
"lastStateChange": now,
|
||||
"lastHeartbeat": now,
|
||||
"hasStarted": True,
|
||||
"stats": {"inputMessageCount": 42},
|
||||
"errorMsg": None,
|
||||
"errorCode": None,
|
||||
"errorRetryable": None,
|
||||
}
|
||||
info = StreamProcessorInfo.from_response(doc)
|
||||
self.assertEqual(info.id, "abc123")
|
||||
self.assertEqual(info.name, "demo")
|
||||
self.assertEqual(info.state, "STARTED")
|
||||
self.assertEqual(info.pipeline, [{"$source": {}}])
|
||||
self.assertEqual(info.pipeline_version, 3)
|
||||
self.assertEqual(info.tier, "SP10")
|
||||
self.assertEqual(info.dlq, {"connectionName": "c1"})
|
||||
self.assertEqual(info.stream_meta_field_name, "ts")
|
||||
self.assertTrue(info.enable_auto_scaling)
|
||||
self.assertTrue(info.failover_enabled)
|
||||
self.assertEqual(info.active_region, "us-east-1")
|
||||
self.assertTrue(info.has_started)
|
||||
self.assertEqual(info.stats, {"inputMessageCount": 42})
|
||||
self.assertEqual(info.raw, doc)
|
||||
|
||||
def test_stream_processor_info_from_response_minimal(self) -> None:
|
||||
doc: dict[str, Any] = {
|
||||
"id": "x",
|
||||
"name": "p",
|
||||
"state": "CREATED",
|
||||
"pipeline": [],
|
||||
"pipelineVersion": 1,
|
||||
}
|
||||
info = StreamProcessorInfo.from_response(doc)
|
||||
self.assertEqual(info.id, "x")
|
||||
self.assertIsNone(info.tier)
|
||||
self.assertIsNone(info.dlq)
|
||||
self.assertFalse(info.enable_auto_scaling)
|
||||
self.assertFalse(info.has_started)
|
||||
self.assertEqual(info.raw, doc)
|
||||
|
||||
def test_stream_processor_info_preserves_unknown_fields(self) -> None:
|
||||
doc: dict[str, Any] = {
|
||||
"id": "x",
|
||||
"name": "p",
|
||||
"state": "CREATED",
|
||||
"pipeline": [],
|
||||
"pipelineVersion": 1,
|
||||
"futureField": "someValue",
|
||||
}
|
||||
info = StreamProcessorInfo.from_response(doc)
|
||||
self.assertEqual(info.raw["futureField"], "someValue")
|
||||
|
||||
def test_stream_processor_info_state_is_string(self) -> None:
|
||||
doc: dict[str, Any] = {
|
||||
"id": "x",
|
||||
"name": "p",
|
||||
"state": "FUTURE_UNKNOWN_STATE",
|
||||
"pipeline": [],
|
||||
"pipelineVersion": 1,
|
||||
}
|
||||
info = StreamProcessorInfo.from_response(doc)
|
||||
self.assertEqual(info.state, "FUTURE_UNKNOWN_STATE")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 3 — lifecycle commands
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class AsyncTestStreamProcessorsCommands(unittest.IsolatedAsyncioTestCase):
|
||||
"""Lifecycle commands on AsyncStreamProcessors and AsyncStreamProcessor."""
|
||||
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.client, self.calls = _spy_client(responses=[{"ok": 1}] * 30)
|
||||
self.sps = AsyncStreamProcessors(self.client)
|
||||
self.proc = AsyncStreamProcessor(client=self.client, name="demo")
|
||||
|
||||
async def test_create_sends_correct_command(self) -> None:
|
||||
await self.sps.create("demo", pipeline=[{"$source": {}}])
|
||||
entry = self.calls[-1]
|
||||
self.assertEqual(entry["cmd"].get("createStreamProcessor"), "demo")
|
||||
self.assertEqual(entry["cmd"].get("pipeline"), [{"$source": {}}])
|
||||
self.assertNotIn("options", entry["cmd"])
|
||||
self.assertFalse(entry["retryable_read"])
|
||||
|
||||
async def test_create_with_options_includes_them(self) -> None:
|
||||
opts = CreateStreamProcessorOptions(dlq={"connectionName": "c"}, tier="SP10")
|
||||
await self.sps.create("demo", pipeline=[{"$source": {}}], options=opts)
|
||||
cmd = self.calls[-1]["cmd"]
|
||||
self.assertEqual(cmd["options"]["dlq"], {"connectionName": "c"})
|
||||
self.assertEqual(cmd["options"]["tier"], "SP10")
|
||||
|
||||
async def test_create_rejects_empty_name(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await self.sps.create("", pipeline=[{"$x": 1}])
|
||||
self.assertEqual(len(self.calls), 0)
|
||||
|
||||
async def test_create_rejects_whitespace_name(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await self.sps.create(" ", pipeline=[{"$x": 1}])
|
||||
self.assertEqual(len(self.calls), 0)
|
||||
|
||||
async def test_create_rejects_empty_pipeline(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await self.sps.create("demo", pipeline=[])
|
||||
self.assertEqual(len(self.calls), 0)
|
||||
|
||||
async def test_get_returns_handle_without_calling_server(self) -> None:
|
||||
proc = self.sps.get("demo")
|
||||
self.assertEqual(len(self.calls), 0)
|
||||
self.assertEqual(proc.name, "demo")
|
||||
self.assertIsInstance(proc, AsyncStreamProcessor)
|
||||
|
||||
async def test_get_rejects_empty_name(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
self.sps.get("")
|
||||
|
||||
async def test_get_info_sends_correct_command_and_decodes(self) -> None:
|
||||
client, calls = _spy_client(responses=[dict(_INFO_RESPONSE)])
|
||||
info = await AsyncStreamProcessors(client).get_info("demo")
|
||||
self.assertEqual(calls[0]["cmd"].get("getStreamProcessor"), "demo")
|
||||
self.assertTrue(calls[0]["retryable_read"])
|
||||
self.assertIsInstance(info, StreamProcessorInfo)
|
||||
self.assertEqual(info.id, "abc")
|
||||
self.assertEqual(info.state, "CREATED")
|
||||
|
||||
async def test_get_info_preserves_unknown_response_fields(self) -> None:
|
||||
resp = dict(_INFO_RESPONSE, futureField="value")
|
||||
client, _ = _spy_client(responses=[resp])
|
||||
info = await AsyncStreamProcessors(client).get_info("demo")
|
||||
self.assertEqual(info.raw["futureField"], "value")
|
||||
|
||||
async def test_start_sends_correct_command(self) -> None:
|
||||
await self.proc.start()
|
||||
entry = self.calls[-1]
|
||||
self.assertEqual(entry["cmd"].get("startStreamProcessor"), "demo")
|
||||
self.assertNotIn("workers", entry["cmd"])
|
||||
self.assertNotIn("options", entry["cmd"])
|
||||
self.assertFalse(entry["retryable_read"])
|
||||
|
||||
async def test_start_with_options_includes_them(self) -> None:
|
||||
opts = StartStreamProcessorOptions(workers=2, clear_checkpoints=True, tier="SP10")
|
||||
await self.proc.start(options=opts)
|
||||
cmd = self.calls[-1]["cmd"]
|
||||
self.assertEqual(cmd.get("workers"), 2)
|
||||
self.assertTrue(cmd["options"]["clearCheckpoints"])
|
||||
self.assertEqual(cmd["options"]["tier"], "SP10")
|
||||
|
||||
async def test_stop_sends_correct_command(self) -> None:
|
||||
await self.proc.stop()
|
||||
entry = self.calls[-1]
|
||||
self.assertEqual(entry["cmd"].get("stopStreamProcessor"), "demo")
|
||||
self.assertFalse(entry["retryable_read"])
|
||||
|
||||
async def test_drop_sends_correct_command(self) -> None:
|
||||
await self.proc.drop()
|
||||
entry = self.calls[-1]
|
||||
self.assertEqual(entry["cmd"].get("dropStreamProcessor"), "demo")
|
||||
self.assertFalse(entry["retryable_read"])
|
||||
|
||||
async def test_stats_sends_correct_command_and_returns_raw_dict(self) -> None:
|
||||
raw_resp: dict[str, Any] = {"ok": 1, "stats": {"inputMessageCount": 5}, "futureField": "x"}
|
||||
client, calls = _spy_client(responses=[raw_resp])
|
||||
result = await AsyncStreamProcessor(client=client, name="demo").stats()
|
||||
self.assertEqual(calls[0]["cmd"].get("getStreamProcessorStats"), "demo")
|
||||
self.assertTrue(calls[0]["retryable_read"])
|
||||
self.assertEqual(result, raw_resp)
|
||||
|
||||
async def test_stats_with_options(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"ok": 1}])
|
||||
opts = GetStreamProcessorStatsOptions(scale=1024, verbose=True)
|
||||
await AsyncStreamProcessor(client=client, name="demo").stats(options=opts)
|
||||
self.assertEqual(calls[0]["cmd"]["options"]["scale"], 1024)
|
||||
self.assertTrue(calls[0]["cmd"]["options"]["verbose"])
|
||||
|
||||
async def test_session_propagates(self) -> None:
|
||||
fake_session = MagicMock()
|
||||
await self.proc.start(session=fake_session)
|
||||
self.assertIs(self.calls[-1]["session"], fake_session)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 4 — sample cursor
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase):
|
||||
"""Two-phase sample cursor protocol."""
|
||||
|
||||
async def test_get_samples_initial_call_sends_start_command(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"x": 1}]}])
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
result = await proc.get_stream_processor_samples()
|
||||
cmd = calls[0]["cmd"]
|
||||
self.assertIn("startSampleStreamProcessor", cmd)
|
||||
self.assertNotIn("cursorId", cmd)
|
||||
self.assertNotIn("batchSize", cmd)
|
||||
self.assertFalse(calls[0]["retryable_read"])
|
||||
self.assertEqual(result.cursor_id, 42)
|
||||
self.assertEqual(result.documents, [{"x": 1}])
|
||||
|
||||
async def test_get_samples_initial_call_includes_limit(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 1, "firstBatch": []}])
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
await proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(limit=100))
|
||||
self.assertEqual(calls[0]["cmd"].get("limit"), 100)
|
||||
self.assertNotIn("batchSize", calls[0]["cmd"])
|
||||
|
||||
async def test_get_samples_continuation_sends_get_more(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": [{"y": 2}]}])
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
result = await proc.get_stream_processor_samples(
|
||||
GetStreamProcessorSamplesOptions(cursor_id=42)
|
||||
)
|
||||
cmd = calls[0]["cmd"]
|
||||
self.assertIn("getMoreSampleStreamProcessor", cmd)
|
||||
self.assertEqual(cmd.get("cursorId"), 42)
|
||||
self.assertNotIn("limit", cmd)
|
||||
self.assertEqual(result.cursor_id, 0)
|
||||
self.assertEqual(result.documents, [{"y": 2}])
|
||||
|
||||
async def test_get_samples_continuation_includes_batch_size(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": []}])
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
await proc.get_stream_processor_samples(
|
||||
GetStreamProcessorSamplesOptions(cursor_id=42, batch_size=5)
|
||||
)
|
||||
self.assertEqual(calls[0]["cmd"].get("batchSize"), 5)
|
||||
|
||||
async def test_get_samples_rejects_cursor_id_zero(self) -> None:
|
||||
client, calls = _spy_client()
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=0))
|
||||
self.assertEqual(len(calls), 0)
|
||||
|
||||
async def test_sample_iterator_drains_single_batch(self) -> None:
|
||||
client, calls = _spy_client(
|
||||
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}]
|
||||
)
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
docs = []
|
||||
async for doc in proc.sample():
|
||||
docs.append(doc)
|
||||
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
|
||||
self.assertEqual(len(calls), 1)
|
||||
|
||||
async def test_sample_iterator_drains_multiple_batches(self) -> None:
|
||||
client, calls = _spy_client(
|
||||
responses=[
|
||||
{"cursorId": 11, "firstBatch": [{"a": 1}]},
|
||||
{"cursorId": 22, "nextBatch": [{"a": 2}]},
|
||||
{"cursorId": 0, "nextBatch": [{"a": 3}]},
|
||||
]
|
||||
)
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
docs = []
|
||||
async for doc in proc.sample():
|
||||
docs.append(doc)
|
||||
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
|
||||
self.assertEqual(len(calls), 3)
|
||||
|
||||
async def test_sample_iterator_handles_empty_batch_with_nonzero_cursor(self) -> None:
|
||||
client, calls = _spy_client(
|
||||
responses=[
|
||||
{"cursorId": 11, "firstBatch": []},
|
||||
{"cursorId": 22, "nextBatch": []},
|
||||
{"cursorId": 0, "nextBatch": [{"a": 1}]},
|
||||
]
|
||||
)
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
docs = []
|
||||
async for doc in proc.sample():
|
||||
docs.append(doc)
|
||||
self.assertEqual(docs, [{"a": 1}])
|
||||
self.assertEqual(len(calls), 3)
|
||||
|
||||
async def test_sample_iterator_stops_after_cursor_id_zero(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
cursor = proc.sample()
|
||||
first = await cursor.__anext__()
|
||||
self.assertEqual(first, {"a": 1})
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await cursor.__anext__()
|
||||
# Only one wire call — the server returned cursorId:0 so we stopped.
|
||||
self.assertEqual(len(calls), 1)
|
||||
|
||||
async def test_sample_cursor_close_stops_iteration(self) -> None:
|
||||
client, calls = _spy_client()
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
cursor = proc.sample()
|
||||
await cursor.close()
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await cursor.__anext__()
|
||||
self.assertEqual(len(calls), 0)
|
||||
|
||||
async def test_sample_cursor_alive_property(self) -> None:
|
||||
client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
cursor = proc.sample()
|
||||
self.assertTrue(cursor.alive)
|
||||
doc = await cursor.__anext__()
|
||||
self.assertEqual(doc, {"a": 1})
|
||||
# After exhaustion (cursorId:0 returned during refill), alive is False.
|
||||
self.assertFalse(cursor.alive)
|
||||
|
||||
async def test_sample_cursor_id_property(self) -> None:
|
||||
client, _ = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"a": 1}]}])
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
cursor = proc.sample()
|
||||
self.assertIsNone(cursor.cursor_id)
|
||||
await cursor.__anext__()
|
||||
self.assertEqual(cursor.cursor_id, 42)
|
||||
|
||||
async def test_sample_cursor_async_context_manager(self) -> None:
|
||||
client, _ = _spy_client()
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
async with proc.sample() as cur:
|
||||
self.assertIsInstance(cur, AsyncSampleCursor)
|
||||
self.assertTrue(cur._closed)
|
||||
|
||||
async def test_sample_cursor_session_propagates_to_refill(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
||||
proc = AsyncStreamProcessor(client=client, name="demo")
|
||||
fake_session = MagicMock()
|
||||
async for _ in proc.sample(session=fake_session):
|
||||
pass
|
||||
self.assertIs(calls[0]["session"], fake_session)
|
||||
|
||||
def test_sample_cursor_does_not_inherit_from_cursor_base(self) -> None:
|
||||
self.assertEqual(AsyncSampleCursor.__bases__, (object,))
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 5 — error handling
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class AsyncTestErrorHandling(unittest.IsolatedAsyncioTestCase):
|
||||
"""Server errors MUST propagate unchanged — no rewrapping, no filtering."""
|
||||
|
||||
def _failure(self, code: int = 9) -> OperationFailure:
|
||||
return OperationFailure("test error", code=code, details={"errmsg": "test", "code": code})
|
||||
|
||||
async def _assert_propagates(self, fn: Any, expected_code: int) -> None:
|
||||
with self.assertRaises(OperationFailure) as cm:
|
||||
await fn()
|
||||
self.assertEqual(cm.exception.code, expected_code)
|
||||
|
||||
async def test_create_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(9))
|
||||
await self._assert_propagates(
|
||||
lambda: AsyncStreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 9
|
||||
)
|
||||
|
||||
async def test_start_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(72))
|
||||
await self._assert_propagates(
|
||||
lambda: AsyncStreamProcessor(client=client, name="p").start(), 72
|
||||
)
|
||||
|
||||
async def test_stop_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(125))
|
||||
await self._assert_propagates(
|
||||
lambda: AsyncStreamProcessor(client=client, name="p").stop(), 125
|
||||
)
|
||||
|
||||
async def test_drop_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(1))
|
||||
await self._assert_propagates(
|
||||
lambda: AsyncStreamProcessor(client=client, name="p").drop(), 1
|
||||
)
|
||||
|
||||
async def test_get_info_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(9))
|
||||
await self._assert_propagates(lambda: AsyncStreamProcessors(client).get_info("p"), 9)
|
||||
|
||||
async def test_stats_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(72))
|
||||
await self._assert_propagates(
|
||||
lambda: AsyncStreamProcessor(client=client, name="p").stats(), 72
|
||||
)
|
||||
|
||||
async def test_get_stream_processor_samples_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(9))
|
||||
await self._assert_propagates(
|
||||
lambda: AsyncStreamProcessor(client=client, name="p").get_stream_processor_samples(), 9
|
||||
)
|
||||
|
||||
async def test_unknown_error_code_propagates(self) -> None:
|
||||
# Code 999 is not in the documented list; it must still propagate.
|
||||
client, _ = _spy_client(raises=self._failure(999))
|
||||
await self._assert_propagates(
|
||||
lambda: AsyncStreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 999
|
||||
)
|
||||
|
||||
def test_no_operation_failure_caught_in_module(self) -> None:
|
||||
"""Structural: verify no try/except hides server errors in either module."""
|
||||
for mod in (pymongo.asynchronous.stream_processing,):
|
||||
src = inspect.getsource(mod)
|
||||
self.assertNotIn("except OperationFailure", src, mod.__name__)
|
||||
self.assertNotIn("except PyMongoError", src, mod.__name__)
|
||||
self.assertNotIn("exc.code in", src, mod.__name__)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 6 — spec compliance
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class AsyncTestSpecCompliance(unittest.IsolatedAsyncioTestCase):
|
||||
"""Spec MUST-level compliance checks."""
|
||||
|
||||
async def test_retryability_classification(self) -> None:
|
||||
"""Each command must route as retryable or non-retryable per the spec."""
|
||||
cases: list[tuple[str, Any, bool]] = [
|
||||
# (label, coroutine-factory, expected_retryable_read)
|
||||
(
|
||||
"create",
|
||||
lambda c: AsyncStreamProcessors(c).create("demo", pipeline=[{"$x": 1}]),
|
||||
False,
|
||||
),
|
||||
("start", lambda c: AsyncStreamProcessor(client=c, name="demo").start(), False),
|
||||
("stop", lambda c: AsyncStreamProcessor(client=c, name="demo").stop(), False),
|
||||
("drop", lambda c: AsyncStreamProcessor(client=c, name="demo").drop(), False),
|
||||
("get_info", lambda c: AsyncStreamProcessors(c).get_info("demo"), True),
|
||||
("stats", lambda c: AsyncStreamProcessor(client=c, name="demo").stats(), True),
|
||||
(
|
||||
"get_samples_initial",
|
||||
lambda c: AsyncStreamProcessor(
|
||||
client=c, name="demo"
|
||||
).get_stream_processor_samples(),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"get_samples_continuation",
|
||||
lambda c: AsyncStreamProcessor(client=c, name="demo").get_stream_processor_samples(
|
||||
GetStreamProcessorSamplesOptions(cursor_id=42)
|
||||
),
|
||||
False,
|
||||
),
|
||||
]
|
||||
|
||||
responses_for: dict[str, list[Mapping[str, Any]]] = {
|
||||
"get_info": [dict(_INFO_RESPONSE)],
|
||||
"get_samples_initial": [{"cursorId": 0, "firstBatch": []}],
|
||||
"get_samples_continuation": [{"cursorId": 0, "nextBatch": []}],
|
||||
}
|
||||
|
||||
for label, make_coro, expected in cases:
|
||||
resp = responses_for.get(label, [{"ok": 1}])
|
||||
client, calls = _spy_client(responses=resp)
|
||||
await make_coro(client)
|
||||
self.assertEqual(
|
||||
calls[-1]["retryable_read"],
|
||||
expected,
|
||||
f"retryability mismatch for '{label}'",
|
||||
)
|
||||
|
||||
def test_async_sample_cursor_class_hierarchy(self) -> None:
|
||||
self.assertEqual(AsyncSampleCursor.__bases__, (object,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
704
test/test_stream_processing.py
Normal file
704
test/test_stream_processing.py
Normal file
@ -0,0 +1,704 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the Atlas Stream Processing driver module.
|
||||
|
||||
All tests run offline — no live workspace is required. The wire layer is
|
||||
stubbed out by replacing ``_command`` on the client with a lightweight spy
|
||||
that records calls and returns pre-configured responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
import unittest
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Mapping, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
import pymongo.synchronous.stream_processing
|
||||
from bson import Timestamp
|
||||
from pymongo import (
|
||||
CreateStreamProcessorOptions,
|
||||
GetStreamProcessorSamplesOptions,
|
||||
GetStreamProcessorSamplesResult,
|
||||
GetStreamProcessorStatsOptions,
|
||||
SampleCursor,
|
||||
StartStreamProcessorOptions,
|
||||
StreamProcessingClient,
|
||||
StreamProcessor,
|
||||
StreamProcessorInfo,
|
||||
StreamProcessors,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spy helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _spy_client(
|
||||
responses: Optional[list[Mapping[str, Any]]] = None,
|
||||
raises: Optional[Exception] = None,
|
||||
) -> tuple[StreamProcessingClient, list[dict[str, Any]]]:
|
||||
"""Return *(client, calls)* where ``client._command`` records every call.
|
||||
|
||||
*responses* is consumed in order; when exhausted, ``{}`` is returned.
|
||||
If *raises* is set, the spy raises that exception on every call instead
|
||||
of returning a response.
|
||||
"""
|
||||
calls: list[dict[str, Any]] = []
|
||||
resp_list: list[Mapping[str, Any]] = list(responses) if responses is not None else [{}]
|
||||
|
||||
def _command(
|
||||
cmd: dict[str, Any],
|
||||
*,
|
||||
retryable_read: bool = False,
|
||||
session: Any = None,
|
||||
) -> Mapping[str, Any]:
|
||||
calls.append({"cmd": dict(cmd), "retryable_read": retryable_read, "session": session})
|
||||
if raises is not None:
|
||||
raise raises
|
||||
return resp_list.pop(0) if resp_list else {}
|
||||
|
||||
client = StreamProcessingClient.__new__(StreamProcessingClient)
|
||||
client._client = MagicMock()
|
||||
client._command = _command
|
||||
return client, calls
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal info response fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INFO_RESPONSE: dict[str, Any] = {
|
||||
"ok": 1,
|
||||
"id": "abc",
|
||||
"name": "demo",
|
||||
"state": "CREATED",
|
||||
"pipeline": [],
|
||||
"pipelineVersion": 1,
|
||||
"enableAutoScaling": False,
|
||||
"failoverEnabled": False,
|
||||
"activeRegion": "us-east-1",
|
||||
"hasStarted": False,
|
||||
}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 1 — client constructor validation
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamProcessingClientConfig(unittest.TestCase):
|
||||
"""Constructor-level validation for StreamProcessingClient."""
|
||||
|
||||
def test_rejects_srv_uri(self) -> None:
|
||||
with self.assertRaises(ConfigurationError) as cm:
|
||||
StreamProcessingClient("mongodb+srv://example.com/")
|
||||
self.assertIn("mongodb+srv", str(cm.exception).lower())
|
||||
|
||||
def test_rejects_tls_false_kwarg(self) -> None:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
StreamProcessingClient("mongodb://example.com/", tls=False)
|
||||
|
||||
def test_rejects_ssl_false_kwarg(self) -> None:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
StreamProcessingClient("mongodb://example.com/", ssl=False)
|
||||
|
||||
def test_rejects_tls_false_in_uri(self) -> None:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
StreamProcessingClient("mongodb://example.com/?tls=false")
|
||||
|
||||
@patch("pymongo.synchronous.stream_processing.MongoClient")
|
||||
def test_forces_tls_true(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
StreamProcessingClient("mongodb://example.com/")
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
self.assertTrue(kwargs.get("tls"))
|
||||
|
||||
@patch("pymongo.synchronous.stream_processing.MongoClient")
|
||||
def test_defaults_authsource_to_admin(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
StreamProcessingClient("mongodb://example.com/")
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
self.assertEqual(kwargs.get("authSource"), "admin")
|
||||
|
||||
@patch("pymongo.synchronous.stream_processing.MongoClient")
|
||||
def test_preserves_explicit_authsource_in_kwarg(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
StreamProcessingClient("mongodb://example.com/", authSource="other")
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
# Our default injection must not overwrite a user-supplied value.
|
||||
self.assertEqual(kwargs.get("authSource"), "other")
|
||||
|
||||
@patch("pymongo.synchronous.stream_processing.MongoClient")
|
||||
def test_preserves_explicit_authsource_in_uri(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
StreamProcessingClient("mongodb://example.com/?authSource=other")
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
# authSource is already in the URI; we must not inject a duplicate kwarg.
|
||||
self.assertNotIn("authSource", kwargs)
|
||||
|
||||
@patch("pymongo.synchronous.stream_processing.MongoClient")
|
||||
def test_drops_ssl_kwarg_to_avoid_duplicate(self, mock_client: MagicMock) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
StreamProcessingClient("mongodb://example.com/", ssl=True)
|
||||
kwargs = mock_client.call_args.kwargs
|
||||
self.assertNotIn("ssl", kwargs)
|
||||
self.assertTrue(kwargs.get("tls"))
|
||||
|
||||
def test_workspace_endpoint_detection(self) -> None:
|
||||
from pymongo.synchronous.stream_processing import _is_workspace_endpoint
|
||||
|
||||
self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net"))
|
||||
self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net"))
|
||||
self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net"))
|
||||
self.assertFalse(_is_workspace_endpoint("localhost"))
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 2 — options dataclass validation
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamProcessingOptions(unittest.TestCase):
|
||||
"""Dataclass validation in stream_processing_options."""
|
||||
|
||||
def test_start_options_mutex_start_after_and_at_op_time(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
StartStreamProcessorOptions(
|
||||
start_after={"_id": 1},
|
||||
start_at_operation_time=Timestamp(0, 0),
|
||||
)
|
||||
|
||||
def test_start_options_invalid_tier(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
StartStreamProcessorOptions(tier="SP1")
|
||||
for valid_tier in ("SP2", "SP5", "SP10", "SP30", "SP50"):
|
||||
StartStreamProcessorOptions(tier=valid_tier) # must not raise
|
||||
|
||||
def test_start_options_invalid_workers(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
StartStreamProcessorOptions(workers=0)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
StartStreamProcessorOptions(workers=-1)
|
||||
StartStreamProcessorOptions(workers=1) # must not raise
|
||||
|
||||
def test_stats_options_invalid_scale(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorStatsOptions(scale=0)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorStatsOptions(scale=-1)
|
||||
GetStreamProcessorStatsOptions(scale=1)
|
||||
GetStreamProcessorStatsOptions(scale=1024)
|
||||
|
||||
def test_samples_options_invalid_cursor_id(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorSamplesOptions(cursor_id=-1)
|
||||
# cursor_id=0 is valid at construction; the method-level check rejects it.
|
||||
GetStreamProcessorSamplesOptions(cursor_id=0)
|
||||
GetStreamProcessorSamplesOptions(cursor_id=42)
|
||||
|
||||
def test_samples_options_invalid_limit(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorSamplesOptions(limit=-1)
|
||||
GetStreamProcessorSamplesOptions(limit=0)
|
||||
GetStreamProcessorSamplesOptions(limit=10)
|
||||
|
||||
def test_samples_options_invalid_batch_size(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
GetStreamProcessorSamplesOptions(batch_size=-1)
|
||||
GetStreamProcessorSamplesOptions(batch_size=0)
|
||||
GetStreamProcessorSamplesOptions(batch_size=5)
|
||||
|
||||
def test_create_options_all_optional(self) -> None:
|
||||
opts = CreateStreamProcessorOptions()
|
||||
self.assertIsNone(opts.dlq)
|
||||
self.assertIsNone(opts.stream_meta_field_name)
|
||||
self.assertIsNone(opts.tier)
|
||||
self.assertIsNone(opts.failover)
|
||||
|
||||
def test_stream_processor_info_from_response_full(self) -> None:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
doc: dict[str, Any] = {
|
||||
"ok": 1,
|
||||
"id": "abc123",
|
||||
"name": "demo",
|
||||
"state": "STARTED",
|
||||
"pipeline": [{"$source": {}}],
|
||||
"pipelineVersion": 3,
|
||||
"tier": "SP10",
|
||||
"dlq": {"connectionName": "c1"},
|
||||
"streamMetaFieldName": "ts",
|
||||
"enableAutoScaling": True,
|
||||
"failoverEnabled": True,
|
||||
"activeRegion": "us-east-1",
|
||||
"lastModifiedAt": now,
|
||||
"modifiedBy": "user@example.com",
|
||||
"lastStateChange": now,
|
||||
"lastHeartbeat": now,
|
||||
"hasStarted": True,
|
||||
"stats": {"inputMessageCount": 42},
|
||||
"errorMsg": None,
|
||||
"errorCode": None,
|
||||
"errorRetryable": None,
|
||||
}
|
||||
info = StreamProcessorInfo.from_response(doc)
|
||||
self.assertEqual(info.id, "abc123")
|
||||
self.assertEqual(info.name, "demo")
|
||||
self.assertEqual(info.state, "STARTED")
|
||||
self.assertEqual(info.pipeline, [{"$source": {}}])
|
||||
self.assertEqual(info.pipeline_version, 3)
|
||||
self.assertEqual(info.tier, "SP10")
|
||||
self.assertEqual(info.dlq, {"connectionName": "c1"})
|
||||
self.assertEqual(info.stream_meta_field_name, "ts")
|
||||
self.assertTrue(info.enable_auto_scaling)
|
||||
self.assertTrue(info.failover_enabled)
|
||||
self.assertEqual(info.active_region, "us-east-1")
|
||||
self.assertTrue(info.has_started)
|
||||
self.assertEqual(info.stats, {"inputMessageCount": 42})
|
||||
self.assertEqual(info.raw, doc)
|
||||
|
||||
def test_stream_processor_info_from_response_minimal(self) -> None:
|
||||
doc: dict[str, Any] = {
|
||||
"id": "x",
|
||||
"name": "p",
|
||||
"state": "CREATED",
|
||||
"pipeline": [],
|
||||
"pipelineVersion": 1,
|
||||
}
|
||||
info = StreamProcessorInfo.from_response(doc)
|
||||
self.assertEqual(info.id, "x")
|
||||
self.assertIsNone(info.tier)
|
||||
self.assertIsNone(info.dlq)
|
||||
self.assertFalse(info.enable_auto_scaling)
|
||||
self.assertFalse(info.has_started)
|
||||
self.assertEqual(info.raw, doc)
|
||||
|
||||
def test_stream_processor_info_preserves_unknown_fields(self) -> None:
|
||||
doc: dict[str, Any] = {
|
||||
"id": "x",
|
||||
"name": "p",
|
||||
"state": "CREATED",
|
||||
"pipeline": [],
|
||||
"pipelineVersion": 1,
|
||||
"futureField": "someValue",
|
||||
}
|
||||
info = StreamProcessorInfo.from_response(doc)
|
||||
self.assertEqual(info.raw["futureField"], "someValue")
|
||||
|
||||
def test_stream_processor_info_state_is_string(self) -> None:
|
||||
doc: dict[str, Any] = {
|
||||
"id": "x",
|
||||
"name": "p",
|
||||
"state": "FUTURE_UNKNOWN_STATE",
|
||||
"pipeline": [],
|
||||
"pipelineVersion": 1,
|
||||
}
|
||||
info = StreamProcessorInfo.from_response(doc)
|
||||
self.assertEqual(info.state, "FUTURE_UNKNOWN_STATE")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 3 — lifecycle commands
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestStreamProcessorsCommands(unittest.TestCase):
|
||||
"""Lifecycle commands on StreamProcessors and StreamProcessor."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.client, self.calls = _spy_client(responses=[{"ok": 1}] * 30)
|
||||
self.sps = StreamProcessors(self.client)
|
||||
self.proc = StreamProcessor(client=self.client, name="demo")
|
||||
|
||||
def test_create_sends_correct_command(self) -> None:
|
||||
self.sps.create("demo", pipeline=[{"$source": {}}])
|
||||
entry = self.calls[-1]
|
||||
self.assertEqual(entry["cmd"].get("createStreamProcessor"), "demo")
|
||||
self.assertEqual(entry["cmd"].get("pipeline"), [{"$source": {}}])
|
||||
self.assertNotIn("options", entry["cmd"])
|
||||
self.assertFalse(entry["retryable_read"])
|
||||
|
||||
def test_create_with_options_includes_them(self) -> None:
|
||||
opts = CreateStreamProcessorOptions(dlq={"connectionName": "c"}, tier="SP10")
|
||||
self.sps.create("demo", pipeline=[{"$source": {}}], options=opts)
|
||||
cmd = self.calls[-1]["cmd"]
|
||||
self.assertEqual(cmd["options"]["dlq"], {"connectionName": "c"})
|
||||
self.assertEqual(cmd["options"]["tier"], "SP10")
|
||||
|
||||
def test_create_rejects_empty_name(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
self.sps.create("", pipeline=[{"$x": 1}])
|
||||
self.assertEqual(len(self.calls), 0)
|
||||
|
||||
def test_create_rejects_whitespace_name(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
self.sps.create(" ", pipeline=[{"$x": 1}])
|
||||
self.assertEqual(len(self.calls), 0)
|
||||
|
||||
def test_create_rejects_empty_pipeline(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
self.sps.create("demo", pipeline=[])
|
||||
self.assertEqual(len(self.calls), 0)
|
||||
|
||||
def test_get_returns_handle_without_calling_server(self) -> None:
|
||||
proc = self.sps.get("demo")
|
||||
self.assertEqual(len(self.calls), 0)
|
||||
self.assertEqual(proc.name, "demo")
|
||||
self.assertIsInstance(proc, StreamProcessor)
|
||||
|
||||
def test_get_rejects_empty_name(self) -> None:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
self.sps.get("")
|
||||
|
||||
def test_get_info_sends_correct_command_and_decodes(self) -> None:
|
||||
client, calls = _spy_client(responses=[dict(_INFO_RESPONSE)])
|
||||
info = StreamProcessors(client).get_info("demo")
|
||||
self.assertEqual(calls[0]["cmd"].get("getStreamProcessor"), "demo")
|
||||
self.assertTrue(calls[0]["retryable_read"])
|
||||
self.assertIsInstance(info, StreamProcessorInfo)
|
||||
self.assertEqual(info.id, "abc")
|
||||
self.assertEqual(info.state, "CREATED")
|
||||
|
||||
def test_get_info_preserves_unknown_response_fields(self) -> None:
|
||||
resp = dict(_INFO_RESPONSE, futureField="value")
|
||||
client, _ = _spy_client(responses=[resp])
|
||||
info = StreamProcessors(client).get_info("demo")
|
||||
self.assertEqual(info.raw["futureField"], "value")
|
||||
|
||||
def test_start_sends_correct_command(self) -> None:
|
||||
self.proc.start()
|
||||
entry = self.calls[-1]
|
||||
self.assertEqual(entry["cmd"].get("startStreamProcessor"), "demo")
|
||||
self.assertNotIn("workers", entry["cmd"])
|
||||
self.assertNotIn("options", entry["cmd"])
|
||||
self.assertFalse(entry["retryable_read"])
|
||||
|
||||
def test_start_with_options_includes_them(self) -> None:
|
||||
opts = StartStreamProcessorOptions(workers=2, clear_checkpoints=True, tier="SP10")
|
||||
self.proc.start(options=opts)
|
||||
cmd = self.calls[-1]["cmd"]
|
||||
self.assertEqual(cmd.get("workers"), 2)
|
||||
self.assertTrue(cmd["options"]["clearCheckpoints"])
|
||||
self.assertEqual(cmd["options"]["tier"], "SP10")
|
||||
|
||||
def test_stop_sends_correct_command(self) -> None:
|
||||
self.proc.stop()
|
||||
entry = self.calls[-1]
|
||||
self.assertEqual(entry["cmd"].get("stopStreamProcessor"), "demo")
|
||||
self.assertFalse(entry["retryable_read"])
|
||||
|
||||
def test_drop_sends_correct_command(self) -> None:
|
||||
self.proc.drop()
|
||||
entry = self.calls[-1]
|
||||
self.assertEqual(entry["cmd"].get("dropStreamProcessor"), "demo")
|
||||
self.assertFalse(entry["retryable_read"])
|
||||
|
||||
def test_stats_sends_correct_command_and_returns_raw_dict(self) -> None:
|
||||
raw_resp: dict[str, Any] = {"ok": 1, "stats": {"inputMessageCount": 5}, "futureField": "x"}
|
||||
client, calls = _spy_client(responses=[raw_resp])
|
||||
result = StreamProcessor(client=client, name="demo").stats()
|
||||
self.assertEqual(calls[0]["cmd"].get("getStreamProcessorStats"), "demo")
|
||||
self.assertTrue(calls[0]["retryable_read"])
|
||||
self.assertEqual(result, raw_resp)
|
||||
|
||||
def test_stats_with_options(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"ok": 1}])
|
||||
opts = GetStreamProcessorStatsOptions(scale=1024, verbose=True)
|
||||
StreamProcessor(client=client, name="demo").stats(options=opts)
|
||||
self.assertEqual(calls[0]["cmd"]["options"]["scale"], 1024)
|
||||
self.assertTrue(calls[0]["cmd"]["options"]["verbose"])
|
||||
|
||||
def test_session_propagates(self) -> None:
|
||||
fake_session = MagicMock()
|
||||
self.proc.start(session=fake_session)
|
||||
self.assertIs(self.calls[-1]["session"], fake_session)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 4 — sample cursor
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSampleCursor(unittest.TestCase):
|
||||
"""Two-phase sample cursor protocol."""
|
||||
|
||||
def test_get_samples_initial_call_sends_start_command(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"x": 1}]}])
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
result = proc.get_stream_processor_samples()
|
||||
cmd = calls[0]["cmd"]
|
||||
self.assertIn("startSampleStreamProcessor", cmd)
|
||||
self.assertNotIn("cursorId", cmd)
|
||||
self.assertNotIn("batchSize", cmd)
|
||||
self.assertFalse(calls[0]["retryable_read"])
|
||||
self.assertEqual(result.cursor_id, 42)
|
||||
self.assertEqual(result.documents, [{"x": 1}])
|
||||
|
||||
def test_get_samples_initial_call_includes_limit(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 1, "firstBatch": []}])
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(limit=100))
|
||||
self.assertEqual(calls[0]["cmd"].get("limit"), 100)
|
||||
self.assertNotIn("batchSize", calls[0]["cmd"])
|
||||
|
||||
def test_get_samples_continuation_sends_get_more(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": [{"y": 2}]}])
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
result = proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=42))
|
||||
cmd = calls[0]["cmd"]
|
||||
self.assertIn("getMoreSampleStreamProcessor", cmd)
|
||||
self.assertEqual(cmd.get("cursorId"), 42)
|
||||
self.assertNotIn("limit", cmd)
|
||||
self.assertEqual(result.cursor_id, 0)
|
||||
self.assertEqual(result.documents, [{"y": 2}])
|
||||
|
||||
def test_get_samples_continuation_includes_batch_size(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": []}])
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
proc.get_stream_processor_samples(
|
||||
GetStreamProcessorSamplesOptions(cursor_id=42, batch_size=5)
|
||||
)
|
||||
self.assertEqual(calls[0]["cmd"].get("batchSize"), 5)
|
||||
|
||||
def test_get_samples_rejects_cursor_id_zero(self) -> None:
|
||||
client, calls = _spy_client()
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
with self.assertRaises(InvalidOperation):
|
||||
proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=0))
|
||||
self.assertEqual(len(calls), 0)
|
||||
|
||||
def test_sample_iterator_drains_single_batch(self) -> None:
|
||||
client, calls = _spy_client(
|
||||
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}]
|
||||
)
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
docs = []
|
||||
for doc in proc.sample():
|
||||
docs.append(doc)
|
||||
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
|
||||
self.assertEqual(len(calls), 1)
|
||||
|
||||
def test_sample_iterator_drains_multiple_batches(self) -> None:
|
||||
client, calls = _spy_client(
|
||||
responses=[
|
||||
{"cursorId": 11, "firstBatch": [{"a": 1}]},
|
||||
{"cursorId": 22, "nextBatch": [{"a": 2}]},
|
||||
{"cursorId": 0, "nextBatch": [{"a": 3}]},
|
||||
]
|
||||
)
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
docs = []
|
||||
for doc in proc.sample():
|
||||
docs.append(doc)
|
||||
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
|
||||
self.assertEqual(len(calls), 3)
|
||||
|
||||
def test_sample_iterator_handles_empty_batch_with_nonzero_cursor(self) -> None:
|
||||
client, calls = _spy_client(
|
||||
responses=[
|
||||
{"cursorId": 11, "firstBatch": []},
|
||||
{"cursorId": 22, "nextBatch": []},
|
||||
{"cursorId": 0, "nextBatch": [{"a": 1}]},
|
||||
]
|
||||
)
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
docs = []
|
||||
for doc in proc.sample():
|
||||
docs.append(doc)
|
||||
self.assertEqual(docs, [{"a": 1}])
|
||||
self.assertEqual(len(calls), 3)
|
||||
|
||||
def test_sample_iterator_stops_after_cursor_id_zero(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
cursor = proc.sample()
|
||||
first = cursor.__next__()
|
||||
self.assertEqual(first, {"a": 1})
|
||||
with self.assertRaises(StopIteration):
|
||||
cursor.__next__()
|
||||
# Only one wire call — the server returned cursorId:0 so we stopped.
|
||||
self.assertEqual(len(calls), 1)
|
||||
|
||||
def test_sample_cursor_close_stops_iteration(self) -> None:
|
||||
client, calls = _spy_client()
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
cursor = proc.sample()
|
||||
cursor.close()
|
||||
with self.assertRaises(StopIteration):
|
||||
cursor.__next__()
|
||||
self.assertEqual(len(calls), 0)
|
||||
|
||||
def test_sample_cursor_alive_property(self) -> None:
|
||||
client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
cursor = proc.sample()
|
||||
self.assertTrue(cursor.alive)
|
||||
doc = cursor.__next__()
|
||||
self.assertEqual(doc, {"a": 1})
|
||||
# After exhaustion (cursorId:0 returned during refill), alive is False.
|
||||
self.assertFalse(cursor.alive)
|
||||
|
||||
def test_sample_cursor_id_property(self) -> None:
|
||||
client, _ = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"a": 1}]}])
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
cursor = proc.sample()
|
||||
self.assertIsNone(cursor.cursor_id)
|
||||
cursor.__next__()
|
||||
self.assertEqual(cursor.cursor_id, 42)
|
||||
|
||||
def test_sample_cursor_async_context_manager(self) -> None:
|
||||
client, _ = _spy_client()
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
with proc.sample() as cur:
|
||||
self.assertIsInstance(cur, SampleCursor)
|
||||
self.assertTrue(cur._closed)
|
||||
|
||||
def test_sample_cursor_session_propagates_to_refill(self) -> None:
|
||||
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
||||
proc = StreamProcessor(client=client, name="demo")
|
||||
fake_session = MagicMock()
|
||||
for _ in proc.sample(session=fake_session):
|
||||
pass
|
||||
self.assertIs(calls[0]["session"], fake_session)
|
||||
|
||||
def test_sample_cursor_does_not_inherit_from_cursor_base(self) -> None:
|
||||
self.assertEqual(SampleCursor.__bases__, (object,))
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 5 — error handling
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestErrorHandling(unittest.TestCase):
|
||||
"""Server errors MUST propagate unchanged — no rewrapping, no filtering."""
|
||||
|
||||
def _failure(self, code: int = 9) -> OperationFailure:
|
||||
return OperationFailure("test error", code=code, details={"errmsg": "test", "code": code})
|
||||
|
||||
def _assert_propagates(self, fn: Any, expected_code: int) -> None:
|
||||
with self.assertRaises(OperationFailure) as cm:
|
||||
fn()
|
||||
self.assertEqual(cm.exception.code, expected_code)
|
||||
|
||||
def test_create_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(9))
|
||||
self._assert_propagates(
|
||||
lambda: StreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 9
|
||||
)
|
||||
|
||||
def test_start_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(72))
|
||||
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").start(), 72)
|
||||
|
||||
def test_stop_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(125))
|
||||
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").stop(), 125)
|
||||
|
||||
def test_drop_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(1))
|
||||
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").drop(), 1)
|
||||
|
||||
def test_get_info_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(9))
|
||||
self._assert_propagates(lambda: StreamProcessors(client).get_info("p"), 9)
|
||||
|
||||
def test_stats_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(72))
|
||||
self._assert_propagates(lambda: StreamProcessor(client=client, name="p").stats(), 72)
|
||||
|
||||
def test_get_stream_processor_samples_propagates_operation_failure(self) -> None:
|
||||
client, _ = _spy_client(raises=self._failure(9))
|
||||
self._assert_propagates(
|
||||
lambda: StreamProcessor(client=client, name="p").get_stream_processor_samples(), 9
|
||||
)
|
||||
|
||||
def test_unknown_error_code_propagates(self) -> None:
|
||||
# Code 999 is not in the documented list; it must still propagate.
|
||||
client, _ = _spy_client(raises=self._failure(999))
|
||||
self._assert_propagates(
|
||||
lambda: StreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 999
|
||||
)
|
||||
|
||||
def test_no_operation_failure_caught_in_module(self) -> None:
|
||||
"""Structural: verify no try/except hides server errors in either module."""
|
||||
for mod in (pymongo.synchronous.stream_processing,):
|
||||
src = inspect.getsource(mod)
|
||||
self.assertNotIn("except OperationFailure", src, mod.__name__)
|
||||
self.assertNotIn("except PyMongoError", src, mod.__name__)
|
||||
self.assertNotIn("exc.code in", src, mod.__name__)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Test class 6 — spec compliance
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSpecCompliance(unittest.TestCase):
|
||||
"""Spec MUST-level compliance checks."""
|
||||
|
||||
def test_retryability_classification(self) -> None:
|
||||
"""Each command must route as retryable or non-retryable per the spec."""
|
||||
cases: list[tuple[str, Any, bool]] = [
|
||||
# (label, coroutine-factory, expected_retryable_read)
|
||||
("create", lambda c: StreamProcessors(c).create("demo", pipeline=[{"$x": 1}]), False),
|
||||
("start", lambda c: StreamProcessor(client=c, name="demo").start(), False),
|
||||
("stop", lambda c: StreamProcessor(client=c, name="demo").stop(), False),
|
||||
("drop", lambda c: StreamProcessor(client=c, name="demo").drop(), False),
|
||||
("get_info", lambda c: StreamProcessors(c).get_info("demo"), True),
|
||||
("stats", lambda c: StreamProcessor(client=c, name="demo").stats(), True),
|
||||
(
|
||||
"get_samples_initial",
|
||||
lambda c: StreamProcessor(client=c, name="demo").get_stream_processor_samples(),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"get_samples_continuation",
|
||||
lambda c: StreamProcessor(client=c, name="demo").get_stream_processor_samples(
|
||||
GetStreamProcessorSamplesOptions(cursor_id=42)
|
||||
),
|
||||
False,
|
||||
),
|
||||
]
|
||||
|
||||
responses_for: dict[str, list[Mapping[str, Any]]] = {
|
||||
"get_info": [dict(_INFO_RESPONSE)],
|
||||
"get_samples_initial": [{"cursorId": 0, "firstBatch": []}],
|
||||
"get_samples_continuation": [{"cursorId": 0, "nextBatch": []}],
|
||||
}
|
||||
|
||||
for label, make_coro, expected in cases:
|
||||
resp = responses_for.get(label, [{"ok": 1}])
|
||||
client, calls = _spy_client(responses=resp)
|
||||
make_coro(client)
|
||||
self.assertEqual(
|
||||
calls[-1]["retryable_read"],
|
||||
expected,
|
||||
f"retryability mismatch for '{label}'",
|
||||
)
|
||||
|
||||
def test_async_sample_cursor_class_hierarchy(self) -> None:
|
||||
self.assertEqual(SampleCursor.__bases__, (object,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -32,6 +32,7 @@ replacements = {
|
||||
"AsyncDatabase": "Database",
|
||||
"_AsyncCursorBase": "_CursorBase",
|
||||
"AsyncCursor": "Cursor",
|
||||
"pymongo.asynchronous.stream_processing.AsyncMongoClient": "pymongo.synchronous.stream_processing.MongoClient",
|
||||
"AsyncMongoClient": "MongoClient",
|
||||
"AsyncCommandCursor": "CommandCursor",
|
||||
"AsyncRawBatchCursor": "RawBatchCursor",
|
||||
@ -67,6 +68,10 @@ replacements = {
|
||||
"_AsyncGridOutChunkIterator": "GridOutChunkIterator",
|
||||
"_a_grid_in_property": "_grid_in_property",
|
||||
"_a_grid_out_property": "_grid_out_property",
|
||||
"AsyncStreamProcessingClient": "StreamProcessingClient",
|
||||
"AsyncStreamProcessors": "StreamProcessors",
|
||||
"AsyncStreamProcessor": "StreamProcessor",
|
||||
"AsyncSampleCursor": "SampleCursor",
|
||||
"AsyncClientEncryption": "ClientEncryption",
|
||||
"AsyncMongoCryptCallback": "MongoCryptCallback",
|
||||
"AsyncExplicitEncrypter": "ExplicitEncrypter",
|
||||
@ -128,6 +133,12 @@ replacements = {
|
||||
"SpecRunnerTask": "SpecRunnerThread",
|
||||
"AsyncMockConnection": "MockConnection",
|
||||
"AsyncMockPool": "MockPool",
|
||||
"AsyncMockMonitor": "SyncMockMonitor",
|
||||
"AsyncMock": "MagicMock",
|
||||
"AsyncTestStreamProcessorsCommands": "TestStreamProcessorsCommands",
|
||||
"AsyncTestSampleCursor": "TestSampleCursor",
|
||||
"AsyncTestErrorHandling": "TestErrorHandling",
|
||||
"AsyncTestSpecCompliance": "TestSpecCompliance",
|
||||
"StopAsyncIteration": "StopIteration",
|
||||
"create_async_event": "create_event",
|
||||
"async_create_barrier": "create_barrier",
|
||||
@ -272,6 +283,7 @@ converted_tests = [
|
||||
"test_srv_polling.py",
|
||||
"test_ssl.py",
|
||||
"test_streaming_protocol.py",
|
||||
"test_stream_processing.py",
|
||||
"test_transactions.py",
|
||||
"test_transactions_unified.py",
|
||||
"test_unified_format.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user